HTK / evaluate_squad_v2.py
faori's picture
Upload folder using huggingface_hub
550665c verified
raw
history blame
No virus
4.62 kB
import os
os.environ["TF_ENABLE_ONEDNN_OPTS"] = '0'
from huggingface_hub import login
from typing import Union, Any, Dict
# from datasets.arrow_dataset import Batch
import argparse
import datasets
from transformers.utils import logging, check_min_version
from transformers.utils.versions import require_version
from retro_reader import RetroReader
from retro_reader.constants import EXAMPLE_FEATURES
import torch
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.13.0.dev0")
require_version("datasets>=1.8.0")
logger = logging.get_logger(__name__)
def schema_integrate(example) -> Union[Dict, Any]:
title = example["title"]
question = example["question"]
context = example["context"]
guid = example["id"]
classtype = [""] * len(title)
dataset_name = source = ["squad_v2"] * len(title)
answers, is_impossible = [], []
for answer_examples in example["answers"]:
if answer_examples["text"]:
answers.append(answer_examples)
is_impossible.append(False)
else:
answers.append({"text": [""], "answer_start": [-1]})
is_impossible.append(True)
# The feature names must be sorted.
return {
"guid": guid,
"question": question,
"context": context,
"answers": answers,
"title": title,
"classtype": classtype,
"source": source,
"is_impossible": is_impossible,
"dataset": dataset_name,
}
# data augmentation for multiple answers
def data_aug_for_multiple_answers(examples) -> Union[Dict, Any]:
result = {key: [] for key in examples.keys()}
def update(i, answers=None):
for key in result.keys():
if key == "answers" and answers is not None:
result[key].append(answers)
else:
result[key].append(examples[key][i])
for i, (answers, unanswerable) in enumerate(
zip(examples["answers"], examples["is_impossible"])
):
answerable = not unanswerable
assert (
len(answers["text"]) == len(answers["answer_start"]) or
answers["answer_start"][0] == -1
)
if answerable and len(answers["text"]) > 1:
for n_ans in range(len(answers["text"])):
ans = {
"text": [answers["text"][n_ans]],
"answer_start": [answers["answer_start"][n_ans]],
}
update(i, ans)
elif not answerable:
update(i, {"text": [], "answer_start": []})
else:
update(i)
return result
def main(args):
# Load SQuAD V2.0 dataset
print("Loading SQuAD v2.0 dataset ...")
squad_v2 = datasets.load_dataset("squad_v2")
# TODO: Visualize a sample from the dataset
# Integrate into the schema used in this library
# Note: The columns used for preprocessing are `question`, `context`, `answers`
# and `is_impossible`. The remaining columns are columns that exist to
# process other types of data.
# Minize the dataset for debugging
if args.debug:
squad_v2["validation"] = squad_v2["validation"].select(range(5))
print("Integrating into the schema used in this library ...")
squad_v2 = squad_v2.map(
schema_integrate,
batched=True,
remove_columns=squad_v2.column_names["train"],
features=EXAMPLE_FEATURES,
)
# Load Retro Reader
# features: parse arguments
# make train/eval dataset from examples
# load model from 🤗 hub
# set sketch/intensive reader and rear verifier
print("Loading Retro Reader ...")
retro_reader = RetroReader.load(
config_file=args.configs,
device="cuda" if torch.cuda.is_available() else "cpu",
)
# Train
res = retro_reader.evaluate(squad_v2["validation"])
print(res)
logger.warning("Train retrospective reader Done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--configs", "-c", type=str, default="configs/inference_electra_base.yaml", help="config file path")
parser.add_argument("--batch_size", "-b", type=int, default=1024, help="batch size")
parser.add_argument("--debug", "-d", action="store_true", help="debug mode")
args = parser.parse_args()
main(args)