simonduerr commited on
Commit
5889003
1 Parent(s): 9d493bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -79,17 +79,18 @@ def predict_structure(prefix, feature_dict, model_runners, random_seed=0):
79
  return plddts
80
 
81
 
82
- def run_protgpt2(startsequence, length):
83
  protgpt2 = pl("text-generation", model="nferruz/ProtGPT2")
84
  sequences = protgpt2(
85
  startsequence,
86
  max_length=length,
87
  do_sample=True,
88
- top_k=950,
89
- repetition_penalty=1.2,
90
- num_return_sequences=5,
91
  eos_token_id=0,
92
  )
 
93
  torch.cuda.empty_cache()
94
  return sequences
95
 
@@ -342,11 +343,15 @@ with proteindream:
342
  with gr.Group():
343
  with gr.Row():
344
  inp = gr.Textbox(placeholder="M", label="Start sequence")
345
- length = gr.Number(value=50, label="Target sequence length")
 
 
 
 
346
  btn = gr.Button("Predict sequences using protGPT2")
347
 
348
  results = gr.Textbox(label="Results", lines=15)
349
- btn.click(fn=update_protGPT2, inputs=[inp, length], outputs=results)
350
 
351
  gr.Markdown("## AlphaFold")
352
  gr.Markdown(
 
79
  return plddts
80
 
81
 
82
+ def run_protgpt2(startsequence, length, repetitionPenalty, top_k_poolsize, max_seqs):
83
  protgpt2 = pl("text-generation", model="nferruz/ProtGPT2")
84
  sequences = protgpt2(
85
  startsequence,
86
  max_length=length,
87
  do_sample=True,
88
+ top_k=top_k_poolsize,
89
+ repetition_penalty=repetitionPenalty,
90
+ num_return_sequences=max_seqs,
91
  eos_token_id=0,
92
  )
93
+ del protgpt2
94
  torch.cuda.empty_cache()
95
  return sequences
96
 
 
343
  with gr.Group():
344
  with gr.Row():
345
  inp = gr.Textbox(placeholder="M", label="Start sequence")
346
+ length = gr.Number(value=50, label="Max sequence length")
347
+ with gr.Row()
348
+ repetitionPenalty = gr.Slider(minimum=1, maximum=5,value=1.2, label="Repetition penalty")
349
+ top_k_poolsize = gr.Slider(minimum=700, maximum=52056,value=950, label="Top-K sampling pool size")
350
+ max_seqs = gr.Slider(minimum=2, maximum=20,value=5, label="Number of sequences to generate")
351
  btn = gr.Button("Predict sequences using protGPT2")
352
 
353
  results = gr.Textbox(label="Results", lines=15)
354
+ btn.click(fn=update_protGPT2, inputs=[inp, length, repetitionPenalty, top_k_poolsize, max_seqs], outputs=results)
355
 
356
  gr.Markdown("## AlphaFold")
357
  gr.Markdown(