Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
import torch | |
from torch.distributions.categorical import Categorical | |
# Load the model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("TianlaiChen/PepMLM-650M") | |
model = AutoModelForMaskedLM.from_pretrained("TianlaiChen/PepMLM-650M") | |
def generate_peptide(protein_seq, peptide_length, top_k): | |
peptide_length = int(peptide_length) | |
top_k = int(top_k) | |
masked_peptide = '<mask>' * peptide_length | |
input_sequence = protein_seq + masked_peptide | |
inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] | |
logits_at_masks = logits[0, mask_token_indices] | |
# Apply top-k sampling | |
top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1) | |
probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1) | |
predicted_indices = Categorical(probabilities).sample() | |
predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1) | |
generated_peptide = tokenizer.decode(predicted_token_ids, skip_special_tokens=True) | |
return generated_peptide.replace(' ', '') | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=generate_peptide, | |
inputs=[ | |
gr.Textbox(label="Protein Sequence", info = "Enter protein sequence here", type="text"), | |
gr.Slider(3, 50, value=15, label="Peptide Length", | |
info='Default value is 15'), | |
gr.Slider(1, 10, value=3, label="Top K Value", default="3", | |
info='Default value is 3') | |
], | |
outputs="textbox", | |
) | |
interface.launch() | |