aksell commited on
Commit
59008d0
1 Parent(s): 53ba5a6

Move app.py to top level

Browse files

To see if spaces finds the file then, it does not
find it in the hexviz directory, with the current
README.md config.

Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import py3Dmol
2
+ import stmol
3
+ import streamlit as st
4
+ from stmol import showmol
5
+
6
+ from hexviz.attention import Model, ModelType, get_attention_pairs
7
+
8
+ st.sidebar.title("pLM Attention Visualization")
9
+
10
+ st.title("pLM Attention Visualization")
11
+
12
+ # Define list of model types
13
+ models = [
14
+ Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
15
+ ]
16
+
17
+ selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
18
+ selected_model = next((model for model in models if model.name.value == selected_model_name), None)
19
+
20
+ pdb_id = st.text_input("PDB ID", "4RW0")
21
+
22
+ left, right = st.columns(2)
23
+ with left:
24
+ layer_one = st.number_input("Layer", value=1, min_value=1, max_value=selected_model.layers)
25
+ layer = layer_one - 1
26
+ with right:
27
+ head_one = st.number_input("Head", value=1, min_value=1, max_value=selected_model.heads)
28
+ head = head_one - 1
29
+
30
+ min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
31
+
32
+ attention_pairs = get_attention_pairs(pdb_id, layer, head, min_attn, model_type=selected_model.name)
33
+
34
+ def get_3dview(pdb):
35
+ xyzview = py3Dmol.view(query=f"pdb:{pdb}")
36
+ xyzview.setStyle({"cartoon": {"color": "spectrum"}})
37
+ stmol.add_hover(xyzview, backgroundColor="black", fontColor="white")
38
+ for att_weight, first, second in attention_pairs:
39
+ stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight*3, cylColor='red', dashed=False)
40
+ return xyzview
41
+
42
+
43
+ xyzview = get_3dview(pdb_id)
44
+ showmol(xyzview, height=500, width=800)