lfqa1 / util /common.py
Achyut Tiwari
Add files via upload
e067d8b unverified
raw
history blame
No virus
5.12 kB
import re
import torch
kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
'wikidata_info', 'history']
kilt_wikipedia_paragraph_columns = ['wikipedia_id', 'start_paragraph_id', 'start_character', 'end_paragraph_id',
'end_character', 'title', 'section', 'text']
def clean_question(text):
result = cleanup_references(text)
result = result.replace("\n", " ")
result = re.sub(r"\s\s+", " ", result)
result = result.replace("[deleted]", "")
return result.lower().strip()
def cleanup_references(text):
# URL reference where we need to remove both the link text and URL
# ...and this letter is used by most biographers as the cornerstone of Lee's personal
# views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)).
# ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery.
result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE)
# URL reference where we need to preserve link text but remove URL
# At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South.
# At the outbreak of the Civil War, Leyburn left his church and joined the South.
result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE)
# lastly remove just dangling _URL_[0-9]_ URL references
result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE)
return result
def clean_answer(text):
result = cleanup_references(text)
result = result.replace("\n", " ")
result = re.sub(r"\s\s+", " ", result)
result = re.sub(r"BULLET::::-", "", result)
return trim(result.strip())
def trim(text, word_count: int = 100):
return " ".join(text.split(" ")[:word_count])
def articles_to_paragraphs(examples):
ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], []
for bidx, example in enumerate(examples["text"]):
last_section = ""
for idx, p in enumerate(example["paragraph"]):
if "Section::::" in p:
last_section = p
ids.append(examples["wikipedia_id"][bidx])
titles.append(examples["wikipedia_title"][bidx])
sections.append(last_section)
texts.append(p)
start_ps.append(idx)
end_ps.append(idx)
start_cs.append(0)
end_cs.append(len(p))
return {"wikipedia_id": ids, "title": titles,
"section": sections, "text": texts,
"start_paragraph_id": start_ps, "end_paragraph_id": end_ps,
"start_character": start_cs,
"end_character": end_cs
}
def create_kilt_datapoint(eli5_example, columns, wiki_passages, min_length=20, topk=7):
res_list = [dict([(k, p[k]) for k in columns]) for p in wiki_passages]
res_list = [res for res in res_list if len(res["text"].split()) > min_length][:topk]
# make a KILT data point
# see https://github.com/facebookresearch/KILT#kilt-data-format
output = []
for a in eli5_example["answers"]["text"]:
output.append({"answer": a})
output.append({"provenance": [
# evidence set for the answer from the KILT ks
{
"wikipedia_id": r["wikipedia_id"], # *mandatory*
"title": r["title"],
"section": r["section"],
"start_paragraph_id": r["start_paragraph_id"],
"start_character": r["start_character"],
"end_paragraph_id": r["end_paragraph_id"],
"end_character": r["end_character"],
"text": r["text"],
"bleu_score": None, # wrt original evidence
"meta": None # dataset/task specific
} for r in res_list
]})
return {"id": eli5_example["q_id"],
"input": eli5_example["title"],
"output": output, # each element is an answer or provenance (can have multiple of each)
"meta": None # dataset/task specific
}
def embed_questions(question_model, question_tokenizer, questions, max_length=128, device="cuda:0"):
query = question_tokenizer(questions, max_length=max_length, padding="max_length", truncation=True,
return_tensors="pt")
with torch.no_grad():
q_reps = question_model(query["input_ids"].to(device),
query["attention_mask"].to(device)).pooler_output
return q_reps.cpu().numpy()
def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cuda:0"):
p = ctx_tokenizer(passages["text"], max_length=max_length, padding="max_length",
truncation=True, return_tensors="pt")
with torch.no_grad():
a_reps = ctx_model(p["input_ids"].to(device),
p["attention_mask"].to(device)).pooler_output
return {"embeddings": a_reps.cpu().numpy()}