Lang2mol-Diff / src /anlg_infill /post_process.py
ndhieunguyen's picture
Add application file
7dd9869
raw
history blame
No virus
1.34 kB
import torch
import json
from transformers import BertForMaskedLM, BertTokenizer
filename = 'diffusion_lm/improved-diffusion/anlg_results/diff_roc_mbr.json2'
bert_model = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model)
model = BertForMaskedLM.from_pretrained(bert_model).cuda()
full_lst = []
with open(filename, 'r') as f:
for line in f:
line = json.loads(line)
full_lst.append(line)
for example in full_lst:
sent = example['sample']
obs1 = example['obs1']
obs2 = example['obs2']
if 'UNK' in sent:
sent = obs1 + sent.replace('UNK', tokenizer.mask_token) + obs2
print(sent)
model_inputs = tokenizer(sent,return_tensors="pt")
model_inputs = {k:v.to(model.device) for k,v in model_inputs.items()}
model_out = model(**model_inputs)
mask_words = model_inputs['input_ids'] == tokenizer.mask_token_id
masked_logits = model_out.logits[mask_words].view(-1, model_out.logits.size(-1))
if masked_logits.size(0) > 0:
# take argmax from this.
max_cands = torch.max(masked_logits, dim=-1)
indices = max_cands.indices
model_inputs['input_ids'][mask_words] = indices
print(tokenizer.batch_decode(model_inputs['input_ids'].tolist()))
else:
print('NO NEED THIS FIX. ')