from types import SimpleNamespace import numpy as np import torch from torch import nn from transformers import BertTokenizerFast, BertForMaskedLM, BertTokenizer, BertModel from tensor2tensor.data_generators import text_encoder import torch.nn.functional as F class LatinBERT(nn.Module): def __init__(self, bertPath, tokenizerPath): super().__init__() self.tokenizer = LatinTokenizer(tokenizerPath) #BertTokenizer.from_pretrained("bert-base-cased") self.model = BertModel.from_pretrained(bertPath)#.to("cuda") self.model.eval() @torch.no_grad() def __call__(self, sentences): if not isinstance(sentences, list): sentences = [sentences] tokens_ids, masks, transforms = self.tokenizer.tokenize(sentences, 512) #tokens_ids = tokens_ids.to("cuda") #tokens_ids = tokens_ids.squeeze() if tokens_ids.shape[-1] > 512: tokens_ids = torch.narrow(tokens_ids, -1, 0, 512) tokens_ids = tokens_ids.reshape((-1, tokens_ids.shape[-1])) outputs = self.model.forward(tokens_ids) embeddings = outputs.pooler_output embeddings = F.normalize(embeddings, p=2).cpu() return embeddings @property def dim(self): return 768 class LatinTokenizer: def __init__(self, model): self.vocab = dict() self.reverseVocab = dict() self.encoder = text_encoder.SubwordTextEncoder(model) self.vocab["[PAD]"] = 0 self.vocab["[UNK]"] = 1 self.vocab["[CLS]"] = 2 self.vocab["[SEP]"] = 3 self.vocab["[MASK]"] = 4 for key in self.encoder._subtoken_string_to_id: self.vocab[key] = self.encoder._subtoken_string_to_id[key] + 5 self.reverseVocab[self.encoder._subtoken_string_to_id[key] + 5] = key def convert_tokens_to_ids(self, tokens): wp_tokens = list() for token in tokens: if token == "[PAD]": wp_tokens.append(0) elif token == "[UNK]": wp_tokens.append(1) elif token == "[CLS]": wp_tokens.append(2) elif token == "[SEP]": wp_tokens.append(3) elif token == "[MASK]": wp_tokens.append(4) else: wp_tokens.append(self.vocab[token]) return wp_tokens def tokenize(self, sentences, max_batch): #print(len(sentences)) maxLen=0 for sentence in sentences: length=0 for word in sentence: toks=self._tokenize(word) length+=len(toks) if length> maxLen: maxLen=length #print(maxLen) all_data=[] all_masks=[] all_labels=[] all_transforms=[] for sentence in sentences: tok_ids=[] input_mask=[] labels=[] transform=[] all_toks=[] n=0 for idx, word in enumerate(sentence): toks=self._tokenize(word) all_toks.append(toks) n+=len(toks) cur=0 for idx, word in enumerate(sentence): toks=all_toks[idx] ind=list(np.zeros(n)) for j in range(cur,cur+len(toks)): ind[j]=1./len(toks) cur+=len(toks) transform.append(ind) tok_ids.extend(self.convert_tokens_to_ids(toks)) input_mask.extend(np.ones(len(toks))) labels.append(1) all_data.append(tok_ids) all_masks.append(input_mask) all_labels.append(labels) all_transforms.append(transform) lengths = np.array([len(l) for l in all_data]) # Note sequence must be ordered from shortest to longest so current_batch will work ordering = np.argsort(lengths) ordered_data = [None for i in range(len(all_data))] ordered_masks = [None for i in range(len(all_data))] ordered_labels = [None for i in range(len(all_data))] ordered_transforms = [None for i in range(len(all_data))] for i, ind in enumerate(ordering): ordered_data[i] = all_data[ind] ordered_masks[i] = all_masks[ind] ordered_labels[i] = all_labels[ind] ordered_transforms[i] = all_transforms[ind] batched_data=[] batched_mask=[] batched_labels=[] batched_transforms=[] i=0 current_batch=max_batch while i < len(ordered_data): batch_data=ordered_data[i:i+current_batch] batch_mask=ordered_masks[i:i+current_batch] batch_labels=ordered_labels[i:i+current_batch] batch_transforms=ordered_transforms[i:i+current_batch] max_len = max([len(sent) for sent in batch_data]) max_label = max([len(label) for label in batch_labels]) for j in range(len(batch_data)): blen=len(batch_data[j]) blab=len(batch_labels[j]) for k in range(blen, max_len): batch_data[j].append(0) batch_mask[j].append(0) for z in range(len(batch_transforms[j])): batch_transforms[j][z].append(0) for k in range(blab, max_label): batch_labels[j].append(-100) for k in range(len(batch_transforms[j]), max_label): batch_transforms[j].append(np.zeros(max_len)) batched_data.append(batch_data) batched_mask.append(batch_mask) batched_labels.append(batch_labels) batched_transforms.append(batch_transforms) #bsize=torch.FloatTensor(batch_transforms).shape i+=current_batch # adjust batch size; sentences are ordered from shortest to longest so decrease as they get longer if max_len > 100: current_batch=12 if max_len > 200: current_batch=6 #print(len(batch_data), len(batch_mask), len(batch_transforms)) return torch.LongTensor(batched_data).squeeze(), torch.FloatTensor(batched_mask).squeeze(), torch.FloatTensor(batched_transforms).squeeze() ''' def _tokenize(self, text): if not isinstance(text, list): text = [text] outputs = [] for sentence in text: tokens = sentence.split(" ") wp_tokens = [] for token in tokens: if token in ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]: wp_tokens.append(token) else: wp_toks = self.encoder.encode(token) for wp in wp_toks: wp_tokens.append(self.reverseVocab[wp + 5]) outputs.append(SimpleNamespace( tokens=wp_tokens, ids=torch.Tensor(self.convert_tokens_to_ids(wp_tokens)) )) return outputs ''' def _tokenize(self, text): tokens = text.split(" ") wp_tokens = [] for token in tokens: if token in {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}: wp_tokens.append(token) else: wp_toks = self.encoder.encode(token) for wp in wp_toks: wp_tokens.append(self.reverseVocab[wp + 5]) #print(wp_tokens) return wp_tokens def main(): model = LatinBERT("../../latinBert/latin_bert/models/latin_bert", tokenizerPath="./tokenizer/latin.subword.encoder") sents = ["arma virumque cano", "arma gravi numero violentaque bella parabam"] output = model(sents) print("end", output.shape) if __name__ == "__main__": main()