ndhieunguyen's picture
Add application file
7dd9869
raw
history blame
No virus
5.28 kB
import json
import sys, os, torch
from spacy.lang.en import English
from improved_diffusion.rounding import rounding_func, load_models, load_tokenizer
from transformers import AutoModelForCausalLM
# read files.
# with open('diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json', 'r') as f:
SPLIT = 'test'
if SPLIT == 'val':
source_file = 'diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json'
elif SPLIT == 'test':
source_file = 'diffusion_lm/ROCstory/anlg/anlg/test_cleanup_no_label.json'
else:
assert False, "invalid split"
with open(source_file, 'r') as f:
sent_lst = json.load(f)
nlp = English()
tokenizer = nlp.tokenizer
MODE = 'ar'
'''
"00b9adb2-b3b6-4737-902a-50f308bac4b5-1": {
"gold_labels": [
"I put my baby in the car and drove around.",
"I realized he needed his blanket, which I had forgotten at a faraway hotel.",
"I took a drive to get my baby to sleep.",
"I took my baby for a drive and she fell asleep in the car."
],
"obs1": "My baby would not go to sleep last night.",
"obs2": "I wound up driving for hours."
},
'''
print(len(sent_lst))
if MODE == 'ar':
model_name = 'predictability/diff_models/roc_e=20_b=32_m=gpt2_wikitext-103-raw-v1_101_wp_pad_infill'
model_name = 'predictability/diff_models/roc_e=6_b=10_m=gpt2_wikitext-103-raw-v1_101_wp_pad_infill_v2'
model = AutoModelForCausalLM.from_pretrained(
model_name, # path to the AR model trained for LMing this task.
).cuda()
tokenizer2 = load_tokenizer('roc', 'random',
'predictability/diffusion_models_v7/diff_roc_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart')
vocab = {v: k for k, v in tokenizer2.items()}
print(len(tokenizer2), len(vocab), 'loaded vocabs')
outfile='ar_sample_full_test_v2.json'
filehandle = open(outfile, 'w')
for idx, (key, val) in enumerate(sent_lst.items()):
# if idx <= 499:
# continue
# if idx >= 500:
# continue
# if idx != 684:
# continue
if MODE == 'diff':
partial_seq = f"{val['obs1']} " + "PAD "*10 + f"{val['obs2']}"
word_lst = [x.text for x in tokenizer(partial_seq)]
partial_seq = " ".join(word_lst)
print(partial_seq, idx)
# partial_seq = "Brenna and I used to be best friends . PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD We never talked again ."
COMMAND = "python ../scripts/infill.py " \
"--model_path predictability/diffusion_models_v7/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long/ema_0.9999_800000.pt " \
" --batch_size 50 " \
f"--partial_seq \'{partial_seq}\' " \
f"--eval_task_ infill --notes {SPLIT}_{idx} " \
f"--out_dir ../anlg_results"
os.system(COMMAND)
torch.cuda.empty_cache()
elif MODE == 'ar':
partial_seq = f"{val['obs1']} " + f"{val['obs2']}"
print(partial_seq)
word_idx_lst = [vocab['START']] + [vocab.get(x.text, vocab['UNK']) for x in tokenizer(partial_seq)]
init_prompt = torch.LongTensor(word_idx_lst).cuda().unsqueeze(0)
print(init_prompt.shape)
# sample_out = model.generate(init_prompt, do_sample=True, max_length=64, top_k=len(vocab))
if 'sample' in outfile:
print('sampling 50 examples.')
init_prompt = init_prompt.expand(50, -1)
sample_out = model.generate(init_prompt, do_sample=True, max_length=64, top_k=len(vocab))
else:
sample_out = model.generate(init_prompt, do_sample=False, num_beam=4, max_length=64, top_k=len(vocab))
print(sample_out.shape)
sample_out = sample_out[:, init_prompt.size(1):]
# decode
if 'sample' in outfile:
sample_lst = []
for examp in sample_out:
sample = examp.tolist()
words_sample = [tokenizer2[s] for s in sample]
tempsent = [x for x in words_sample if x != 'PAD']
if tempsent[0] == 'START':
tempsent = tempsent[1:]
if tempsent[-1] == 'END':
tempsent = tempsent[:-1]
result_sent = " ".join(tempsent)
sample_lst.append(result_sent)
out_dict = {'idx': idx,
'obs1': val['obs1'],
'obs2': val['obs2'],
'samples': sample_lst}
print(json.dumps(out_dict), file=filehandle)
else:
sample = sample_out[0].tolist()
words_sample = [tokenizer2[s] for s in sample]
tempsent = [x for x in words_sample if x != 'PAD']
if tempsent[0] == 'START':
tempsent = tempsent[1:]
if tempsent[-1] == 'END':
tempsent = tempsent[:-1]
result_sent = " ".join(tempsent)
out_dict = {'idx':idx,
'obs1':val['obs1'],
'obs2':val['obs2'],
'sample':result_sent}
print(json.dumps(out_dict), file=filehandle)
filehandle.close()
print(f'written to {outfile}')