File size: 7,009 Bytes
f0a0a4a
 
671732c
93085f8
f0a0a4a
238e0cb
f0a0a4a
 
81c6fa4
b34ccc3
 
 
 
a2780d6
f0a0a4a
81c6fa4
 
f0a0a4a
 
 
 
 
fb4e118
186c0c1
a432919
186c0c1
ba6d9e2
 
 
f0a0a4a
a432919
f0a0a4a
 
72d0321
e0dd23e
f0a0a4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ea9225
f0a0a4a
 
 
 
2ea9225
 
1e72178
63af266
 
1e72178
63af266
 
 
1e72178
63af266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0a0a4a
a432919
 
 
e0dd23e
a43a5b0
e0dd23e
e8e6698
63af266
 
 
 
e0dd23e
 
 
 
 
 
63af266
a432919
e0dd23e
63af266
 
e0dd23e
f0a0a4a
 
63af266
df73a2b
 
e1b2bb3
f37a3ce
 
 
161ad07
fbde604
 
 
 
 
562368d
 
 
7f70a94
 
 
562368d
 
 
7f70a94
 
 
f37a3ce
ba6d9e2
a218a91
a432919
 
a218a91
 
 
3d33b8d
2c39dc2
 
3d33b8d
a218a91
3d33b8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
from transformers import DonutProcessor, VisionEncoderDecoderModel

pretrained_repo_name = 'ivelin/donut-refexp-combined-v1'
pretrained_revision = 'main'
# revision: '348ddad8e958d370b7e341acd6050330faa0500f' # Iou = 0.47
# revision: '41210d7c42a22e77711711ec45508a6b63ec380f' # : IoU=0.42 
# use 'main' for latest revision
print(f"Loading model checkpoint: {pretrained_repo_name}")

processor = DonutProcessor.from_pretrained(pretrained_repo_name, revision=pretrained_revision)
model = VisionEncoderDecoderModel.from_pretrained(pretrained_repo_name, revision=pretrained_revision)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


def process_refexp(image: Image, prompt: str):

    print(f"(image, prompt): {image}, {prompt}")

    # trim prompt to 80 characters and normalize to lowercase
    prompt = prompt[:80].lower()

    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values

    # prepare decoder inputs
    task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_bounding_box>"
    prompt = task_prompt.replace("{user_input}", prompt)
    decoder_input_ids = processor.tokenizer(
        prompt, add_special_tokens=False, return_tensors="pt").input_ids

    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    print(fr"predicted decoder sequence: {html.escape(sequence)}")
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
        processor.tokenizer.pad_token, "")
    # remove first task start token
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
    print(
        fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
    seqjson = processor.token2json(sequence)

    # safeguard in case predicted sequence does not include a target_bounding_box token
    bbox = seqjson.get('target_bounding_box')
    if bbox is None:
        print(
            f"token2bbox seq has no predicted target_bounding_box, seq:{seq}")
        bbox = {"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0}
        return bbox

    print(f"predicted bounding box with text coordinates: {bbox}")
    # safeguard in case text prediction is missing some bounding box coordinates
    # or coordinates are not valid numeric values
    try:
        xmin = float(bbox.get("xmin", 0))
    except ValueError:
        xmin = 0
    try:
        ymin = float(bbox.get("ymin", 0))
    except ValueError:
        ymin = 0
    try:
        xmax = float(bbox.get("xmax", 1))
    except ValueError:
        xmax = 1
    try:
        ymax = float(bbox.get("ymax", 1))
    except ValueError:
        ymax = 1
    # replace str with float coords
    bbox = {"xmin": xmin, "ymin": ymin, "xmax": xmax,
            "ymax": ymax, "decoder output sequence": sequence}
    print(f"predicted bounding box with float coordinates: {bbox}")

    print(f"image object: {image}")
    print(f"image size: {image.size}")
    width, height = image.size
    print(f"image width, height: {width, height}")
    print(f"processed prompt: {prompt}")

    # safeguard in case text prediction is missing some bounding box coordinates
    xmin = math.floor(width*bbox["xmin"])
    ymin = math.floor(height*bbox["ymin"])
    xmax = math.floor(width*bbox["xmax"])
    ymax = math.floor(height*bbox["ymax"])

    print(
        f"to image pixel values: xmin, ymin, xmax, ymax: {xmin, ymin, xmax, ymax}")

    shape = [(xmin, ymin), (xmax, ymax)]

    # deaw bbox rectangle
    img1 = ImageDraw.Draw(image)
    img1.rectangle(shape, outline="green", width=5)
    img1.rectangle(shape, outline="white", width=2)

    return image, bbox


title = "Demo: Donut 🍩 for UI RefExp (by GuardianUI)"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on [UIBert RefExp](https://huggingface.co/datasets/ivelin/ui_refexp_saved) Dataset (UI Referring Expression). To use it, simply upload your image and type a prompt and click 'submit', or click one of the examples to load them. See the model training <a href='https://colab.research.google.com/github/ivelin/donut_ui_refexp/blob/main/Fine_tune_Donut_on_UI_RefExp.ipynb' target='_parent'>Colab Notebook</a> for this space. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the setting icon from top right corner"],
            ["example_1.jpg", "click on down arrow beside the entertainment"],
            ["example_1.jpg", "select the down arrow button beside lifestyle"],
            ["example_1.jpg", "click on the image beside the option traffic"],
            ["example_2.jpg", "enter the text field next to the name"],
            ["example_3.jpg", "select the third row first image"],
            ["example_3.jpg", "click the tick mark on the first image"],
            ["example_3.jpg", "select the ninth image"],
            ["example_3.jpg", "select the add icon"],
            ["example_3.jpg", "click the first image"],
            ["val-image-4.jpg", 'select 4153365454'],
            ['val-image-4.jpg', 'go to cell']
            ['val-image-4.jpg', 'select number above cell']
            ["val-image-1.jpg", "select calendar option"],
            ["val-image-1.jpg", "select photos&videos option"],
            ["val-image-2.jpg", "click on change store"],
            ["example_2.jpg", "click on green color button"],
            ["example_2.jpg", "click on text which is beside call now"],
            ["example_2.jpg", "click on more button"],
            ["val-image-2.jpg", "click on shop menu at the bottom"],
            ["val-image-3.jpg", "click on image above short meow"],
            ["val-image-3.jpg", "go to cat sounds"],
            ]

demo = gr.Interface(fn=process_refexp,
                    inputs=[gr.Image(type="pil"), "text"],
                    outputs=[gr.Image(type="pil"), "json"],
                    title=title,
                    description=description,
                    article=article,
                    examples=examples,
                    # caching examples inference takes too long to start space after app change commit
                    cache_examples=False
                    )

demo.launch()