File size: 6,992 Bytes
18f2f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c1e42b
18f2f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a287fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18f2f54
 
8a287fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c1e42b
8a287fa
 
 
 
 
 
 
 
 
 
 
18f2f54
8a287fa
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import streamlit as st
import numpy as np

from plotly.subplots import make_subplots
import plotly.graph_objects as go

import graphviz

from backend.maximally_activating_patches import load_layer_infos, load_activation, get_receptive_field_coordinates
from frontend import on_click_graph
from backend.utils import load_dataset_dict

HIGHTLIGHT_COLOR = '#e7bcc5'
st.set_page_config(layout='wide')

# -------------------------- LOAD DATASET ---------------------------------
dataset_dict = load_dataset_dict()

# -------------------------- LOAD GRAPH -----------------------------------

def load_dot_to_graph(filename):
    dot = graphviz.Source.from_file(filename)
    source_lines = str(dot).splitlines()
    source_lines.pop(0)
    source_lines.pop(-1)
    graph = graphviz.Digraph()
    graph.body += source_lines
    return graph, dot
    
st.title('Maximally activating image patches')
st.write('Visualize image patches that maximize the activation of layers in ConvNeXt model')

# st.header('ConvNeXt')
convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
convnext_graph = load_dot_to_graph(convnext_dot_file)[0]

convnext_graph.graph_attr['size'] = '4,40'

# -------------------------- DISPLAY GRAPH -----------------------------------

def chosen_node_text(clicked_node_title):
    clicked_node_title = clicked_node_title.replace('stage ', 'stage_').replace('block ', 'block_')
    stage_id = clicked_node_title.split()[0].split('_')[1] if 'stage' in clicked_node_title else None
    block_id = clicked_node_title.split()[1].split('_')[1] if 'block' in clicked_node_title else None
    layer_id = clicked_node_title.split()[-1]
    
    if 'embeddings' in layer_id:
        display_text = 'Patchify layer'
        activation_key = 'embeddings.patch_embeddings'
    elif 'downsampling' in layer_id:
        display_text = f'Stage {stage_id} > Downsampling layer'
        activation_key = f'encoder.stages[{stage_id}].downsampling_layer[1]'
    else:
        display_text = f'Stage {stage_id} > Block {block_id} > {layer_id} layer'
        activation_key = f'encoder.stages[{int(stage_id)-1}].layers[{int(block_id)-1}].{layer_id}'
    return display_text, activation_key


props = {
    'hightlight_color': HIGHTLIGHT_COLOR,
    'initial_state': {
        'group_1_header': 'Choose an option from group 1',
        'group_2_header': 'Choose an option from group 2'
    }
}


col1, col2 = st.columns((2,5))
col1.markdown("#### Architecture")
col1.write('')
col1.write('Click on a layer below to generate top-k maximally activating image patches')
col1.graphviz_chart(convnext_graph)

with col2:
    st.markdown("#### Output")
    nodes = on_click_graph(key='toggle_buttons', **props)

# -------------------------- DISPLAY OUTPUT -----------------------------------

if nodes != None:
    clicked_node_title = nodes["choice"]["node_title"]
    clicked_node_id = nodes["choice"]["node_id"]
    display_text, activation_key = chosen_node_text(clicked_node_title)
    col2.write(f'**Chosen layer:** {display_text}')
    # col2.write(f'**Activation key:** {activation_key}')

    hightlight_syle = f'''
        <style>
            div[data-stale]:has(iframe) {{
                height: 0;
            }}
            #{clicked_node_id}>polygon {{
                fill: {HIGHTLIGHT_COLOR};
                stroke: {HIGHTLIGHT_COLOR};
            }}
        </style>
    '''
    col2.markdown(hightlight_syle, unsafe_allow_html=True)

    with col2:
        layer_infos = None
        with st.form('top_k_form'):
            activation_path = './data/activation/convnext_activation.json'
            activation = load_activation(activation_path)
            num_channels = activation[activation_key].shape[1]

            top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10)
            channel_start, channel_end = st.slider(
                'Choose channel range of this layer (recommend to choose small range less than 30)',
                1, num_channels, value=(1, 30))
            summit_button = st.form_submit_button('Generate image patches')
            if summit_button:
                
                activation = activation[activation_key][:top_k,:,:]
                layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json')
                # st.write(channel_start, channel_end)
                # st.write(activation.shape, activation.shape[1])

        if layer_infos != None:
            num_cols, num_rows = top_k, channel_end - channel_start + 1
            # num_rows = activation.shape[1]
            top_k_coor_max_ = activation
            st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})")

            for row in range(channel_start, channel_end+1):
                if row == channel_start:
                    top_margin = 50
                    fig = make_subplots(
                        rows=1, cols=num_cols, 
                        subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
                else:
                    top_margin = 0
                    fig = make_subplots(rows=1, cols=num_cols, shared_yaxes=True)
                for col in range(1, num_cols+1):
                    k, c = col-1, row-1
                    img_index = int(top_k_coor_max_[k, c, 3])
                    activation_value = top_k_coor_max_[k, c, 0]
                    img = dataset_dict[img_index//10_000][img_index%10_000]['image']
                    class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
                    class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']

                    idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2]
                    x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y)
                    img = np.array(img)[y1:y2, x1:x2, :]
                    
                    hovertemplate = f"""Top-{col}<br>Activation value: {activation_value:.5f}<br>Class Label: {class_label}<br>Class id: {class_id}<br>Image id: {img_index}"""
                    fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col)
                    fig.update_xaxes(showticklabels=False, showgrid=False)
                    fig.update_yaxes(showticklabels=False, showgrid=False)
                    fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0})
                    fig.update_layout(showlegend=False, yaxis_title=row)
                    fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
                    fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7"))
                st.plotly_chart(fig, use_container_width=True)


else:
    col2.markdown(f'Chosen layer: <code>None</code>', unsafe_allow_html=True)
    col2.markdown("""<style>div[data-stale]:has(iframe) {height: 0};""", unsafe_allow_html=True)