merve HF staff commited on
Commit
7decfba
β€’
1 Parent(s): 2fd0925

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ from PIL import Image
4
+ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
5
+
6
+ def infer_infographics(image, question):
7
+ model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-ai2d-base").to("cuda")
8
+ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")
9
+
10
+ inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
11
+
12
+ predictions = model.generate(**inputs)
13
+ return processor.decode(predictions[0], skip_special_tokens=True)
14
+
15
+ def infer_ui(image, question):
16
+ model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-screen2words-base").to("cuda")
17
+ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-screen2words-base")
18
+
19
+ inputs = processor(images=image,text=question, return_tensors="pt").to("cuda")
20
+
21
+ predictions = model.generate(**inputs)
22
+ return processor.decode(predictions[0], skip_special_tokens=True)
23
+
24
+ def infer_chart(image, question):
25
+ model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-chartqa-base").to("cuda")
26
+ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-chartqa-base")
27
+
28
+ inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
29
+
30
+ predictions = model.generate(**inputs)
31
+ return processor.decode(predictions[0], skip_special_tokens=True)
32
+
33
+ def infer_doc(image, question):
34
+ model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-base").to("cuda")
35
+ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-base")
36
+ inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
37
+ predictions = model.generate(**inputs)
38
+ return processor.decode(predictions[0], skip_special_tokens=True)
39
+
40
+ css = """
41
+ #mkd {
42
+ height: 500px;
43
+ overflow: auto;
44
+ border: 1px solid #ccc;
45
+ }
46
+ """
47
+
48
+ with gr.Blocks(css=css) as demo:
49
+ gr.HTML("<h1><center>Pix2Struct πŸ“„<center><h1>")
50
+ gr.HTML("<h3><center>Pix2Struct is a powerful backbone for visual question answering. ⚑</h3>")
51
+ gr.HTML("<h3><center>Each tab in this app demonstrates Pix2Struct models fine-tuned on document question answering, infographics question answering, question answering on user interfaces, and charts. πŸ“„πŸ“±πŸ“Š<h3>")
52
+ gr.HTML("<h3><center>This app has base versions of each model. For better performance, use large checkpoints.<h3>")
53
+
54
+ with gr.Tab(label="Visual Question Answering over Documents"):
55
+ with gr.Row():
56
+ with gr.Column():
57
+ input_img = gr.Image(label="Input Document")
58
+ question = gr.Text(label="Question")
59
+ submit_btn = gr.Button(label="Submit")
60
+ output = gr.Text(label="Answer")
61
+ gr.Examples(
62
+ [["docvqa_example.png", "How many items are sold?"]],
63
+ inputs = [input_img, question],
64
+ outputs = [output],
65
+ fn=infer_doc,
66
+ cache_examples=True,
67
+ label='Click on any Examples below to get Document Question Answering results quickly πŸ‘‡'
68
+ )
69
+
70
+ submit_btn.click(infer_doc, [input_img, question], [output])
71
+
72
+ with gr.Tab(label="Visual Question Answering over Infographics"):
73
+ with gr.Row():
74
+ with gr.Column():
75
+ input_img = gr.Image(label="Input Image")
76
+ question = gr.Text(label="Question")
77
+ submit_btn = gr.Button(label="Submit")
78
+ output = gr.Text(label="Answer")
79
+ gr.Examples(
80
+ [["infographics_example.jpeg", "What is this infographic about?"]],
81
+ inputs = [input_img, question],
82
+ outputs = [output],
83
+ fn=infer_doc,
84
+ cache_examples=True,
85
+ label='Click on any Examples below to get Infographics QA results quickly πŸ‘‡'
86
+ )
87
+
88
+ submit_btn.click(infer_infographics, [input_img, question], [output])
89
+ with gr.Tab(label="Caption User Interfaces"):
90
+ with gr.Row():
91
+ with gr.Column():
92
+ input_img = gr.Image(label="Input UI Image")
93
+ question = gr.Text(label="Question")
94
+ submit_btn = gr.Button(label="Submit")
95
+ output = gr.Text(label="Caption")
96
+ submit_btn.click(infer_chart, [input_img, question], [output])
97
+ gr.Examples(
98
+ [["screen2words_ui_example.png", "What is this UI about?"]],
99
+ inputs = [input_img, question],
100
+ outputs = [output],
101
+ fn=infer_doc,
102
+ cache_examples=True,
103
+ label='Click on any Examples below to get UI question answering results quickly πŸ‘‡'
104
+ )
105
+
106
+ with gr.Tab(label="Ask about Charts"):
107
+ with gr.Row():
108
+ with gr.Column():
109
+ input_img = gr.Image(label="Input Chart")
110
+ question = gr.Text(label="Question")
111
+ submit_btn = gr.Button(label="Submit")
112
+ output = gr.Text(label="Caption")
113
+
114
+ submit_btn.click(infer_chart, [input_img, question], [output])
115
+ gr.Examples(
116
+ [["chartqa_example.png", "How much percent is bicycle?"]],
117
+ inputs = [input_img, question],
118
+ outputs = [output],
119
+ fn=infer_doc,
120
+ cache_examples=True,
121
+ label='Click on any Examples below to get Chart question answering results quickly πŸ‘‡'
122
+ )
123
+
124
+ demo.launch(debug=True)