Simon Duerr commited on
Commit
f78ba33
1 Parent(s): 776a14e
Files changed (1) hide show
  1. app.py +365 -12
app.py CHANGED
@@ -1,17 +1,370 @@
1
  import gradio as gr
2
- import jax
3
 
4
- demo = gr.Blocks()
5
- def speech_to_text():
6
- return "Test"
7
- with demo:
8
- text = gr.Textbox()
9
- seqs= ["DSDAJSKDPAKSDPKPEKPKEPKDPAD", "MKTPKPSKPDKAPSKFKRNGKMSODAPS;DPALSPFKOIFMENENOLMAPSDASPFMKOVMFEOM"]
10
- seqChoice = gr.Radio(seqs)
11
 
12
- b1 = gr.Button("Run ProtGPT2")
13
- b2 = gr.Button("Predicted structure")
14
 
15
- b1.click(speech_to_text, inputs=text, outputs=seqs)
16
 
17
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
 
3
+ import numpy as np
 
 
 
 
 
 
4
 
5
+ import os
 
6
 
7
+ import matplotlib
8
 
9
+ matplotlib.use("Agg")
10
+ import matplotlib.pyplot as plt
11
+ from transformers import pipeline as pl
12
+
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import sys
16
+
17
+ print(os.getcwd())
18
+ if "/scratch/duerr/gradiofold2/alphafold" not in sys.path:
19
+ sys.path.append("/scratch/duerr/gradiofold2/alphafold")
20
+
21
+ from alphafold.common import protein
22
+ from alphafold.data import pipeline
23
+ from alphafold.data import templates
24
+ from alphafold.model import data
25
+ from alphafold.model import config
26
+ from alphafold.model import model
27
+
28
+
29
+ def update_seqs(choice):
30
+ return gr.Textbox.update(choice)
31
+
32
+
33
+ def mk_mock_template(query_sequence):
34
+ """create blank template"""
35
+ ln = len(query_sequence)
36
+ output_templates_sequence = "-" * ln
37
+ templates_all_atom_positions = np.zeros(
38
+ (ln, templates.residue_constants.atom_type_num, 3)
39
+ )
40
+ templates_all_atom_masks = np.zeros((ln, templates.residue_constants.atom_type_num))
41
+ templates_aatype = templates.residue_constants.sequence_to_onehot(
42
+ output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID
43
+ )
44
+ template_features = {
45
+ "template_all_atom_positions": templates_all_atom_positions[None],
46
+ "template_all_atom_masks": templates_all_atom_masks[None],
47
+ "template_aatype": np.array(templates_aatype)[None],
48
+ "template_domain_names": [f"none".encode()],
49
+ }
50
+ return template_features
51
+
52
+
53
+ def predict_structure(prefix, feature_dict, model_runners, random_seed=0):
54
+ """Predicts structure using AlphaFold for the given sequence."""
55
+
56
+ # Run the models.
57
+ plddts = {}
58
+ for model_name, model_runner in model_runners.items():
59
+ processed_feature_dict = model_runner.process_features(
60
+ feature_dict, random_seed=random_seed
61
+ )
62
+ prediction_result = model_runner.predict(processed_feature_dict)
63
+ b_factors = (
64
+ prediction_result["plddt"][:, None]
65
+ * prediction_result["structure_module"]["final_atom_mask"]
66
+ )
67
+ unrelaxed_protein = protein.from_prediction(
68
+ processed_feature_dict, prediction_result, b_factors
69
+ )
70
+ unrelaxed_pdb_path = f"{prefix}_unrelaxed_{model_name}.pdb"
71
+ plddts[model_name] = prediction_result["plddt"]
72
+
73
+ print(f"{model_name} {plddts[model_name].mean()}")
74
+
75
+ with open(unrelaxed_pdb_path, "w") as f:
76
+ f.write(protein.to_pdb(unrelaxed_protein))
77
+ return plddts
78
+
79
+
80
+ def run_protgpt2(startsequence, length):
81
+ protgpt2 = pl("text-generation", model="nferruz/ProtGPT2")
82
+ sequences = protgpt2(
83
+ startsequence,
84
+ max_length=length,
85
+ do_sample=True,
86
+ top_k=950,
87
+ repetition_penalty=1.2,
88
+ num_return_sequences=5,
89
+ eos_token_id=0,
90
+ )
91
+ return sequences
92
+
93
+
94
+ def run_alphafold(startsequence):
95
+ model_runners = {}
96
+ models = ["model_1"] # ,"model_2","model_3","model_4","model_5"]
97
+ for model_name in models:
98
+ model_config = config.model_config(model_name)
99
+ model_config.data.eval.num_ensemble = 1
100
+ model_params = data.get_model_haiku_params(model_name=model_name, data_dir=".")
101
+ model_runner = model.RunModel(model_config, model_params)
102
+ model_runners[model_name] = model_runner
103
+ query_sequence = startsequence.replace("\n", "")
104
+
105
+ feature_dict = {
106
+ **pipeline.make_sequence_features(
107
+ sequence=query_sequence, description="none", num_res=len(query_sequence)
108
+ ),
109
+ **pipeline.make_msa_features(
110
+ msas=[[query_sequence]], deletion_matrices=[[[0] * len(query_sequence)]]
111
+ ),
112
+ **mk_mock_template(query_sequence),
113
+ }
114
+ plddts = predict_structure("test", feature_dict, model_runners)
115
+ return plddts["model_1"]
116
+
117
+
118
+ def update_protGPT2(inp, length):
119
+ startsequence = inp
120
+ seqlen = length
121
+ generated_seqs = run_protgpt2(startsequence, seqlen)
122
+ gen_seqs = [x["generated_text"] for x in generated_seqs]
123
+ print(gen_seqs)
124
+ return gr.Radio.update(gen_seqs)
125
+
126
+
127
+ def update(inp):
128
+ print("Running AF on", inp)
129
+ startsequence = inp
130
+ plddts = run_alphafold(startsequence)
131
+ print(plddts)
132
+ x = np.arange(10)
133
+ plt.style.use(["seaborn-ticks", "seaborn-talk"])
134
+ fig = plt.figure()
135
+ ax = fig.add_subplot(111)
136
+ ax.plot(plddts)
137
+ ax.set_ylabel("predicted LDDT")
138
+ ax.set_xlabel("positions")
139
+ ax.set_title("pLDDT")
140
+
141
+ return (
142
+ molecule(
143
+ f"test_unrelaxed_model_1.pdb",
144
+ ),
145
+ fig,
146
+ f"{np.mean(plddts):.1f} ± {np.std(plddts):.1f}",
147
+ )
148
+
149
+
150
+ def read_mol(molpath):
151
+ with open(molpath, "r") as fp:
152
+ lines = fp.readlines()
153
+ mol = ""
154
+ for l in lines:
155
+ mol += l
156
+ return mol
157
+
158
+
159
+ def molecule(pdb):
160
+ mol = read_mol(pdb)
161
+ x = (
162
+ """<!DOCTYPE html>
163
+ <html>
164
+ <head>
165
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
166
+ <link rel="stylesheet" href="https://unpkg.com/flowbite@1.4.5/dist/flowbite.min.css" />
167
+ <style>
168
+ body{
169
+ font-family:sans-serif
170
+ }
171
+ .mol-container {
172
+ width: 100%;
173
+ height: 800px;
174
+ position: relative;
175
+ }
176
+ .space-x-2 > * + *{
177
+ margin-left: 0.5rem;
178
+ }
179
+ .p-1{
180
+ padding:0.5rem;
181
+ }
182
+ .flex{
183
+ display:flex;
184
+ align-items: center;
185
+ }
186
+ .w-4{
187
+ width:1rem;
188
+ }
189
+ .h-4{
190
+ height:1rem;
191
+ }
192
+ .mt-4{
193
+ margin-top:1rem;
194
+ }
195
+ </style>
196
+ <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
197
+ </head>
198
+ <body>
199
+
200
+ <div id="container" class="mol-container"></div>
201
+ <div class="flex">
202
+ <div class="px-4">
203
+ <label for="sidechain" class="relative inline-flex items-center mb-4 cursor-pointer ">
204
+ <input id="sidechain"type="checkbox" class="sr-only peer">
205
+ <div class="w-11 h-6 bg-gray-200 rounded-full peer peer-focus:ring-4 peer-focus:ring-blue-300 dark:peer-focus:ring-blue-800 dark:bg-gray-700 peer-checked:after:translate-x-full peer-checked:after:border-white after:absolute after:top-0.5 after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-5 after:w-5 after:transition-all dark:border-gray-600 peer-checked:bg-blue-600"></div>
206
+ <span class="ml-3 text-sm font-medium text-gray-900 dark:text-gray-300">Show side chains</span>
207
+ </label>
208
+ </div>
209
+ <button type="button" class="text-gray-900 bg-white hover:bg-gray-100 border border-gray-200 focus:ring-4 focus:outline-none focus:ring-gray-100 font-medium rounded-lg text-sm px-5 py-2.5 text-center inline-flex items-center dark:focus:ring-gray-600 dark:bg-gray-800 dark:border-gray-700 dark:text-white dark:hover:bg-gray-700 mr-2 mb-2" id="download">
210
+ <svg class="w-6 h-6 mr-2 -ml-1" fill="none" stroke="currentColor" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path></svg>
211
+ Download predicted structure
212
+ </button>
213
+ </div>
214
+ <div class="text-sm">
215
+ <div class="font-medium mt-4"><b>AlphaFold model confidence:</b></div>
216
+ <div class="flex space-x-2 py-1"><span class="w-4 h-4"
217
+ style="background-color: rgb(0, 83, 214);">&nbsp;</span><span class="legendlabel">Very high
218
+ (pLDDT &gt; 90)</span></div>
219
+ <div class="flex space-x-2 py-1"><span class="w-4 h-4"
220
+ style="background-color: rgb(101, 203, 243);">&nbsp;</span><span class="legendlabel">Confident
221
+ (90 &gt; pLDDT &gt; 70)</span></div>
222
+ <div class="flex space-x-2 py-1"><span class="w-4 h-4"
223
+ style="background-color: rgb(255, 219, 19);">&nbsp;</span><span class="legendlabel">Low (70 &gt;
224
+ pLDDT &gt; 50)</span></div>
225
+ <div class="flex space-x-2 py-1"><span class="w-4 h-4"
226
+ style="background-color: rgb(255, 125, 69);">&nbsp;</span><span class="legendlabel">Very low
227
+ (pLDDT &lt; 50)</span></div>
228
+ <div class="row column legendDesc"> AlphaFold produces a per-residue confidence
229
+ score (pLDDT) between 0 and 100. Some regions below 50 pLDDT may be unstructured in isolation.
230
+ </div>
231
+ </div>
232
+ <script>
233
+ let viewer = null;
234
+ let voldata = null;
235
+ $(document).ready(function () {
236
+
237
+ let element = $("#container");
238
+ let config = { backgroundColor: "white" };
239
+ viewer = $3Dmol.createViewer( element, config );
240
+ viewer.ui.initiateUI();
241
+ let data = `"""
242
+ + mol
243
+ + """`
244
+ viewer.addModel( data, "pdb" );
245
+ //AlphaFold code from https://gist.github.com/piroyon/30d1c1099ad488a7952c3b21a5bebc96
246
+ let colorAlpha = function (atom) {
247
+ if (atom.b < 50) {
248
+ return "OrangeRed";
249
+ } else if (atom.b < 70) {
250
+ return "Gold";
251
+ } else if (atom.b < 90) {
252
+ return "MediumTurquoise";
253
+ } else {
254
+ return "Blue";
255
+ }
256
+ };
257
+ viewer.setStyle({}, { cartoon: { colorfunc: colorAlpha } });
258
+ viewer.zoomTo();
259
+ viewer.render();
260
+ viewer.zoom(0.8, 2000);
261
+ viewer.getModel(0).setHoverable({}, true,
262
+ function (atom, viewer, event, container) {
263
+ console.log(atom)
264
+ if (!atom.label) {
265
+ atom.label = viewer.addLabel(atom.resn+atom.resi+" pLDDT=" + atom.b, { position: atom, backgroundColor: "mintcream", fontColor: "black" });
266
+ }
267
+ },
268
+ function (atom, viewer) {
269
+ if (atom.label) {
270
+ viewer.removeLabel(atom.label);
271
+ delete atom.label;
272
+ }
273
+ }
274
+ );
275
+ $("#sidechain").change(function () {
276
+ if (this.checked) {
277
+ BB = ["C", "O", "N"]
278
+ viewer.setStyle( {"and": [{resn: ["GLY", "PRO"], invert: true},{atom: BB, invert: true},]},{stick: {colorscheme: "WhiteCarbon", radius: 0.3}, cartoon: { colorfunc: colorAlpha }});
279
+ viewer.render()
280
+ } else {
281
+ viewer.setStyle({cartoon: { colorfunc: colorAlpha }});
282
+ viewer.render()
283
+ }
284
+ });
285
+ $("#download").click(function () {
286
+ download("gradioFold_model1.pdb", data);
287
+ })
288
+ });
289
+
290
+ function download(filename, text) {
291
+ var element = document.createElement("a");
292
+ element.setAttribute("href", "data:text/plain;charset=utf-8," + encodeURIComponent(text));
293
+ element.setAttribute("download", filename);
294
+
295
+ element.style.display = "none";
296
+ document.body.appendChild(element);
297
+
298
+ element.click();
299
+
300
+ document.body.removeChild(element);
301
+ }
302
+ </script>
303
+ </body></html>"""
304
+ )
305
+
306
+ return f"""<iframe style="width: 800px; height: 1200px" name="result" allow="midi; geolocation; microphone; camera;
307
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
308
+ allow-scripts allow-same-origin allow-popups
309
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
310
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
311
+
312
+
313
+ proteindream = gr.Blocks()
314
+
315
+ with proteindream:
316
+ gr.Markdown("# GradioFold")
317
+ gr.Markdown(
318
+ """GradioFold is a web-based tool that combines a large language model trained on natural protein sequence (protGPT2) with structure prediction using AlphaFold.
319
+ Type a start sequence or provide a sequence with blanks that protGPT2 can complete."""
320
+ )
321
+ gr.Markdown("## protGPT2")
322
+ gr.Markdown(
323
+ """
324
+ Enter a start sequence and have the language model complete it.
325
+ """
326
+ )
327
+ with gr.Group():
328
+ with gr.Row():
329
+ inp = gr.Textbox(placeholder="M", label="Start sequence")
330
+ length = gr.Number(value=50, label="Target sequence length")
331
+ btn = gr.Button("Autocomplete sequences")
332
+
333
+ seqs = [
334
+ "MTAEADPAPLAANPPAPVRPIQFHDVSVRYEARPWLRALWDVASGSFIGLLGASGAGKSTCVDLLNGVRKPSSGERFVRGQPSRGRKGRFNRRVAMVFQDVRHQLFSRSVAREIAFGLENLPTSAAAIDRRVS",
335
+ "MTAGIVAGGIAGGVAGYKAKKHRKAVKATMIAAGVSGGIGGGYIGEKFNRRLAKHEDRVRRSAPRHKKHSSYSKSSGEGGGILGKLFGR",
336
+ "MTAVLVAIALEMQNPHRMALAAVLCGQFTVAVAAEPFAPEGVAEGLNPLGDLLAESPLLEVVSATLALLVALGTATSLSWISGPVSALPAPSFQSSETPYPQRPIERESFDQDSREEDPWDRL",
337
+ "MTARVRNRSSSRSYVLDFADLADGQREVLLPESRGNASEVDLPAGTTVNVTIDVTASGTGTLTARTPDGADVVSNEYELTVERDTDLTRVETESPQVAAGETATVTGTAENVGTVAGEREVTAYVDGE",
338
+ "MTAAGWREEGTPFARIARQLGRHVTSVRQAAGRVRQQMGLTSPDPADPPRSGPTPTIPIEQERA",
339
+ ]
340
+ seqChoice = gr.Radio(seqs, label="Generated sequences")
341
+ btn.click(fn=update_protGPT2, inputs=[inp, length], outputs=seqChoice)
342
+ gr.Markdown("## AlphaFold")
343
+ gr.Markdown(
344
+ "Select a generated sequence above for structure prediction using AlphaFold2."
345
+ )
346
+ with gr.Group():
347
+ chosenSeq = gr.Textbox(label="Chosen sequence")
348
+ btn2 = gr.Button("Predict structure")
349
+ with gr.Group():
350
+ meanpLDDT = gr.Textbox(label="Mean pLDDT of chosen sequence")
351
+ with gr.Row():
352
+ mol = gr.HTML()
353
+ plot = gr.Plot(label="pLDDT")
354
+ gr.Markdown(
355
+ """## Acknowledgements
356
+ This was a fun demo using Gradio, Huggingface and ColabFold. More information about the used algorithms can be found below.
357
+
358
+ All code is available on [Github]() and licensed under MIT license.
359
+
360
+ - ProtGPT2: Ferruz et.al [BioRxiv](https://doi.org/10.1101/2022.03.09.483666) [Code](https://huggingface.co/nferruz/ProtGPT2)
361
+ - AlphaFold2: Jumper et.al [Paper](https://doi.org/10.1038/s41586-021-03819-2) [Code](https://github.com/deepmind/alphafold) Model parameters released under CC BY 4.0
362
+ - ColabFold: Mirdita et.al [Paper](https://doi.org/10.1101/2021.08.15.456425 ) [Code](https://github.com/sokrypton/ColabFold)
363
+
364
+ Created by [@simonduerr](https://twitter.com/simonduerr)
365
+ """
366
+ )
367
+ seqChoice.change(fn=update_seqs, inputs=seqChoice, outputs=chosenSeq)
368
+ btn2.click(fn=update, inputs=seqChoice, outputs=[mol, plot, meanpLDDT])
369
+
370
+ proteindream.launch(share=False)