CNNs-interpretation-visualization / pages /3_Adversarial_attack.py
taquynhnga's picture
added adversarial attacks & change to streamlit 1.19.0
0c1e42b
raw
history blame
No virus
7.86 kB
import streamlit as st
import pandas as pd
import numpy as np
import random
from backend.utils import make_grid, load_dataset, load_model, load_image
from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img, LoadImage, ShowHeatMap, ShowMaskedImage
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
from matplotlib.backends.backend_agg import RendererAgg
from backend.adversarial_attack import *
_lock = RendererAgg.lock
st.set_page_config(layout='wide')
BACKGROUND_COLOR = '#bcd0e7'
SECONDARY_COLOR = '#bce7db'
st.title('Adversarial Attack')
st.write('How adversarial attacks affect ConvNeXt interpretation?')
imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
image_id = None
if 'image_id' not in st.session_state:
st.session_state.image_id = 0
# def on_change_random_input():
# st.session_state.image_id = st.session_state.image_id
# ----------------------------- INPUT ----------------------------------
st.header('Input')
input_col_1, input_col_2, input_col_3 = st.columns(3)
# --------------------------- INPUT column 1 ---------------------------
with input_col_1:
with st.form('image_form'):
# image_id = st.number_input('Image ID: ', format='%d', step=1)
st.write('**Choose or generate a random image**')
chosen_image_id_input = st.empty()
image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
choose_image_button = st.form_submit_button('Choose the defined image')
random_id = st.form_submit_button('Generate a random image')
if random_id:
image_id = random.randint(0, 50000)
st.session_state.image_id = image_id
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
if choose_image_button:
image_id = int(image_id)
st.session_state.image_id = int(image_id)
# st.write(image_id, st.session_state.image_id)
# ---------------------------- SET UP OUTPUT ------------------------------
epsilon_container = st.empty()
st.header('Output')
st.subheader('Perform attack')
# perform attack container
header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
# prediction error container
error_container = st.empty()
smoothgrad_header_container = st.empty()
# smoothgrad container
smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
original_image_dict = load_image(st.session_state.image_id)
input_image = original_image_dict['image']
input_label = original_image_dict['label']
input_id = original_image_dict['id']
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
with output_col_1:
pred_prob, pred_class_id, pred_class_label = feed_forward(input_image)
# st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
st.image(input_image)
header_col_1.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.1f}% confidence')
if pred_class_id != (input_id-1):
with error_container.container():
st.write(f'Predicted output: Class ID {pred_class_id} - {pred_class_label} {pred_prob*100:.1f}% confidence')
st.error('ConvNeXt misclassified the chosen image. Please choose or generate another image.',
icon = "🚫")
# ----------------------------- INPUT column 2 & 3 ----------------------------
with input_col_2:
with st.form('epsilon_form'):
st.write('**Set epsilon or find the smallest epsilon automatically**')
chosen_epsilon_input = st.empty()
epsilon = chosen_epsilon_input.number_input('Epsilon:', min_value=0.001, format='%.3f', step=0.001)
epsilon_button = st.form_submit_button('Choose the defined epsilon')
find_epsilon = st.form_submit_button('Find the smallest epsilon automatically')
with input_col_3:
with st.form('iterate_epsilon_form'):
max_epsilon = st.number_input('Maximum value of epsilon (Optional setting)', value=0.500, format='%.3f')
step_epsilon = st.number_input('Step (Optional setting)', value=0.001, format='%.3f')
setting_button = st.form_submit_button('Set iterating mode')
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
if pred_class_id == (input_id-1) and (epsilon_button or find_epsilon or setting_button):
with output_col_3:
if epsilon_button:
perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, epsilon)
else:
epsilons = [i*step_epsilon for i in range(1, 1001) if i*step_epsilon <= max_epsilon]
epsilon_container.progress(0, text='Checking epsilon')
for i, e in enumerate(epsilons):
print(e)
perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, e)
epsilon_container.progress(i/len(epsilons), text=f'Checking epsilon={e:.3f}. Confidence={new_prob*100:.1f}%')
epsilon = e
if new_id != input_id - 1:
epsilon_container.empty()
st.balloons()
break
if i == len(epsilons)-1:
epsilon_container.error(f'FSGM failed to attack on this image at epsilon={e:.3f}. Set higher maximum value of epsilon or choose another image',
icon = "🚫")
perturbed_image = deprocess_image(perturbed_data.detach().numpy())[0].astype(np.uint8).transpose(1,2,0)
perturbed_amount = perturbed_image - input_image
header_col_3.write(f'Pertubed amount - epsilon={epsilon:.3f}')
st.image(ShowImage(perturbed_amount))
with output_col_2:
# st.write('plus sign')
st.image(LoadImage('frontend/images/plus-sign.png'))
with output_col_4:
# st.write('equal sign')
st.image(LoadImage('frontend/images/equal-sign.png'))
# ---------------------------- DISPLAY COL 5 ROW 1 ------------------------------
with output_col_5:
# st.write(f'ID {new_id+1} - {new_label}: {new_prob*100:.3f}% confidence')
st.image(ShowImage(perturbed_image))
header_col_5.write(f'Class ID {new_id+1} - {new_label}: {new_prob*100:.1f}% confidence')
# -------------------------- DISPLAY SMOOTHGRAD ---------------------------
smoothgrad_header_container.subheader('SmoothGrad visualization')
with smoothgrad_col_1:
smooth_head_1.write(f'SmoothGrad before attacked')
heatmap_image, masked_image, mask = generate_images(st.session_state.image_id, epsilon=0)
st.image(heatmap_image)
st.image(masked_image)
with smoothgrad_col_3:
smooth_head_3.write('SmoothGrad after attacked')
heatmap_image_attacked, masked_image_attacked, attacked_mask= generate_images(st.session_state.image_id, epsilon=epsilon)
st.image(heatmap_image_attacked)
st.image(masked_image_attacked)
with smoothgrad_col_2:
st.image(LoadImage('frontend/images/minus-sign-5.png'))
with smoothgrad_col_5:
smooth_head_5.write('SmoothGrad difference')
difference_mask = abs(attacked_mask-mask)
st.image(ShowHeatMap(difference_mask))
masked_image = ShowMaskedImage(difference_mask, perturbed_image)
st.image(masked_image)
with smoothgrad_col_4:
st.image(LoadImage('frontend/images/equal-sign.png'))