File size: 6,832 Bytes
e067d8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import argparse
import json
import os

import torch
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DPRQuestionEncoder

from common import articles_to_paragraphs, kilt_wikipedia_columns
from common import kilt_wikipedia_paragraph_columns as columns


def eval_generate(args):
    device = ("cuda" if torch.cuda.is_available() else "cpu")
    question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
    question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
    _ = question_model.eval()

    eli5_tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_eli5')
    eli5_model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_eli5').to(device)
    _ = eli5_model.eval()

    min_snippet_length = 20
    topk = 21
    min_chars_per_passage = 200
    kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
    kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
                                                   remove_columns=kilt_wikipedia_columns,
                                                   batch_size=256,
                                                   cache_file_name=f"./data/wiki_kilt_paragraphs_full.arrow",
                                                   desc="Expanding wiki articles into paragraphs")

    # use paragraphs that are not simple fragments or very short sentences
    kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
        lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
    kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)

    def embed_questions_for_retrieval(questions):
        query = question_tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
        with torch.no_grad():
            q_reps = question_model(query["input_ids"].to(device),
                                    query["attention_mask"].to(device)).pooler_output
        return q_reps.cpu().numpy()

    def query_index(question):
        question_embedding = embed_questions_for_retrieval([question])
        scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)

        retrieved_examples = []
        r = list(zip(wiki_passages[k] for k in columns))
        for i in range(topk):
            retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
        return retrieved_examples

    def create_kilt_datapoint(q_id, query, answer, res_list):
        # make a KILT data point
        # see https://github.com/facebookresearch/KILT#kilt-data-format

        provenance = [{
            "wikipedia_id": r["wikipedia_id"],  # *mandatory*
            "title": r["title"],
            "section": r["section"],
            "start_paragraph_id": r["start_paragraph_id"],
            "start_character": r["start_character"],
            "end_paragraph_id": r["end_paragraph_id"],
            "end_character": r["end_character"],
            "text": r["text"],
            "bleu_score": None,  # wrt original evidence
            "meta": None  # dataset/task specific
        } for r in res_list]

        output = [{"answer": answer, "provenance": provenance}]

        return {"id": q_id,
                "input": query,
                "output": output,  # each element is an answer or provenance (can have multiple of each)
                "meta": None  # dataset/task specific
                }

    kilt_output = []
    with open(args.kilt_input_file, "r") as f:
        kilt_items = [json.loads(x) for x in f.read().strip().split("\n")]
        progress_bar = tqdm(range(len(kilt_items)), desc="Creating KILT response document")
        for idx, item in enumerate(kilt_items):
            query = item["input"]
            res_list = query_index(query)

            res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
            documents = [res["text"] for res in res_list]
            conditioned_doc = "<P> " + " <P> ".join([d for d in documents])

            query_and_docs = "question: {} context: {}".format(query, conditioned_doc)

            model_input = eli5_tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt")
            generated_answers_encoded = eli5_model.generate(input_ids=model_input["input_ids"].to(device),
                                                            attention_mask=model_input["attention_mask"].to(device),
                                                            min_length=50,
                                                            max_length=250,
                                                            do_sample=False,
                                                            early_stopping=True,
                                                            num_beams=8,
                                                            temperature=1.0,
                                                            top_k=None,
                                                            top_p=None,
                                                            no_repeat_ngram_size=3,
                                                            num_return_sequences=1)
            answer = eli5_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
                                                 clean_up_tokenization_spaces=True)

            kilt_example = create_kilt_datapoint(item["id"], query, answer[0], res_list)
            kilt_output.append(kilt_example)
            progress_bar.update(1)

    with open(args.kilt_output_file, "w") as fp:
        for kilt_example in kilt_output:
            json.dump(kilt_example, fp)
            fp.write("\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--kilt_input_file', default="./eli5-dev-kilt.jsonl", type=str)
    parser.add_argument('--kilt_output_file', default="./eli5-predicted_retrieval.jsonl", type=str)
    parser.add_argument(
        "--question_encoder_name",
        default="vblagoje/dpr-question_encoder-single-lfqa-base",
        help="Question encoder to use",
    )

    parser.add_argument(
        "--index_file_name",
        default="../data/kilt_dpr_wikipedia_first.faiss",
        help="Faiss index with passage embeddings",
    )

    args = parser.parse_args()

    assert os.path.isfile(args.kilt_input_file), f"Input file {args.kilt_input_file} couldn't be loaded"
    eval_generate(args)