PepMLM / app.py
TianlaiChen's picture
gen
9b1cba9
raw
history blame
No virus
3.7 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from torch.distributions.categorical import Categorical
import numpy as np
import pandas as pd
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("TianlaiChen/PepMLM-650M")
model = AutoModelForMaskedLM.from_pretrained("TianlaiChen/PepMLM-650M")
def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
sequence = protein_seq + binder_seq
tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
# Create a mask for the binder sequence
binder_mask = torch.zeros(tensor_input.shape).to(model.device)
binder_mask[0, -len(binder_seq)-1:-1] = 1
# Mask the binder sequence in the input and create labels
masked_input = tensor_input.clone().masked_fill_(binder_mask.bool(), tokenizer.mask_token_id)
labels = tensor_input.clone().masked_fill_(~binder_mask.bool(), -100)
with torch.no_grad():
loss = model(masked_input, labels=labels).loss
return np.exp(loss.item())
def generate_peptide(protein_seq, peptide_length, top_k, num_binders):
peptide_length = int(peptide_length)
top_k = int(top_k)
num_binders = int(num_binders)
binders_with_ppl = []
for _ in range(num_binders):
# Generate binder
masked_peptide = '<mask>' * peptide_length
input_sequence = protein_seq + masked_peptide
inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)
with torch.no_grad():
logits = model(**inputs).logits
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
logits_at_masks = logits[0, mask_token_indices]
# Apply top-k sampling
top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
predicted_indices = Categorical(probabilities).sample()
predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')
# Compute PPL for the generated binder
ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)
# Add the generated binder and its PPL to the results list
binders_with_ppl.append([generated_binder, ppl_value])
# Convert the list of lists to a pandas dataframe
df = pd.DataFrame(binders_with_ppl, columns=["Binder", "Perplexity"])
# Save the dataframe to a CSV file
output_filename = "output.csv"
df.to_csv(output_filename, index=False)
return binders_with_ppl, output_filename
# Define the Gradio interface
interface = gr.Interface(
fn=generate_peptide,
inputs=[
gr.Textbox(label="Protein Sequence", info="Enter protein sequence here", type="text"),
gr.Slider(3, 50, value=15, label="Peptide Length", step=1, info='Default value is 15'),
gr.Slider(1, 10, value=3, label="Top K Value", step=1, info='Default value is 3'),
gr.Dropdown(choices=[1, 2, 4, 8, 16, 32], label="Number of Binders", value=1)
],
outputs=[
gr.Dataframe(
headers=["Binder", "Perplexity"],
datatype=["str", "number"],
col_count=(2, "fixed")
),
gr.outputs.File(label="Download CSV")
],
title="PepMLM: Target Sequence-Conditioned Generation of Peptide Binders via Masked Language Modeling"
)
interface.launch()