Spaces:
Sleeping
Sleeping
Move app.py to top level
Browse filesTo see if spaces finds the file then, it does not
find it in the hexviz directory, with the current
README.md config.
app.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import py3Dmol
|
2 |
+
import stmol
|
3 |
+
import streamlit as st
|
4 |
+
from stmol import showmol
|
5 |
+
|
6 |
+
from hexviz.attention import Model, ModelType, get_attention_pairs
|
7 |
+
|
8 |
+
st.sidebar.title("pLM Attention Visualization")
|
9 |
+
|
10 |
+
st.title("pLM Attention Visualization")
|
11 |
+
|
12 |
+
# Define list of model types
|
13 |
+
models = [
|
14 |
+
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
15 |
+
]
|
16 |
+
|
17 |
+
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
|
18 |
+
selected_model = next((model for model in models if model.name.value == selected_model_name), None)
|
19 |
+
|
20 |
+
pdb_id = st.text_input("PDB ID", "4RW0")
|
21 |
+
|
22 |
+
left, right = st.columns(2)
|
23 |
+
with left:
|
24 |
+
layer_one = st.number_input("Layer", value=1, min_value=1, max_value=selected_model.layers)
|
25 |
+
layer = layer_one - 1
|
26 |
+
with right:
|
27 |
+
head_one = st.number_input("Head", value=1, min_value=1, max_value=selected_model.heads)
|
28 |
+
head = head_one - 1
|
29 |
+
|
30 |
+
min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
|
31 |
+
|
32 |
+
attention_pairs = get_attention_pairs(pdb_id, layer, head, min_attn, model_type=selected_model.name)
|
33 |
+
|
34 |
+
def get_3dview(pdb):
|
35 |
+
xyzview = py3Dmol.view(query=f"pdb:{pdb}")
|
36 |
+
xyzview.setStyle({"cartoon": {"color": "spectrum"}})
|
37 |
+
stmol.add_hover(xyzview, backgroundColor="black", fontColor="white")
|
38 |
+
for att_weight, first, second in attention_pairs:
|
39 |
+
stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight*3, cylColor='red', dashed=False)
|
40 |
+
return xyzview
|
41 |
+
|
42 |
+
|
43 |
+
xyzview = get_3dview(pdb_id)
|
44 |
+
showmol(xyzview, height=500, width=800)
|