Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
import torch | |
from torch.distributions.categorical import Categorical | |
# Load the model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("TianlaiChen/PepMLM-650M") | |
model = AutoModelForMaskedLM.from_pretrained("TianlaiChen/PepMLM-650M") | |
def generate_peptide(protein_seq, peptide_length, top_k): | |
peptide_length = int(peptide_length) | |
top_k = int(top_k) | |
masked_peptide = '<mask>' * peptide_length | |
input_sequence = protein_seq + masked_peptide | |
inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] | |
logits_at_masks = logits[0, mask_token_indices] | |
# Apply top-k sampling | |
top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1) | |
probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1) | |
predicted_indices = Categorical(probabilities).sample() | |
predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1) | |
generated_peptide = tokenizer.decode(predicted_token_ids, skip_special_tokens=True) | |
return f"Generated Sequence: {generated_peptide.replace(' ', '')}" | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=generate_peptide, | |
inputs=[ | |
gr.inputs.Textbox(label="Protein Sequence", default="Enter protein sequence here", type="text"), | |
gr.inputs.Dropdown(choices=[str(i) for i in range(2, 51)], label="Peptide Length", default="15"), | |
gr.inputs.Dropdown(choices=[str(i) for i in range(1, 11)], label="Top K Value", default="3") | |
], | |
outputs="textbox", | |
live=True | |
) | |
interface.launch() | |