''' Code from https://github.com/blender-nlp/MolT5 ```bibtex @article{edwards2022translation, title={Translation between Molecules and Natural Language}, author={Edwards, Carl and Lai, Tuan and Ros, Kevin and Honke, Garrett and Ji, Heng}, journal={arXiv preprint arXiv:2204.11817}, year={2022} } ``` ''' import pickle import argparse import csv import os.path as osp import numpy as np #load metric stuff from nltk.translate.bleu_score import corpus_bleu #from nltk.translate.meteor_score import meteor_score from Levenshtein import distance as lev from rdkit import Chem from rdkit import RDLogger RDLogger.DisableLog('rdApp.*') def evaluate(input_fp, verbose=False): outputs = [] with open(osp.join(input_fp)) as f: reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE) for n, line in enumerate(reader): gt_smi = line['ground truth'] ot_smi = line['output'] outputs.append((line['description'], gt_smi, ot_smi)) bleu_scores = [] #meteor_scores = [] references = [] hypotheses = [] for i, (smi, gt, out) in enumerate(outputs): if i % 100 == 0: if verbose: print(i, 'processed.') gt_tokens = [c for c in gt] out_tokens = [c for c in out] references.append([gt_tokens]) hypotheses.append(out_tokens) # mscore = meteor_score([gt], out) # meteor_scores.append(mscore) # BLEU score bleu_score = corpus_bleu(references, hypotheses) if verbose: print('BLEU score:', bleu_score) # Meteor score # _meteor_score = np.mean(meteor_scores) # print('Average Meteor score:', _meteor_score) rouge_scores = [] references = [] hypotheses = [] levs = [] num_exact = 0 bad_mols = 0 for i, (smi, gt, out) in enumerate(outputs): hypotheses.append(out) references.append(gt) try: m_out = Chem.MolFromSmiles(out) m_gt = Chem.MolFromSmiles(gt) if Chem.MolToInchi(m_out) == Chem.MolToInchi(m_gt): num_exact += 1 #if gt == out: num_exact += 1 #old version that didn't standardize strings except: bad_mols += 1 levs.append(lev(out, gt)) # Exact matching score exact_match_score = num_exact/(i+1) if verbose: print('Exact Match:') print(exact_match_score) # Levenshtein score levenshtein_score = np.mean(levs) if verbose: print('Levenshtein:') print(levenshtein_score) validity_score = 1 - bad_mols/len(outputs) if verbose: print('validity:', validity_score) return bleu_score, exact_match_score, levenshtein_score, validity_score if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--input_file', type=str, default='caption2smiles_example.txt', help='path where test generations are saved') args = parser.parse_args() evaluate(args.input_file, verbose=True)