Spaces:
Build error
Build error
from sentencepiece import SentencePieceProcessor | |
import os | |
import torch | |
class ExLlamaTokenizer: | |
def __init__(self, tokenizer_model_path): | |
self.path = tokenizer_model_path | |
self.tokenizer = SentencePieceProcessor(model_file = self.path) | |
self.unk_token = "<unk>" | |
self.bos_token = "<s>" | |
self.eos_token = "</s>" | |
self.unk_token_id = self.tokenizer.unk_id() # is the same as pad token id... | |
self.eos_token_id = self.tokenizer.eos_id() | |
self.bos_token_id = self.tokenizer.bos_id() | |
self.pad_token_id = 0 # self.tokenizer.pad_id() | |
self.newline_token_id = 13 | |
self.special_characters = [(self.bos_token, self.bos_token_id), (self.eos_token, self.eos_token_id), (self.unk_token, self.unk_token_id)] # for tokenzier encoding | |
# Encode string | |
def encode(self, text, return_mask = False, max_seq_len = 2048, add_bos = False, add_eos = False, encode_special_characters = False): | |
if isinstance(text, list): | |
# text is a list of strings | |
list_ids = self.tokenizer.EncodeAsIds(text) | |
# pad bos and eos | |
if add_bos: | |
for ids in list_ids: ids.insert(0, self.bos_token_id) | |
if add_eos: | |
for ids in list_ids: ids.append(self.eos_token_id) | |
max_length = max([len(ids) for ids in list_ids]) | |
needs_mask = False | |
padded_ids = [] | |
for ids in list_ids: | |
if len(ids) != len(list_ids[0]): needs_mask = True | |
padding = torch.full((max_length - len(ids),), self.pad_token_id) | |
sequence = torch.tensor(ids) | |
padded_ids.append(torch.cat((padding, sequence), dim = 0).long()) | |
stacked_ids = torch.stack(padded_ids, dim = 0) | |
if return_mask: | |
if needs_mask: | |
mask_padding = torch.full((stacked_ids.shape[0], max_seq_len - stacked_ids.shape[1]), True, dtype = torch.bool, device = "cpu") | |
mask = stacked_ids != 0 | |
mask = torch.cat((mask, mask_padding), dim = 1) | |
return stacked_ids, mask | |
else: | |
return stacked_ids, None | |
else: | |
return stacked_ids | |
else: | |
# text is a single string | |
split_text = [text] | |
# look for special characters | |
if encode_special_characters: | |
for special_character, special_token_id in self.special_characters: | |
temp_text = [] | |
for segment in split_text: | |
if isinstance(segment, str) and special_character in segment: | |
# for each special character, append the text before the special character, then append the special character ID, then the rest of the text | |
parts = segment.split(special_character) | |
new_parts = [] | |
for i, part in enumerate(parts): | |
new_parts.append(part) | |
if i < len(parts) - 1: # add the special token id between parts, but not after the last part | |
new_parts.append(special_token_id) | |
temp_text.extend(new_parts) | |
else: | |
temp_text.append(segment) | |
split_text = temp_text | |
ids = [] | |
for text_chunk in split_text: | |
if isinstance(text_chunk, str): | |
ids += self.tokenizer.EncodeAsIds(text_chunk) | |
else: | |
ids.append(text_chunk) | |
# pad bos and eos | |
if add_bos: | |
ids = [self.bos_token_id] + ids | |
if add_eos: | |
ids = ids + [self.eos_token_id] | |
stacked_ids = torch.tensor(ids).unsqueeze(0) | |
if return_mask: | |
return stacked_ids, None | |
else: | |
return stacked_ids | |
def decode(self, ids, decode_special_characters=False): | |
special_ids = {id_: char for char, id_ in self.special_characters} # create a lookup dictionary | |
if ids.dim() > 1: | |
texts = [] | |
for i in range(ids.shape[0]): | |
seq = ids[i].tolist() | |
seq = [t for t in seq if t != self.pad_token_id] | |
if decode_special_characters: | |
text_parts = [] | |
normal_ids = [] # list of lists | |
current_normal_ids = [] # current list of normal IDs | |
for idx, id_ in enumerate(seq): | |
if id_ in special_ids: | |
# Save the current list of normal IDs, then start a new one | |
normal_ids.append(current_normal_ids) | |
current_normal_ids = [] | |
# Store special token as a string | |
text_parts.append(special_ids[id_]) | |
else: | |
current_normal_ids.append(id_) | |
normal_ids.append(current_normal_ids) # save the last segment of normal IDs | |
decoded_segments = [self.tokenizer.Decode(segment) for segment in normal_ids] | |
for idx, decoded_segment in enumerate(decoded_segments): | |
text_parts.insert(2*idx, decoded_segment) | |
texts.append("".join(text_parts)) | |
else: | |
if self.eos_token_id in seq: # to not mess up special char decoding | |
seq = seq[:seq.index(self.eos_token_id)] | |
texts.append(self.tokenizer.Decode(seq)) | |
return texts | |
else: | |
ids = ids.tolist() | |
if decode_special_characters: | |
text_parts = [] | |
normal_ids = [] # list of lists | |
current_normal_ids = [] # current list of normal IDs | |
for idx, id_ in enumerate(ids): | |
if id_ in special_ids: | |
# Save the current list of normal IDs, then start a new one | |
normal_ids.append(current_normal_ids) | |
current_normal_ids = [] | |
# Store special token as a string | |
text_parts.append(special_ids[id_]) | |
else: | |
current_normal_ids.append(id_) | |
normal_ids.append(current_normal_ids) # save the last segment of normal IDs | |
decoded_segments = [self.tokenizer.Decode(segment) for segment in normal_ids] | |
for idx, decoded_segment in enumerate(decoded_segments): | |
text_parts.insert(2*idx, decoded_segment) | |
text = "".join(text_parts) | |
else: | |
text = self.tokenizer.Decode(ids) | |
return text | |
def num_tokens(self, text, encode_special_characters = False): | |
if encode_special_characters: | |
ids = self.encode(text, encode_special_characters = True) | |
return ids.size(1) | |
else: | |
ids = self.tokenizer.Encode(text) | |
return len(ids) |