Lang2mol-Diff / app.py
ndhieunguyen's picture
Add application file
7dd9869
raw
history blame
No virus
2.83 kB
import torch
import argparse
import selfies as sf
from tqdm import tqdm
from transformers import T5EncoderModel
from transformers import set_seed
from src.scripts.mytokenizers import Tokenizer
from src.improved_diffusion import gaussian_diffusion as gd
from src.improved_diffusion import dist_util, logger
from src.improved_diffusion.respace import SpacedDiffusion
from src.improved_diffusion.transformer_model import TransformerNetModel
from src.improved_diffusion.script_util import (
model_and_diffusion_defaults,
add_dict_to_argparser,
)
from src.scripts.mydatasets import Lang2molDataset_submission
import streamlit as st
import os
@st.cache_resource
def get_encoder():
model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
model.eval()
return model
@st.cache_resource
def get_tokenizer():
return Tokenizer()
@st.cache_resource
def get_model():
model = TransformerNetModel(
in_channels=32,
model_channels=128,
dropout=0.1,
vocab_size=35073,
hidden_size=1024,
num_attention_heads=16,
num_hidden_layers=12,
)
model.load_state_dict(
dist_util.load_state_dict(
os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
map_location="cpu",
)
)
model.eval()
return model
@st.cache_resource
def get_diffusion():
return SpacedDiffusion(
use_timesteps=[i for i in range(0, 2000, 10)],
betas=gd.get_named_beta_schedule("sqrt", 2000),
model_mean_type=(gd.ModelMeanType.START_X),
model_var_type=((gd.ModelVarType.FIXED_LARGE)),
loss_type=gd.LossType.E2E_MSE,
rescale_timesteps=True,
model_arch="transformer",
training_mode="e2e",
)
tokenizer = get_tokenizer()
encoder = get_encoder()
model = get_model()
diffusion = get_diffusion()
sample_fn = diffusion.ddim_sample_loop
text_input = st.text_area("Enter molecule description")
output = tokenizer(
text_input,
max_length=256,
truncation=True,
padding="max_length",
add_special_tokens=True,
return_tensors="pt",
return_attention_mask=True,
)
caption_state = encoder(
input_ids=output["input_ids"],
attention_mask=output["attention_mask"],
).last_hidden_state
caption_mask = output["attention_mask"]
outputs = sample_fn(
model,
(1, 256, 32),
clip_denoised=False,
denoised_fn=None,
model_kwargs={},
top_p=1.0,
progress=True,
caption=(caption_state, caption_mask),
)
logits = model.get_logits(torch.tensor(outputs))
cands = torch.topk(logits, k=1, dim=-1)
outputs = cands.indices
outputs = outputs.squeeze(-1)
outputs = tokenizer.decode(outputs)
result = sf.decoder(
outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
).replace("\t", "")
st.write(result)