MnLgt commited on
Commit
6706230
1 Parent(s): 7ae85d5
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ gradio_cached_examples/
3
+ checkpoint-*
4
+ */example.ipynb
5
+
6
+ *.pyc
README.md CHANGED
@@ -56,7 +56,7 @@ To use this model, you'll need to have the appropriate YOLO framework installed.
56
  To use the model for inference, you can use the following Python script:
57
 
58
  ```python
59
- from yolo_segmentation import YOLO
60
 
61
  # Load the model
62
  model = YOLO('path/to/your/model.pt')
 
56
  To use the model for inference, you can use the following Python script:
57
 
58
  ```python
59
+ from ultralytics import YOLO
60
 
61
  # Load the model
62
  model = YOLO('path/to/your/model.pt')
README.yaml DELETED
@@ -1,99 +0,0 @@
1
- ---
2
- language:
3
- - "en"
4
- thumbnail: "https://example.com/path/to/your/thumbnail.jpg"
5
- tags:
6
- - yolo
7
- - object-detection
8
- - image-segmentation
9
- - computer-vision
10
- - human-body-parts
11
- license: "mit"
12
- datasets:
13
- - custom_human_body_parts_dataset
14
- metrics:
15
- - mean_average_precision
16
- - intersection_over_union
17
- base_model: "ultralytics/yolov5yolov8x-seg"
18
- ---
19
-
20
- # YOLO Segmentation Model for Human Body Parts and Objects
21
-
22
- This model is a fine-tuned version of YOLOv5 for segmenting human body parts and objects. It can detect and segment 11 different classes including various body parts, outfits, and phones.
23
-
24
- ## Model Details
25
-
26
- - **Model Type:** YOLOv8 for Instance Segmentation
27
- - **Task:** Segmentation
28
- - **Fine-tuning Dataset:** Custom dataset of human body parts and objects
29
- - **Number of Classes:** 11
30
-
31
- ## Classes
32
-
33
- The model can detect and segment the following classes:
34
-
35
- 0. Hair
36
- 1. Face
37
- 2. Neck
38
- 3. Arm
39
- 4. Hand
40
- 5. Back
41
- 6. Leg
42
- 7. Foot
43
- 8. Outfit
44
- 9. Person
45
- 10. Phone
46
-
47
- ## Usage
48
-
49
- This model can be used for various applications, including:
50
-
51
- - Human pose estimation
52
- - Gesture recognition
53
- - Fashion analysis
54
- - Person tracking
55
- - Human-computer interaction
56
-
57
- For detailed usage instructions, please refer to the model's README file.
58
-
59
- ## Training Procedure
60
-
61
- The model was fine-tuned on a custom dataset of annotated images containing human body parts and objects. The training process involved transfer learning from the base YOLOv8 model, with adjustments made to the final layers to accommodate the new class structure.
62
-
63
- ## Evaluation Results
64
-
65
- (Note: Replace these placeholder metrics with your actual evaluation results)
66
-
67
- lr/pg0:0.000572628
68
- lr/pg1:0.000572628
69
- lr/pg2:0.000572628
70
- metrics/mAP50-95(B):0.53001
71
- metrics/mAP50-95(M):0.42367
72
- metrics/mAP50(B):0.69407
73
- metrics/mAP50(M):0.61714
74
- metrics/precision(B):0.7047
75
- metrics/precision(M):0.68041
76
- metrics/recall(B):0.68802
77
- metrics/recall(M):0.62248
78
- model/GFLOPs:344.557
79
- model/parameters:71,761,441
80
- model/speed_PyTorch(ms):5.813
81
- train/box_loss:0.54718
82
- train/cls_loss:0.52977
83
- train/dfl_loss:0.95171
84
- train/seg_loss:1.34628
85
- val/box_loss:0.80538
86
- val/cls_loss:0.83434
87
- val/dfl_loss:1.18352
88
- val/seg_loss:2.19488
89
-
90
-
91
- ## Limitations and Biases
92
-
93
- - The model's performance may vary depending on lighting conditions and image quality.
94
- - It may have difficulty with occluded or partially visible body parts.
95
- - The model's performance on diverse body types and skin tones should be carefully evaluated to ensure fairness and inclusivity.
96
-
97
- ## Ethical Considerations
98
-
99
- Users of this model should be aware of privacy concerns related to human body detection and ensure they have appropriate consent for its application. The model should not be used for surveillance or any application that could infringe on personal privacy without explicit consent.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from ultralytics import YOLO
4
+ from yolo.BodyMask import BodyMask
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib import patches
8
+ from skimage.transform import resize
9
+ from PIL import Image
10
+ import io
11
+
12
+ model_id = os.path.abspath("yolo-human-parse-epoch-125.pt")
13
+
14
+
15
+ def display_image_with_masks(image, results, cols=4):
16
+ # Convert PIL Image to numpy array
17
+ image_np = np.array(image)
18
+
19
+ # Check image dimensions
20
+ if image_np.ndim != 3 or image_np.shape[2] != 3:
21
+ raise ValueError("Image must be a 3-dimensional array with 3 color channels")
22
+
23
+ # Number of masks
24
+ n = len(results)
25
+ rows = (n + cols - 1) // cols # Calculate required number of rows
26
+
27
+ # Setting up the plot
28
+ fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
29
+ axs = np.array(axs).reshape(-1) # Flatten axs array for easy indexing
30
+
31
+ for i, result in enumerate(results):
32
+ mask = result["mask"]
33
+ label = result["label"]
34
+ score = float(result["score"])
35
+
36
+ # Convert PIL mask to numpy array and resize if necessary
37
+ mask_np = np.array(mask)
38
+ if mask_np.shape != image_np.shape[:2]:
39
+ mask_np = resize(
40
+ mask_np, image_np.shape[:2], mode="constant", anti_aliasing=False
41
+ )
42
+ mask_np = (mask_np > 0.5).astype(
43
+ np.uint8
44
+ ) # Threshold back to binary after resize
45
+
46
+ # Create an overlay where mask is True
47
+ overlay = np.zeros_like(image_np)
48
+ overlay[mask_np > 0] = [0, 0, 255] # Applying blue color on the mask area
49
+
50
+ # Combine the image and the overlay
51
+ combined = image_np.copy()
52
+ indices = np.where(mask_np > 0)
53
+ combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5
54
+
55
+ # Show the combined image
56
+ ax = axs[i]
57
+ ax.imshow(combined)
58
+ ax.axis("off")
59
+ ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12)
60
+ rect = patches.Rectangle(
61
+ (0, 0),
62
+ image_np.shape[1],
63
+ image_np.shape[0],
64
+ linewidth=1,
65
+ edgecolor="r",
66
+ facecolor="none",
67
+ )
68
+ ax.add_patch(rect)
69
+
70
+ # Hide unused subplots if the total number of masks is not a multiple of cols
71
+ for idx in range(i + 1, rows * cols):
72
+ axs[idx].axis("off")
73
+
74
+ plt.tight_layout()
75
+
76
+ # Save the plot to a bytes buffer
77
+ buf = io.BytesIO()
78
+ plt.savefig(buf, format="png")
79
+ buf.seek(0)
80
+
81
+ # Clear the current figure
82
+ plt.close(fig)
83
+
84
+ return buf
85
+
86
+
87
+ def perform_segmentation(input_image):
88
+ bm = BodyMask(input_image, model_id=model_id, resize_to=640)
89
+ results = bm.results
90
+ buf = display_image_with_masks(input_image, results)
91
+
92
+ # Convert BytesIO to PIL Image
93
+ img = Image.open(buf)
94
+ return img
95
+
96
+
97
+ # Get example images
98
+ example_images = [
99
+ os.path.join("sample_images", f)
100
+ for f in os.listdir("sample_images")
101
+ if f.endswith((".png", ".jpg", ".jpeg"))
102
+ ]
103
+
104
+ with gr.Blocks() as demo:
105
+ gr.Markdown("# YOLO Segmentation Demo with BodyMask")
106
+ gr.Markdown(
107
+ "Upload an image or select an example to see the YOLO segmentation results."
108
+ )
109
+
110
+ with gr.Row():
111
+ with gr.Column():
112
+ input_image = gr.Image(type="pil", label="Input Image", height=512)
113
+ segment_button = gr.Button("Perform Segmentation")
114
+
115
+ output_image = gr.Image(label="Segmentation Result")
116
+
117
+ gr.Examples(
118
+ examples=example_images,
119
+ inputs=input_image,
120
+ outputs=output_image,
121
+ fn=perform_segmentation,
122
+ cache_examples=True,
123
+ )
124
+
125
+ segment_button.click(
126
+ fn=perform_segmentation,
127
+ inputs=input_image,
128
+ outputs=output_image,
129
+ )
130
+
131
+ demo.launch()
example.ipynb ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os \n",
10
+ "from ultralytics import YOLO\n",
11
+ "from yolo.BodyMask import BodyMask\n",
12
+ "\n",
13
+ "\n",
14
+ "model_id = os.path.abspath(\"yolo-human-parse-epoch-125.pt\")\n",
15
+ "\n",
16
+ "example_images = [\n",
17
+ " os.path.join(\"sample_images\", f)\n",
18
+ " for f in os.listdir(\"sample_images\")\n",
19
+ " if f.endswith((\".png\", \".jpg\", \".jpeg\"))\n",
20
+ "]\n",
21
+ "\n",
22
+ "image = example_images[0]\n",
23
+ "\n",
24
+ "bm = BodyMask(image, model_id=model_id)"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "bm.display_results()"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 8,
39
+ "metadata": {},
40
+ "outputs": [
41
+ {
42
+ "name": "stdout",
43
+ "output_type": "stream",
44
+ "text": [
45
+ "\u001b[0;31mInit signature:\u001b[0m\n",
46
+ "\u001b[0mgr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mImage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
47
+ "\u001b[0;34m\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str | PIL.Image.Image | np.ndarray | Callable | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
48
+ "\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
49
+ "\u001b[0;34m\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'webp'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
50
+ "\u001b[0;34m\u001b[0m \u001b[0mheight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int | str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
51
+ "\u001b[0;34m\u001b[0m \u001b[0mwidth\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int | str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
52
+ "\u001b[0;34m\u001b[0m \u001b[0mimage_mode\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"Literal['1', 'L', 'P', 'RGB', 'RGBA', 'CMYK', 'YCbCr', 'LAB', 'HSV', 'I', 'F'] | None\"\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'RGB'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
53
+ "\u001b[0;34m\u001b[0m \u001b[0msources\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"list[Literal['upload', 'webcam', 'clipboard']] | Literal['upload', 'webcam', 'clipboard'] | None\"\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
54
+ "\u001b[0;34m\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"Literal['numpy', 'pil', 'filepath']\"\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'numpy'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
55
+ "\u001b[0;34m\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
56
+ "\u001b[0;34m\u001b[0m \u001b[0mevery\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Timer | float | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
57
+ "\u001b[0;34m\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Component | Sequence[Component] | set[Component] | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
58
+ "\u001b[0;34m\u001b[0m \u001b[0mshow_label\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
59
+ "\u001b[0;34m\u001b[0m \u001b[0mshow_download_button\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
60
+ "\u001b[0;34m\u001b[0m \u001b[0mcontainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
61
+ "\u001b[0;34m\u001b[0m \u001b[0mscale\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
62
+ "\u001b[0;34m\u001b[0m \u001b[0mmin_width\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m160\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
63
+ "\u001b[0;34m\u001b[0m \u001b[0minteractive\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
64
+ "\u001b[0;34m\u001b[0m \u001b[0mvisible\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
65
+ "\u001b[0;34m\u001b[0m \u001b[0mstreaming\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
66
+ "\u001b[0;34m\u001b[0m \u001b[0melem_id\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
67
+ "\u001b[0;34m\u001b[0m \u001b[0melem_classes\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'list[str] | str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
68
+ "\u001b[0;34m\u001b[0m \u001b[0mrender\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
69
+ "\u001b[0;34m\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int | str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
70
+ "\u001b[0;34m\u001b[0m \u001b[0mmirror_webcam\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
71
+ "\u001b[0;34m\u001b[0m \u001b[0mshow_share_button\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
72
+ "\u001b[0;34m\u001b[0m \u001b[0mplaceholder\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str | None'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
73
+ "\u001b[0;34m\u001b[0m \u001b[0mshow_fullscreen_button\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'bool'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
74
+ "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
75
+ "\u001b[0;31mDocstring:\u001b[0m \n",
76
+ "Creates an image component that can be used to upload images (as an input) or display images (as an output).\n",
77
+ "\n",
78
+ "Demos: sepia_filter, fake_diffusion\n",
79
+ "Guides: image-classification-in-pytorch, image-classification-in-tensorflow, image-classification-with-vision-transformers, create-your-own-friends-with-a-gan\n",
80
+ "\u001b[0;31mInit docstring:\u001b[0m\n",
81
+ "Parameters:\n",
82
+ " value: A PIL Image, numpy array, path or URL for the default value that Image component is going to take. If callable, the function will be called whenever the app loads to set the initial value of the component.\n",
83
+ " format: File format (e.g. \"png\" or \"gif\") to save image if it does not already have a valid format (e.g. if the image is being returned to the frontend as a numpy array or PIL Image). The format should be supported by the PIL library. This parameter has no effect on SVG files.\n",
84
+ " height: The height of the displayed image, specified in pixels if a number is passed, or in CSS units if a string is passed.\n",
85
+ " width: The width of the displayed image, specified in pixels if a number is passed, or in CSS units if a string is passed.\n",
86
+ " image_mode: \"RGB\" if color, or \"L\" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html for other supported image modes and their meaning. This parameter has no effect on SVG or GIF files. If set to None, the image_mode will be inferred from the image file.\n",
87
+ " sources: List of sources for the image. \"upload\" creates a box where user can drop an image file, \"webcam\" allows user to take snapshot from their webcam, \"clipboard\" allows users to paste an image from the clipboard. If None, defaults to [\"upload\", \"webcam\", \"clipboard\"] if streaming is False, otherwise defaults to [\"webcam\"].\n",
88
+ " type: The format the image is converted before being passed into the prediction function. \"numpy\" converts the image to a numpy array with shape (height, width, 3) and values from 0 to 255, \"pil\" converts the image to a PIL image object, \"filepath\" passes a str path to a temporary file containing the image. If the image is SVG, the `type` is ignored and the filepath of the SVG is returned. To support animated GIFs in input, the `type` should be set to \"filepath\" or \"pil\".\n",
89
+ " label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.\n",
90
+ " every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.\n",
91
+ " inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change.\n",
92
+ " show_label: if True, will display label.\n",
93
+ " show_download_button: If True, will display button to download image.\n",
94
+ " container: If True, will place the component in a container - providing some extra padding around the border.\n",
95
+ " scale: relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.\n",
96
+ " min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.\n",
97
+ " interactive: if True, will allow users to upload and edit an image; if False, can only be used to display images. If not provided, this is inferred based on whether the component is used as an input or output.\n",
98
+ " visible: If False, component will be hidden.\n",
99
+ " streaming: If True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'webcam'.\n",
100
+ " elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.\n",
101
+ " elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.\n",
102
+ " render: If False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.\n",
103
+ " key: if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.\n",
104
+ " mirror_webcam: If True webcam will be mirrored. Default is True.\n",
105
+ " show_share_button: If True, will show a share icon in the corner of the component that allows user to share outputs to Hugging Face Spaces Discussions. If False, icon does not appear. If set to None (default behavior), then the icon appears if this Gradio app is launched on Spaces, but not otherwise.\n",
106
+ " placeholder: Custom text for the upload area. Overrides default upload messages when provided. Accepts new lines and `#` to designate a heading.\n",
107
+ " show_fullscreen_button: If True, will show a fullscreen icon in the corner of the component that allows user to view the image in fullscreen mode. If False, icon does not appear.\n",
108
+ "\u001b[0;31mFile:\u001b[0m /opt/homebrew/Caskroom/miniforge/base/envs/lemons/lib/python3.10/site-packages/gradio/components/image.py\n",
109
+ "\u001b[0;31mType:\u001b[0m ComponentMeta\n",
110
+ "\u001b[0;31mSubclasses:\u001b[0m "
111
+ ]
112
+ }
113
+ ],
114
+ "source": [
115
+ "import gradio as gr \n",
116
+ "\n",
117
+ "gr.Image?"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {},
124
+ "outputs": [],
125
+ "source": []
126
+ }
127
+ ],
128
+ "metadata": {
129
+ "kernelspec": {
130
+ "display_name": "lemons",
131
+ "language": "python",
132
+ "name": "lemons"
133
+ },
134
+ "language_info": {
135
+ "codemirror_mode": {
136
+ "name": "ipython",
137
+ "version": 3
138
+ },
139
+ "file_extension": ".py",
140
+ "mimetype": "text/x-python",
141
+ "name": "python",
142
+ "nbconvert_exporter": "python",
143
+ "pygments_lexer": "ipython3",
144
+ "version": "3.10.14"
145
+ },
146
+ "orig_nbformat": 4
147
+ },
148
+ "nbformat": 4,
149
+ "nbformat_minor": 2
150
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ diffusers==0.30.3
2
+ gradio==4.44.0
3
+ matplotlib==3.8.4
4
+ numpy==2.1.1
5
+ Pillow==10.4.0
6
+ skimage==0.0
7
+ ultralytics==8.2.97
sample_images/image_five.jpg ADDED
sample_images/image_four.jpg ADDED
sample_images/image_six.jpg ADDED
yolo/BodyMask.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import List
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from diffusers.utils import load_image
8
+ from PIL import Image, ImageChops, ImageFilter
9
+ from ultralytics import YOLO
10
+ from .utils import *
11
+
12
+
13
+ def dilate_mask(mask, dilate_factor=6, blur_radius=2, erosion_factor=2):
14
+ if not mask:
15
+ return None
16
+ # Convert PIL image to NumPy array if necessary
17
+ if isinstance(mask, Image.Image):
18
+ mask = np.array(mask)
19
+
20
+ # Ensure mask is in uint8 format
21
+ mask = mask.astype(np.uint8)
22
+
23
+ # Apply dilation
24
+ kernel = np.ones((dilate_factor, dilate_factor), np.uint8)
25
+ dilated_mask = cv2.dilate(mask, kernel, iterations=1)
26
+
27
+ # Apply erosion for refinement
28
+ kernel = np.ones((erosion_factor, erosion_factor), np.uint8)
29
+ eroded_mask = cv2.erode(dilated_mask, kernel, iterations=1)
30
+
31
+ # Apply Gaussian blur to smooth the edges
32
+ blurred_mask = cv2.GaussianBlur(
33
+ eroded_mask, (2 * blur_radius + 1, 2 * blur_radius + 1), 0
34
+ )
35
+
36
+ # Convert back to PIL image
37
+ smoothed_mask = Image.fromarray(blurred_mask).convert("L")
38
+
39
+ # Optionally, apply an additional blur for extra smoothness using PIL
40
+ smoothed_mask = smoothed_mask.filter(ImageFilter.GaussianBlur(radius=blur_radius))
41
+
42
+ return smoothed_mask
43
+
44
+
45
+ @lru_cache(maxsize=1)
46
+ def get_model(model_id):
47
+ model = YOLO(model=model_id)
48
+ return model
49
+
50
+
51
+ def combine_masks(masks: List[dict], labels: List[str], is_label=True) -> Image.Image:
52
+ """
53
+ Combine masks with the specified labels into a single mask, optimized for speed and non-overlapping of excluded masks.
54
+
55
+ Parameters:
56
+ - masks (List[dict]): A list of dictionaries, each containing the mask under a 'mask' key and its label under a 'label' key.
57
+ - labels (List[str]): A list of labels to include in the combination.
58
+
59
+ Returns:
60
+ - Image.Image: The combined mask as a PIL Image object, or None if no masks are combined.
61
+ """
62
+ labels_set = set(labels) # Convert labels list to a set for O(1) lookups
63
+
64
+ # Filter and convert mask images based on the specified labels
65
+ mask_images = [
66
+ mask["mask"].convert("L")
67
+ for mask in masks
68
+ if (mask["label"] in labels_set) == is_label
69
+ ]
70
+
71
+ # Ensure there is at least one mask to combine
72
+ if not mask_images:
73
+ return None # Or raise an appropriate error, e.g., ValueError("No masks found for the specified labels.")
74
+
75
+ # Initialize the combined mask with the first mask
76
+ combined_mask = mask_images[0]
77
+
78
+ # Combine the remaining masks with the existing combined_mask using a bitwise OR operation to ensure non-overlap
79
+ for mask in mask_images[1:]:
80
+ combined_mask = ImageChops.lighter(combined_mask, mask)
81
+
82
+ return combined_mask
83
+
84
+
85
+ body_labels = ["hair", "face", "arm", "hand", "leg", "foot", "outfit"]
86
+
87
+
88
+ class BodyMask:
89
+
90
+ def __init__(
91
+ self,
92
+ image_path,
93
+ model_id,
94
+ labels=body_labels,
95
+ overlay="mask",
96
+ widen_box=0,
97
+ elongate_box=0,
98
+ resize_to=640,
99
+ dilate_factor=0,
100
+ is_label=False,
101
+ resize_to_nearest_eight=False,
102
+ verbose=True,
103
+ remove_overlap=True,
104
+ ):
105
+ self.image_path = image_path
106
+ self.image = self.get_image(
107
+ resize_to=resize_to, resize_to_nearest_eight=resize_to_nearest_eight
108
+ )
109
+ self.labels = labels
110
+ self.is_label = is_label
111
+ self.model_id = model_id
112
+ self.model = get_model(self.model_id)
113
+ self.model_labels = self.model.names
114
+ self.verbose = verbose
115
+ self.results = self.get_results()
116
+ self.dilate_factor = dilate_factor
117
+ self.body_mask = self.get_body_mask()
118
+ self.box = get_bounding_box(self.body_mask)
119
+ self.body_box = self.get_body_box(
120
+ remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box
121
+ )
122
+ if overlay == "box":
123
+ self.overlay = overlay_mask(
124
+ self.image, self.body_box, opacity=0.9, color="red"
125
+ )
126
+ else:
127
+ self.overlay = overlay_mask(
128
+ self.image, self.body_mask, opacity=0.9, color="red"
129
+ )
130
+
131
+ def get_image(self, resize_to, resize_to_nearest_eight):
132
+ image = load_image(self.image_path)
133
+ if resize_to:
134
+ image = resize_preserve_aspect_ratio(image, resize_to)
135
+ if resize_to_nearest_eight:
136
+ image = resize_image_to_nearest_eight(image)
137
+ else:
138
+ image = image
139
+ return image
140
+
141
+ def get_body_mask(self):
142
+ body_mask = combine_masks(self.results, self.labels, self.is_label)
143
+ return dilate_mask(body_mask, self.dilate_factor)
144
+
145
+ def get_results(self):
146
+ imgsz = max(self.image.size)
147
+ results = self.model(
148
+ self.image, retina_masks=True, imgsz=imgsz, verbose=self.verbose
149
+ )[0]
150
+ self.masks, self.boxes, self.scores, self.phrases = unload(
151
+ results, self.model_labels
152
+ )
153
+ results = format_results(
154
+ self.masks,
155
+ self.boxes,
156
+ self.scores,
157
+ self.phrases,
158
+ self.model_labels,
159
+ person_masks_only=False,
160
+ )
161
+
162
+ # filter out lower score results
163
+ masks_to_filter = ["hair"]
164
+ results = filter_highest_score(results, ["hair", "face", "phone"])
165
+ return results
166
+
167
+ def display_results(self):
168
+ if len(self.masks) < 4:
169
+ cols = len(self.masks)
170
+ else:
171
+ cols = 4
172
+ display_image_with_masks(self.image, self.results, cols=cols)
173
+
174
+ def get_mask(self, mask_label):
175
+ assert mask_label in self.phrases, "Mask label not found in results"
176
+ return [f for f in self.results if f.get("label") == mask_label]
177
+
178
+ def combine_masks(self, mask_labels: List, no_labels=None, is_label=True):
179
+ """
180
+ Combine the masks included in the labels list or all of the masks not in the list
181
+ """
182
+ if not is_label:
183
+ mask_labels = [
184
+ phrase for phrase in self.phrases if phrase not in mask_labels
185
+ ]
186
+ masks = [
187
+ row.get("mask") for row in self.results if row.get("label") in mask_labels
188
+ ]
189
+ if len(masks) == 0:
190
+ return None
191
+ combined_mask = masks[0]
192
+ for mask in masks[1:]:
193
+ combined_mask = ImageChops.lighter(combined_mask, mask)
194
+ return combined_mask
195
+
196
+ def get_body_box(self, remove_overlap=True, widen=0, elongate=0):
197
+ body_box = get_bounding_box_mask(self.body_mask, widen=widen, elongate=elongate)
198
+ if remove_overlap:
199
+ body_box = self.remove_overlap(body_box)
200
+ return body_box
201
+
202
+ def remove_overlap(self, body_box):
203
+ """
204
+ Remove mask regions that overlap with unwanted labels
205
+ """
206
+ # convert mask to numpy array
207
+ box_array = np.array(body_box)
208
+
209
+ # combine the masks for those labels
210
+ mask = self.combine_masks(mask_labels=self.labels, is_label=True)
211
+
212
+ # convert mask to numpy array
213
+ mask_array = np.array(mask)
214
+
215
+ # where the mask array is white set the box array to black
216
+ box_array[mask_array == 255] = 0
217
+
218
+ # convert the box array to an image
219
+ mask_image = Image.fromarray(box_array)
220
+ return mask_image
221
+
222
+
223
+ if __name__ == "__main__":
224
+ url = "https://sjc1.vultrobjects.com/photo-storage/images/525d1f68-314c-455b-a8b6-f5dc3fa044e4.jpeg"
225
+ image_name = url.split("/")[-1]
226
+ labels = ["face", "hair", "phone", "hand"]
227
+ image = load_image(url)
228
+ image_size = image.size
229
+ # Get the original size of the image
230
+ original_size = image.size
231
+
232
+ # Create body mask
233
+ body_mask = BodyMask(
234
+ image,
235
+ overlay="box",
236
+ labels=labels,
237
+ widen_box=50,
238
+ elongate_box=10,
239
+ dilate_factor=0,
240
+ resize_to=640,
241
+ is_label=False,
242
+ remove_overlap=True,
243
+ verbose=False,
244
+ )
245
+
246
+ # Resize the image back to the original size
247
+ image = body_mask.image.resize(original_size)
248
+ body_mask.body_box.save(image_name)
yolo/utils.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.patches as patches
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw
5
+
6
+
7
+ def unload_mask(mask):
8
+ mask = mask.cpu().numpy().squeeze()
9
+ mask = mask.astype(np.uint8) * 255
10
+ return Image.fromarray(mask)
11
+
12
+
13
+ def unload_box(box):
14
+ return box.cpu().numpy().tolist()
15
+
16
+
17
+ def masks_overlap(mask1, mask2):
18
+ return np.any(np.logical_and(mask1, mask2))
19
+
20
+
21
+ def remove_non_person_masks(person_mask, formatted_results):
22
+ return [
23
+ f
24
+ for f in formatted_results
25
+ if f.get("label") == "person" or masks_overlap(person_mask, f.get("mask"))
26
+ ]
27
+
28
+
29
+ def format_masks(masks):
30
+ return [unload_mask(mask) for mask in masks]
31
+
32
+
33
+ def format_boxes(boxes):
34
+ return [unload_box(box) for box in boxes]
35
+
36
+
37
+ def format_scores(scores):
38
+ return scores.cpu().numpy().tolist()
39
+
40
+
41
+ def unload(result, labels_dict):
42
+ masks = format_masks(result.masks.data)
43
+ boxes = format_boxes(result.boxes.xyxy)
44
+ scores = format_scores(result.boxes.conf)
45
+ labels = result.boxes.cls
46
+ labels = [int(label.item()) for label in labels]
47
+ phrases = [labels_dict[label] for label in labels]
48
+ return masks, boxes, scores, phrases
49
+
50
+
51
+ def format_results(masks, boxes, scores, labels, labels_dict, person_masks_only=True):
52
+ if isinstance(list(labels_dict.keys())[0], int):
53
+ labels_dict = {v: k for k, v in labels_dict.items()}
54
+
55
+ # check that the person mask is present
56
+ if person_masks_only:
57
+ assert "person" in labels, "Person mask not present in results"
58
+ results_dict = []
59
+ for row in zip(labels, scores, boxes, masks):
60
+ label, score, box, mask = row
61
+ label_id = labels_dict[label]
62
+ results_row = dict(
63
+ label=label, score=score, mask=mask, box=box, label_id=label_id
64
+ )
65
+ results_dict.append(results_row)
66
+ results_dict = sorted(results_dict, key=lambda x: x["label"])
67
+ if person_masks_only:
68
+ # Get the person mask
69
+ person_mask = [f for f in results_dict if f.get("label") == "person"][0]["mask"]
70
+ assert person_mask is not None, "Person mask not found in results"
71
+
72
+ # Remove any results that do no overlap with the person
73
+ results_dict = remove_non_person_masks(person_mask, results_dict)
74
+ return results_dict
75
+
76
+
77
+ def filter_highest_score(results, labels):
78
+ """
79
+ Filter results to remove entries with lower scores for specified labels.
80
+
81
+ Args:
82
+ results (list): List of dictionaries containing 'label', 'score', and other keys.
83
+ labels (list): List of labels to filter.
84
+
85
+ Returns:
86
+ list: Filtered results with only the highest score for each specified label.
87
+ """
88
+ # Dictionary to keep track of the highest score entry for each label
89
+ label_highest = {}
90
+
91
+ # First pass: identify the highest score for each label
92
+ for result in results:
93
+ label = result["label"]
94
+ if label in labels:
95
+ if (
96
+ label not in label_highest
97
+ or result["score"] > label_highest[label]["score"]
98
+ ):
99
+ label_highest[label] = result
100
+
101
+ # Second pass: construct the filtered list while preserving the order
102
+ filtered_results = []
103
+ seen_labels = set()
104
+
105
+ for result in results:
106
+ label = result["label"]
107
+ if label in labels:
108
+ if label in seen_labels:
109
+ continue
110
+ if result == label_highest[label]:
111
+ filtered_results.append(result)
112
+ seen_labels.add(label)
113
+ else:
114
+ filtered_results.append(result)
115
+
116
+ return filtered_results
117
+
118
+
119
+ def display_image_with_masks(image, results, cols=4, return_images=False):
120
+ # Convert PIL Image to numpy array
121
+ image_np = np.array(image)
122
+
123
+ # Check image dimensions
124
+ if image_np.ndim != 3 or image_np.shape[2] != 3:
125
+ raise ValueError("Image must be a 3-dimensional array with 3 color channels")
126
+
127
+ # Number of masks
128
+ n = len(results)
129
+ rows = (n + cols - 1) // cols # Calculate required number of rows
130
+
131
+ # Setting up the plot
132
+ fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
133
+ axs = np.array(axs).reshape(-1) # Flatten axs array for easy indexing
134
+ for i, result in enumerate(results):
135
+ mask = result["mask"]
136
+ label = result["label"]
137
+ score = float(result["score"])
138
+
139
+ # Convert PIL mask to numpy array and resize if necessary
140
+ mask_np = np.array(mask)
141
+ if mask_np.shape != image_np.shape[:2]:
142
+ mask_np = resize(
143
+ mask_np, image_np.shape[:2], mode="constant", anti_aliasing=False
144
+ )
145
+ mask_np = (mask_np > 0.5).astype(
146
+ np.uint8
147
+ ) # Threshold back to binary after resize
148
+
149
+ # Create an overlay where mask is True
150
+ overlay = np.zeros_like(image_np)
151
+ overlay[mask_np > 0] = [0, 0, 255] # Applying blue color on the mask area
152
+
153
+ # Combine the image and the overlay
154
+ combined = image_np.copy()
155
+ indices = np.where(mask_np > 0)
156
+ combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5
157
+
158
+ # Show the combined image
159
+ ax = axs[i]
160
+ ax.imshow(combined)
161
+ ax.axis("off")
162
+ ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12)
163
+ rect = patches.Rectangle(
164
+ (0, 0),
165
+ image_np.shape[1],
166
+ image_np.shape[0],
167
+ linewidth=1,
168
+ edgecolor="r",
169
+ facecolor="none",
170
+ )
171
+ ax.add_patch(rect)
172
+
173
+ # Hide unused subplots if the total number of masks is not a multiple of cols
174
+ for idx in range(i + 1, rows * cols):
175
+ axs[idx].axis("off")
176
+ plt.tight_layout()
177
+ plt.show()
178
+
179
+
180
+ def get_bounding_box(mask):
181
+ """
182
+ Given a segmentation mask, return the bounding box for the mask object.
183
+ """
184
+ # Find indices where the mask is non-zero
185
+ coords = np.argwhere(mask)
186
+ # Get the minimum and maximum x and y coordinates
187
+ x_min, y_min = np.min(coords, axis=0)
188
+ x_max, y_max = np.max(coords, axis=0)
189
+ # Return the bounding box coordinates
190
+ return (y_min, x_min, y_max, x_max)
191
+
192
+
193
+ def get_bounding_box_mask(segmentation_mask, widen=0, elongate=0):
194
+ # Convert the PIL segmentation mask to a NumPy array
195
+ mask_array = np.array(segmentation_mask)
196
+
197
+ # Find the coordinates of the non-zero pixels
198
+ non_zero_y, non_zero_x = np.nonzero(mask_array)
199
+
200
+ # Calculate the bounding box coordinates
201
+ min_x, max_x = np.min(non_zero_x), np.max(non_zero_x)
202
+ min_y, max_y = np.min(non_zero_y), np.max(non_zero_y)
203
+
204
+ if widen > 0:
205
+ min_x = max(0, min_x - widen)
206
+ max_x = min(mask_array.shape[1], max_x + widen)
207
+
208
+ if elongate > 0:
209
+ min_y = max(0, min_y - elongate)
210
+ max_y = min(mask_array.shape[0], max_y + elongate)
211
+
212
+ # Create a new blank image for the bounding box mask
213
+ bounding_box_mask = Image.new("1", segmentation_mask.size)
214
+
215
+ # Draw the filled bounding box on the blank image
216
+ draw = ImageDraw.Draw(bounding_box_mask)
217
+ draw.rectangle([(min_x, min_y), (max_x, max_y)], fill=1)
218
+
219
+ return bounding_box_mask
220
+
221
+
222
+ colors = {
223
+ "blue": (136, 207, 249),
224
+ "red": (255, 0, 0),
225
+ "green": (0, 255, 0),
226
+ "yellow": (255, 255, 0),
227
+ "purple": (128, 0, 128),
228
+ "cyan": (0, 255, 255),
229
+ "magenta": (255, 0, 255),
230
+ "orange": (255, 165, 0),
231
+ "lime": (50, 205, 50),
232
+ "pink": (255, 192, 203),
233
+ "brown": (139, 69, 19),
234
+ "gray": (128, 128, 128),
235
+ "black": (0, 0, 0),
236
+ "white": (255, 255, 255),
237
+ "gold": (255, 215, 0),
238
+ "silver": (192, 192, 192),
239
+ "beige": (245, 245, 220),
240
+ "navy": (0, 0, 128),
241
+ "maroon": (128, 0, 0),
242
+ "olive": (128, 128, 0),
243
+ }
244
+
245
+
246
+ def overlay_mask(image, mask, opacity=0.5, color="blue"):
247
+ """
248
+ Takes in a PIL image and a PIL boolean image mask. Overlay the mask on the image
249
+ and color the mask with a low opacity blue with hex #88CFF9.
250
+ """
251
+ # Convert the boolean mask to an image with alpha channel
252
+ alpha = mask.convert("L").point(lambda x: 255 if x == 255 else 0, mode="1")
253
+
254
+ # Choose the color
255
+ r, g, b = colors[color]
256
+
257
+ color_mask = Image.new("RGBA", mask.size, (r, g, b, int(opacity * 255)))
258
+ mask_rgba = Image.composite(
259
+ color_mask, Image.new("RGBA", mask.size, (0, 0, 0, 0)), alpha
260
+ )
261
+
262
+ # Create a new RGBA image to overlay the mask on
263
+ overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
264
+
265
+ # Paste the mask onto the overlay
266
+ overlay.paste(mask_rgba, (0, 0))
267
+
268
+ # Create a new image to return by blending the original image and the overlay
269
+ result = Image.alpha_composite(image.convert("RGBA"), overlay)
270
+
271
+ # Convert the result back to the original mode and return it
272
+ return result.convert(image.mode)
273
+
274
+
275
+ def resize_preserve_aspect_ratio(image, max_side=512):
276
+ width, height = image.size
277
+ scale = min(max_side / width, max_side / height)
278
+ new_width = int(width * scale)
279
+ new_height = int(height * scale)
280
+ return image.resize((new_width, new_height))
281
+
282
+
283
+ def round_to_nearest_eigth(value):
284
+ return int((value // 8 * 8))
285
+
286
+
287
+ def resize_image_to_nearest_eight(image):
288
+ width, height = image.size
289
+ width, height = round_to_nearest_eigth(width), round_to_nearest_eigth(height)
290
+ image = image.resize((width, height))
291
+ return image