import argparse import json import os import torch from datasets import load_dataset from tqdm.auto import tqdm from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DPRQuestionEncoder from common import articles_to_paragraphs, kilt_wikipedia_columns from common import kilt_wikipedia_paragraph_columns as columns def eval_generate(args): device = ("cuda" if torch.cuda.is_available() else "cpu") question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name) question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device) _ = question_model.eval() eli5_tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_eli5') eli5_model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_eli5').to(device) _ = eli5_model.eval() min_snippet_length = 20 topk = 21 min_chars_per_passage = 200 kilt_wikipedia = load_dataset("kilt_wikipedia", split="full") kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True, remove_columns=kilt_wikipedia_columns, batch_size=256, cache_file_name=f"./data/wiki_kilt_paragraphs_full.arrow", desc="Expanding wiki articles into paragraphs") # use paragraphs that are not simple fragments or very short sentences kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter( lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage) kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0) def embed_questions_for_retrieval(questions): query = question_tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt") with torch.no_grad(): q_reps = question_model(query["input_ids"].to(device), query["attention_mask"].to(device)).pooler_output return q_reps.cpu().numpy() def query_index(question): question_embedding = embed_questions_for_retrieval([question]) scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk) retrieved_examples = [] r = list(zip(wiki_passages[k] for k in columns)) for i in range(topk): retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])}) return retrieved_examples def create_kilt_datapoint(q_id, query, answer, res_list): # make a KILT data point # see https://github.com/facebookresearch/KILT#kilt-data-format provenance = [{ "wikipedia_id": r["wikipedia_id"], # *mandatory* "title": r["title"], "section": r["section"], "start_paragraph_id": r["start_paragraph_id"], "start_character": r["start_character"], "end_paragraph_id": r["end_paragraph_id"], "end_character": r["end_character"], "text": r["text"], "bleu_score": None, # wrt original evidence "meta": None # dataset/task specific } for r in res_list] output = [{"answer": answer, "provenance": provenance}] return {"id": q_id, "input": query, "output": output, # each element is an answer or provenance (can have multiple of each) "meta": None # dataset/task specific } kilt_output = [] with open(args.kilt_input_file, "r") as f: kilt_items = [json.loads(x) for x in f.read().strip().split("\n")] progress_bar = tqdm(range(len(kilt_items)), desc="Creating KILT response document") for idx, item in enumerate(kilt_items): query = item["input"] res_list = query_index(query) res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)] documents = [res["text"] for res in res_list] conditioned_doc = "

" + "

".join([d for d in documents]) query_and_docs = "question: {} context: {}".format(query, conditioned_doc) model_input = eli5_tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt") generated_answers_encoded = eli5_model.generate(input_ids=model_input["input_ids"].to(device), attention_mask=model_input["attention_mask"].to(device), min_length=50, max_length=250, do_sample=False, early_stopping=True, num_beams=8, temperature=1.0, top_k=None, top_p=None, no_repeat_ngram_size=3, num_return_sequences=1) answer = eli5_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True, clean_up_tokenization_spaces=True) kilt_example = create_kilt_datapoint(item["id"], query, answer[0], res_list) kilt_output.append(kilt_example) progress_bar.update(1) with open(args.kilt_output_file, "w") as fp: for kilt_example in kilt_output: json.dump(kilt_example, fp) fp.write("\n") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--kilt_input_file', default="./eli5-dev-kilt.jsonl", type=str) parser.add_argument('--kilt_output_file', default="./eli5-predicted_retrieval.jsonl", type=str) parser.add_argument( "--question_encoder_name", default="vblagoje/dpr-question_encoder-single-lfqa-base", help="Question encoder to use", ) parser.add_argument( "--index_file_name", default="../data/kilt_dpr_wikipedia_first.faiss", help="Faiss index with passage embeddings", ) args = parser.parse_args() assert os.path.isfile(args.kilt_input_file), f"Input file {args.kilt_input_file} couldn't be loaded" eval_generate(args)