File size: 2,275 Bytes
9301987
c1b4f26
9301987
 
1d3d5c8
c1b4f26
1d3d5c8
6d1520c
 
 
1d3d5c8
6d1520c
 
c1b4f26
 
 
 
 
 
 
1d3d5c8
c1b4f26
9301987
1d3d5c8
 
 
 
 
9301987
c1b4f26
1d3d5c8
 
 
 
 
 
 
 
 
9301987
c1b4f26
 
 
 
 
 
 
9301987
 
 
 
c1b4f26
1d3d5c8
c1b4f26
b2813ce
1d3d5c8
6d1520c
1d3d5c8
 
 
 
 
b2813ce
1d3d5c8
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import glob
import streamlit as st

from streamlit_image_select import image_select
import streamlit.components.v1 as components

# Trick to not init function multitime
if "model" not in st.session_state:
    print("INIT MODEL")
    from src.model import Model

    st.session_state.model = Model()
    print("DONE INIT MODEL")

st.set_page_config(page_title="VQA", layout="wide")
hide_menu_style = """
<style>
footer {visibility: hidden;}
</style>
"""
st.markdown(hide_menu_style, unsafe_allow_html=True)

mapper = {
    "images/000000000645.jpg": "Đây là đâu",
    "images/000000000661.jpg": "Tốc độ tối đa trên đoạn đường này là bao nhiêu",
    "images/000000000674.jpg": "Còn bao xa nữa là tới Huế",
    "images/000000000706.jpg": "Cầu này dài bao nhiêu",
    "images/000000000777.jpg": "Chè khúc bạch giá bao nhiêu",
}

image = st.file_uploader(
    "Choose an image file",
    type=[
        "jpg",
        "jpeg",
        "png",
        "webp",
    ],
)
example = image_select("Examples", glob.glob("images/*.jpg"))

if image:
    bytes_data = image.getvalue()
    with open("test.png", "wb") as f:
        f.write(bytes_data)
    f.close()
    st.session_state.image = "test.png"
    st.session_state.question = ""
else:
    st.session_state.question = mapper[example]
    st.session_state.image = example

if "image" in st.session_state:
    st.image(st.session_state.image)
    question = st.text_input("**Question:** ", value=st.session_state.question)
    visualize = True
    if question:
        answer, text_attention_html, images_visualize = (
            st.session_state.model.inference(
                st.session_state.image, question, visualize
            )
        )
        st.write(f"**Answer:** {answer}")

        if visualize:
            st.write("**Explanation**")
            col1, col2 = st.columns([1, 2])
            # st.markdown(text_attention_html, unsafe_allow_html=True)
            with col1:
                st.write("*Text Attention*")
                components.html(text_attention_html, height=960, scrolling=True)

            with col2:
                st.write("*Image Attention*")
                for image_visualize in images_visualize:
                    st.image(image_visualize)