gitlost-murali commited on
Commit
da59cbe
1 Parent(s): 33024b0

initial checkpoint inference push

Browse files
Files changed (4) hide show
  1. Dockerfile +24 -0
  2. app.py +116 -0
  3. requirements.txt +4 -0
  4. utils.py +144 -0
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM ubuntu:22.04
5
+ # install curl
6
+ RUN apt-get update && apt-get install -y curl && apt-get install -y git && \
7
+ curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \
8
+ apt-get install -y git-lfs
9
+
10
+ WORKDIR /code
11
+
12
+ RUN git lfs clone https://huggingface.co/AskUI/pta-text-0.1 /code/model/
13
+
14
+ COPY ./requirements.txt /code/requirements.txt
15
+
16
+ RUN apt-get install -y python3 python3-pip
17
+
18
+ # RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
19
+ RUN pip install --upgrade -r /code/requirements.txt
20
+
21
+
22
+ COPY . .
23
+
24
+ CMD ["python3", "app.py", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image, ImageDraw
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch
7
+ from transformers import Pix2StructProcessor, Pix2StructVisionModel
8
+ from utils import download_default_font, render_header
9
+
10
+ class Pix2StructForRegression(nn.Module):
11
+ def __init__(self, sourcemodel_path, device):
12
+ super(Pix2StructForRegression, self).__init__()
13
+ self.model = Pix2StructVisionModel.from_pretrained(sourcemodel_path)
14
+ print("Pix2StructForRegression Model is Loaded...")
15
+ self.regression_layer1 = nn.Linear(768, 1536)
16
+ self.dropout1 = nn.Dropout(0.1)
17
+ self.regression_layer2 = nn.Linear(1536, 768)
18
+ self.dropout2 = nn.Dropout(0.1)
19
+ self.regression_layer3 = nn.Linear(768, 2)
20
+ self.device = device
21
+ print("Regression Layers are Loaded...")
22
+
23
+ def forward(self, *args, **kwargs):
24
+ outputs = self.model(*args, **kwargs)
25
+ sequence_output = outputs.last_hidden_state
26
+ first_token_output = sequence_output[:, 0, :]
27
+
28
+ x = F.relu(self.regression_layer1(first_token_output))
29
+ x = F.relu(self.regression_layer2(x))
30
+ regression_output = torch.sigmoid(self.regression_layer3(x))
31
+
32
+ return regression_output
33
+
34
+ def load_state_dict_file(self, checkpoint_path, strict=True):
35
+ print("Loading Model Weights...")
36
+ state_dict = torch.load(checkpoint_path, map_location=self.device)
37
+ self.load_state_dict(state_dict, strict=strict)
38
+ print("Model Weights are Loaded...")
39
+
40
+ class Inference:
41
+ def __init__(self) -> None:
42
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ self.model, self.processor = self.load_model_and_processor("matcha-base", "model/pta-text-v0.1.pt")
44
+ print("Model and Processor are Loaded...")
45
+
46
+ def load_model_and_processor(self, model_name, checkpoint_path):
47
+ model = Pix2StructForRegression(sourcemodel_path=model_name, device=self.device)
48
+ model.load_state_dict_file(checkpoint_path=checkpoint_path)
49
+ model.eval()
50
+ model = model.to(self.device)
51
+ processor = Pix2StructProcessor.from_pretrained(model_name, is_vqa=False)
52
+ return model, processor
53
+
54
+ def prepare_image(self, image, prompt, processor):
55
+ image = image.resize((1920, 1080))
56
+ download_default_font_path = download_default_font()
57
+ rendered_image, _, render_variables = render_header(
58
+ image=image,
59
+ header=prompt,
60
+ bbox={"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0},
61
+ font_path=download_default_font_path,
62
+ )
63
+ encoding = processor(
64
+ images=rendered_image,
65
+ max_patches=2048,
66
+ add_special_tokens=True,
67
+ return_tensors="pt",
68
+ )
69
+ return encoding, render_variables
70
+
71
+ def predict_coordinates(self, encoding, model, render_variables):
72
+ with torch.no_grad():
73
+ pred_regression_outs = model(flattened_patches=encoding["flattened_patches"], attention_mask=encoding["attention_mask"])
74
+ new_height = render_variables["height"]
75
+ new_header_height = render_variables["header_height"]
76
+ new_total_height = render_variables["total_height"]
77
+
78
+ pred_regression_outs[:, 1] = (
79
+ (pred_regression_outs[:, 1] * new_total_height) - new_header_height
80
+ ) / new_height
81
+
82
+ pred_coordinates = pred_regression_outs.squeeze().tolist()
83
+ return pred_coordinates
84
+
85
+ def draw_circle_on_image(self, image, coordinates):
86
+ x, y = coordinates[0] * image.width, coordinates[1] * image.height
87
+ print(coordinates)
88
+ draw = ImageDraw.Draw(image)
89
+ radius = 5
90
+ draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="red")
91
+ return image
92
+
93
+ def process_image_and_draw_circle(self, image, prompt):
94
+ encoding, render_variables = self.prepare_image(image, prompt, self.processor)
95
+ pred_coordinates = self.predict_coordinates(encoding.to(self.device) , self.model, render_variables)
96
+ result_image = self.draw_circle_on_image(image, pred_coordinates)
97
+ return result_image
98
+
99
+
100
+ def main():
101
+ inference = Inference()
102
+ print("Model and Processor are Loaded...")
103
+ # Gradio Interface
104
+ iface = gr.Interface(
105
+ fn=inference.process_image_and_draw_circle,
106
+ inputs=[gr.Image(type="pil", label = "Upload Image"),
107
+ gr.Textbox(label = "Prompt", placeholder="Enter prompt here...")],
108
+ outputs=gr.Image(type="pil"),
109
+ title="Pix2Struct Image Processing",
110
+ description="Upload an image and enter a prompt to see the model's prediction."
111
+ )
112
+
113
+
114
+ iface.launch()
115
+ if __name__ == "__main__":
116
+ main()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ Pillow
utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import textwrap
4
+ from typing import Dict, Optional, Tuple
5
+
6
+ from huggingface_hub import hf_hub_download
7
+ from PIL import Image, ImageDraw, ImageFont
8
+
9
+ DEFAULT_FONT_PATH = "ybelkada/fonts"
10
+
11
+
12
+ def download_default_font():
13
+ font_path = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF")
14
+ return font_path
15
+
16
+
17
+ def render_text(
18
+ text: str,
19
+ text_size: int = 36,
20
+ text_color: str = "black",
21
+ background_color: str = "white",
22
+ left_padding: int = 5,
23
+ right_padding: int = 5,
24
+ top_padding: int = 5,
25
+ bottom_padding: int = 5,
26
+ font_bytes: Optional[bytes] = None,
27
+ font_path: Optional[str] = None,
28
+ ) -> Image.Image:
29
+ """
30
+ Render text. This script is entirely adapted from the original script that can be found here:
31
+ https://github.com/google-research/pix2struct/blob/main/pix2struct/preprocessing/preprocessing_utils.py
32
+
33
+ Args:
34
+ text (`str`, *optional*, defaults to ):
35
+ Text to render.
36
+ text_size (`int`, *optional*, defaults to 36):
37
+ Size of the text.
38
+ text_color (`str`, *optional*, defaults to `"black"`):
39
+ Color of the text.
40
+ background_color (`str`, *optional*, defaults to `"white"`):
41
+ Color of the background.
42
+ left_padding (`int`, *optional*, defaults to 5):
43
+ Padding on the left.
44
+ right_padding (`int`, *optional*, defaults to 5):
45
+ Padding on the right.
46
+ top_padding (`int`, *optional*, defaults to 5):
47
+ Padding on the top.
48
+ bottom_padding (`int`, *optional*, defaults to 5):
49
+ Padding on the bottom.
50
+ font_bytes (`bytes`, *optional*):
51
+ Bytes of the font to use. If `None`, the default font will be used.
52
+ font_path (`str`, *optional*):
53
+ Path to the font to use. If `None`, the default font will be used.
54
+ """
55
+ wrapper = textwrap.TextWrapper(
56
+ width=80
57
+ ) # Add new lines so that each line is no more than 80 characters.
58
+ lines = wrapper.wrap(text=text)
59
+ wrapped_text = "\n".join(lines)
60
+
61
+ if font_bytes is not None and font_path is None:
62
+ font = io.BytesIO(font_bytes)
63
+ elif font_path is not None:
64
+ font = font_path
65
+ else:
66
+ font = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF")
67
+ raise ValueError(
68
+ "Either font_bytes or font_path must be provided. "
69
+ f"Using default font {font}."
70
+ )
71
+ font = ImageFont.truetype(font, encoding="UTF-8", size=text_size)
72
+
73
+ # Use a temporary canvas to determine the width and height in pixels when
74
+ # rendering the text.
75
+ temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color))
76
+ _, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font)
77
+
78
+ # Create the actual image with a bit of padding around the text.
79
+ image_width = text_width + left_padding + right_padding
80
+ image_height = text_height + top_padding + bottom_padding
81
+ image = Image.new("RGB", (image_width, image_height), background_color)
82
+ draw = ImageDraw.Draw(image)
83
+ draw.text(
84
+ xy=(left_padding, top_padding), text=wrapped_text, fill=text_color, font=font
85
+ )
86
+ return image
87
+
88
+
89
+ # Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87
90
+ def render_header(
91
+ image: Image.Image, header: str, bbox: Dict[str, float], font_path: str, **kwargs
92
+ ) -> Tuple[Image.Image, Tuple[float, float, float, float]]:
93
+ """
94
+ Renders the input text as a header on the input image and updates the bounding box.
95
+
96
+ Args:
97
+ image (Image.Image):
98
+ The image to render the header on.
99
+ header (str):
100
+ The header text.
101
+ bbox (Dict[str,float]):
102
+ The bounding box in relative position (0-1), format ("x_min": 0,
103
+ "y_min": 0,
104
+ "x_max": 0,
105
+ "y_max": 0).
106
+ input_data_format (Union[str, ChildProcessError], optional):
107
+ The data format of the image.
108
+
109
+ Returns:
110
+ Tuple[Image.Image, Dict[str, float] ]:
111
+ The image with the header rendered and the updated bounding box.
112
+ """
113
+ assert os.path.exists(font_path), f"Font path {font_path} does not exist."
114
+ header_image = render_text(text=header, font_path=font_path, **kwargs)
115
+ new_width = max(header_image.width, image.width)
116
+
117
+ new_height = int(image.height * (new_width / image.width))
118
+ new_header_height = int(header_image.height * (new_width / header_image.width))
119
+
120
+ new_image = Image.new("RGB", (new_width, new_height + new_header_height), "white")
121
+ new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0))
122
+ new_image.paste(image.resize((new_width, new_height)), (0, new_header_height))
123
+
124
+ new_total_height = new_image.height
125
+
126
+ new_bbox = {
127
+ "xmin": bbox["xmin"],
128
+ "ymin": ((bbox["ymin"] * new_height) + new_header_height)
129
+ / new_total_height, # shift y_min down by the header's relative height
130
+ "xmax": bbox["xmax"],
131
+ "ymax": ((bbox["ymax"] * new_height) + new_header_height)
132
+ / new_total_height, # shift y_min down by the header's relative height
133
+ }
134
+
135
+ return (
136
+ new_image,
137
+ new_bbox,
138
+ {
139
+ "width": new_width,
140
+ "height": new_height,
141
+ "header_height": new_header_height,
142
+ "total_height": new_total_height,
143
+ },
144
+ )