--- license: llama3.1 base_model: - THUDM/LongCite-llama3.1-8b datasets: - THUDM/LongCite-45k pipeline_tag: text-generation tags: - imatrix - importance matrix - gguf - llama.cpp --- GGUF version of longcite, you need to add the following tokens as stop tokens : `[128000, 128007, 128009]` or `["<|begin_of_text|>", "<|end_header_id|>", "<|eot_id|>"]` Be default, and it seems to be working so far, EOS token is 128007 (end_header_id). Working for citation and naive question-answer mode. Not chat template provided as it requires python pre-processing (before being sent to LLM) and post-processing. iMatrix generated using [this dataset](https://github.com/ggerganov/llama.cpp/discussions/5263#discussioncomment-9552049) Example code ```python from nltk.tokenize import PunktSentenceTokenizer import re class LongCiteModel: @staticmethod def text_split_by_punctuation(original_text, return_dict=False): # text = re.sub(r'([a-z])\.([A-Z])', r'\1. \2', original_text) # separate period without space text = original_text custom_sent_tokenizer = PunktSentenceTokenizer() punctuations = r"([。;!?])" # For Chinese support separated = custom_sent_tokenizer.tokenize(text) separated = sum([re.split(punctuations, s) for s in separated], []) # Put the punctuations back to the sentence for i in range(1, len(separated)): if re.match(punctuations, separated[i]): separated[i-1] += separated[i] separated[i] = '' separated = [s for s in separated if s != ""] if len(separated) == 1: separated = original_text.split('\n\n') separated = [s.strip() for s in separated if s.strip() != ""] if not return_dict: return separated else: pos = 0 res = [] for i, sent in enumerate(separated): st = original_text.find(sent, pos) assert st != -1, sent ed = st + len(sent) res.append( { 'c_idx': i, 'content': sent, 'start_idx': st, 'end_idx': ed, } ) pos = ed return res @staticmethod def get_prompt(context, question): sents = LongCiteModel.text_split_by_punctuation(context, return_dict=True) splited_context = "" for i, s in enumerate(sents): st, ed = s['start_idx'], s['end_idx'] assert s['content'] == context[st:ed], s ed = sents[i+1]['start_idx'] if i < len(sents)-1 else len(context) sents[i] = { 'content': context[st:ed], 'start': st, 'end': ed, 'c_idx': s['c_idx'], } splited_context += f""+context[st:ed] prompt = '''Please answer the user's question based on the following document. When a sentence S in your response uses information from some chunks in the document (i.e., -, -, ...), please append these chunk numbers to S in the format "{S}[{s1}-{e1}][{s2}-{e2}]...". You must answer in the same language as the user's question.\n\n[Document Start]\n%s\n[Document End]\n\n%s''' % (splited_context, question) return prompt, sents, splited_context @staticmethod def get_citations(statement, sents): c_texts = re.findall(r'(.*?)', statement, re.DOTALL) spans = sum([re.findall(r"\[([0-9]+\-[0-9]+)\]", c_text, re.DOTALL) for c_text in c_texts], []) statement = re.sub(r'(.*?)', '', statement, flags=re.DOTALL) merged_citations = [] for i, s in enumerate(spans): try: st, ed = [int(x) for x in s.split('-')] if st > len(sents) - 1 or ed < st: continue st, ed = max(0, st), min(ed, len(sents)-1) assert st <= ed, str(c_texts) + '\t' + str(len(sents)) if len(merged_citations) > 0 and st == merged_citations[-1]['end_sentence_idx'] + 1: merged_citations[-1].update({ "end_sentence_idx": ed, 'end_char_idx': sents[ed]['end'], 'cite': ''.join([x['content'] for x in sents[merged_citations[-1]['start_sentence_idx']:ed+1]]), }) else: merged_citations.append({ "start_sentence_idx": st, "end_sentence_idx": ed, "start_char_idx": sents[st]['start'], 'end_char_idx': sents[ed]['end'], 'cite': ''.join([x['content'] for x in sents[st:ed+1]]), }) except: print(c_texts, len(sents), statement) raise return statement, merged_citations[:3] @staticmethod def postprocess(answer, sents, splited_context): res = [] pos = 0 new_answer = "" while True: st = answer.find("", pos) if st == -1: st = len(answer) ed = answer.find("", st) statement = answer[pos:st] if len(statement.strip()) > 5: res.append({ "statement": statement, "citation": [] }) new_answer += f"{statement}" else: res.append({ "statement": statement, "citation": None, }) new_answer += statement if ed == -1: break statement = answer[st+len(""):ed] if len(statement.strip()) > 0: statement, citations = LongCiteModel.get_citations(statement, sents) res.append({ "statement": statement, "citation": citations }) c_str = ''.join(['[{}-{}]'.format(c['start_sentence_idx'], c['end_sentence_idx']) for c in citations]) new_answer += f"{statement}{c_str}" else: res.append({ "statement": statement, "citation": None, }) new_answer += statement pos = ed + len("") return { "answer": new_answer.strip(), "statements_with_citations": [x for x in res if x['citation'] is not None], "splited_context": splited_context.strip(), "all_statements": res, } @staticmethod def truncate_from_middle(prompt, max_input_length=None, tokenizer=None): if max_input_length is None: return prompt else: assert tokenizer is not None tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False) if len(tokenized_prompt) > max_input_length: half = int(max_input_length/2) prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) return prompt if __name__ == "__main__": context = ''' your context ''' query = "your user question here" prompt, sents, splited_context = LongCiteModel.get_prompt(context, query) print('Prompt:', prompt) # add the Llama 3 tags to the prompt max_input_length = 4096 output = "..." # what the llm returned result = LongCiteModel.postprocess(output, sents, splited_context) ```