JingyeChen22 commited on
Commit
cc859d1
1 Parent(s): 9de996f

Update util.py

Browse files
Files changed (1) hide show
  1. util.py +54 -16
util.py CHANGED
@@ -26,7 +26,7 @@ for index, c in enumerate(alphabet):
26
 
27
 
28
 
29
- def transform_mask_pil(mask_root):
30
  """
31
  This function extracts the mask area and text area from the images.
32
 
@@ -37,13 +37,13 @@ def transform_mask_pil(mask_root):
37
  * The white area is the text area
38
  """
39
  img = np.array(mask_root)
40
- img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_NEAREST)
41
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
42
  ret, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY) # pixel value is set to 0 or 255 according to the threshold
43
  return 1 - (binary.astype(np.float32) / 255)
 
44
 
45
-
46
- def transform_mask(mask_root: str):
47
  """
48
  This function extracts the mask area and text area from the images.
49
 
@@ -54,7 +54,7 @@ def transform_mask(mask_root: str):
54
  * The white area is the text area
55
  """
56
  img = cv2.imread(mask_root)
57
- img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_NEAREST)
58
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
59
  ret, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY) # pixel value is set to 0 or 255 according to the threshold
60
  return 1 - (binary.astype(np.float32) / 255)
@@ -125,7 +125,45 @@ def filter_segmentation_mask(segmentation_mask: np.array):
125
 
126
 
127
 
128
- def combine_image(args, sub_output_dir: str, pred_image_list: List, image_pil: Image, character_mask_pil: Image, character_mask_highlight_pil: Image, caption_pil_list: List):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  """
130
  This function combines all the outputs and useful inputs together.
131
 
@@ -143,20 +181,20 @@ def combine_image(args, sub_output_dir: str, pred_image_list: List, image_pil: I
143
  if size == 1:
144
  return pred_image_list[0]
145
  elif size == 2:
146
- blank = Image.new('RGB', (512*2, 512), (0,0,0))
147
  blank.paste(pred_image_list[0],(0,0))
148
- blank.paste(pred_image_list[1],(512,0))
149
  elif size == 3:
150
- blank = Image.new('RGB', (512*3, 512), (0,0,0))
151
  blank.paste(pred_image_list[0],(0,0))
152
- blank.paste(pred_image_list[1],(512,0))
153
- blank.paste(pred_image_list[2],(1024,0))
154
  elif size == 4:
155
- blank = Image.new('RGB', (512*2, 512*2), (0,0,0))
156
  blank.paste(pred_image_list[0],(0,0))
157
- blank.paste(pred_image_list[1],(512,0))
158
- blank.paste(pred_image_list[2],(0,512))
159
- blank.paste(pred_image_list[3],(512,512))
160
 
161
 
162
  return blank
@@ -303,4 +341,4 @@ def inpainting_merge_image(original_image, mask_image, inpainting_image):
303
  table.append(0)
304
  mask_image = mask_image.point(table, "1")
305
  merged_image = Image.composite(inpainting_image, original_image, mask_image)
306
- return merged_image
 
26
 
27
 
28
 
29
+ def transform_mask_pil(mask_root, size):
30
  """
31
  This function extracts the mask area and text area from the images.
32
 
 
37
  * The white area is the text area
38
  """
39
  img = np.array(mask_root)
40
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_NEAREST)
41
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
42
  ret, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY) # pixel value is set to 0 or 255 according to the threshold
43
  return 1 - (binary.astype(np.float32) / 255)
44
+
45
 
46
+ def transform_mask(mask_root, size):
 
47
  """
48
  This function extracts the mask area and text area from the images.
49
 
 
54
  * The white area is the text area
55
  """
56
  img = cv2.imread(mask_root)
57
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_NEAREST)
58
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
59
  ret, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY) # pixel value is set to 0 or 255 according to the threshold
60
  return 1 - (binary.astype(np.float32) / 255)
 
125
 
126
 
127
 
128
+ def combine_image(args, resolution, sub_output_dir: str, pred_image_list: List, image_pil: Image, character_mask_pil: Image, character_mask_highlight_pil: Image, caption_pil_list: List):
129
+ """
130
+ This function combines all the outputs and useful inputs together.
131
+
132
+ Args:
133
+ args (argparse.ArgumentParser): The arguments.
134
+ pred_image_list (List): List of predicted images.
135
+ image_pil (Image): The original image.
136
+ character_mask_pil (Image): The character-level segmentation mask.
137
+ character_mask_highlight_pil (Image): The character-level segmentation mask highlighting character regions with green color.
138
+ caption_pil_list (List): List of captions.
139
+ """
140
+
141
+
142
+ size = len(pred_image_list)
143
+
144
+ if size == 1:
145
+ return pred_image_list[0]
146
+ elif size == 2:
147
+ blank = Image.new('RGB', (resolution*2, resolution), (0,0,0))
148
+ blank.paste(pred_image_list[0],(0,0))
149
+ blank.paste(pred_image_list[1],(resolution,0))
150
+ elif size == 3:
151
+ blank = Image.new('RGB', (resolution*3, resolution), (0,0,0))
152
+ blank.paste(pred_image_list[0],(0,0))
153
+ blank.paste(pred_image_list[1],(resolution,0))
154
+ blank.paste(pred_image_list[2],(resolution*2,0))
155
+ elif size == 4:
156
+ blank = Image.new('RGB', (resolution*2, resolution*2), (0,0,0))
157
+ blank.paste(pred_image_list[0],(0,0))
158
+ blank.paste(pred_image_list[1],(resolution,0))
159
+ blank.paste(pred_image_list[2],(0,resolution))
160
+ blank.paste(pred_image_list[3],(resolution,resolution))
161
+
162
+
163
+ return blank
164
+
165
+
166
+ def combine_image_gradio(args, size, sub_output_dir: str, pred_image_list: List, image_pil: Image, character_mask_pil: Image, character_mask_highlight_pil: Image, caption_pil_list: List):
167
  """
168
  This function combines all the outputs and useful inputs together.
169
 
 
181
  if size == 1:
182
  return pred_image_list[0]
183
  elif size == 2:
184
+ blank = Image.new('RGB', (size*2, size), (0,0,0))
185
  blank.paste(pred_image_list[0],(0,0))
186
+ blank.paste(pred_image_list[1],(size,0))
187
  elif size == 3:
188
+ blank = Image.new('RGB', (size*3, size), (0,0,0))
189
  blank.paste(pred_image_list[0],(0,0))
190
+ blank.paste(pred_image_list[1],(size,0))
191
+ blank.paste(pred_image_list[2],(size*2,0))
192
  elif size == 4:
193
+ blank = Image.new('RGB', (size*2, size*2), (0,0,0))
194
  blank.paste(pred_image_list[0],(0,0))
195
+ blank.paste(pred_image_list[1],(size,0))
196
+ blank.paste(pred_image_list[2],(0,size))
197
+ blank.paste(pred_image_list[3],(size,size))
198
 
199
 
200
  return blank
 
341
  table.append(0)
342
  mask_image = mask_image.point(table, "1")
343
  merged_image = Image.composite(inpainting_image, original_image, mask_image)
344
+ return merged_image