File size: 8,090 Bytes
6706230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import os
from functools import lru_cache
from typing import List

import cv2
import numpy as np
from diffusers.utils import load_image
from PIL import Image, ImageChops, ImageFilter
from ultralytics import YOLO
from .utils import *


def dilate_mask(mask, dilate_factor=6, blur_radius=2, erosion_factor=2):
    if not mask:
        return None
    # Convert PIL image to NumPy array if necessary
    if isinstance(mask, Image.Image):
        mask = np.array(mask)

    # Ensure mask is in uint8 format
    mask = mask.astype(np.uint8)

    # Apply dilation
    kernel = np.ones((dilate_factor, dilate_factor), np.uint8)
    dilated_mask = cv2.dilate(mask, kernel, iterations=1)

    # Apply erosion for refinement
    kernel = np.ones((erosion_factor, erosion_factor), np.uint8)
    eroded_mask = cv2.erode(dilated_mask, kernel, iterations=1)

    # Apply Gaussian blur to smooth the edges
    blurred_mask = cv2.GaussianBlur(
        eroded_mask, (2 * blur_radius + 1, 2 * blur_radius + 1), 0
    )

    # Convert back to PIL image
    smoothed_mask = Image.fromarray(blurred_mask).convert("L")

    # Optionally, apply an additional blur for extra smoothness using PIL
    smoothed_mask = smoothed_mask.filter(ImageFilter.GaussianBlur(radius=blur_radius))

    return smoothed_mask


@lru_cache(maxsize=1)
def get_model(model_id):
    model = YOLO(model=model_id)
    return model


def combine_masks(masks: List[dict], labels: List[str], is_label=True) -> Image.Image:
    """
    Combine masks with the specified labels into a single mask, optimized for speed and non-overlapping of excluded masks.

    Parameters:
    - masks (List[dict]): A list of dictionaries, each containing the mask under a 'mask' key and its label under a 'label' key.
    - labels (List[str]): A list of labels to include in the combination.

    Returns:
    - Image.Image: The combined mask as a PIL Image object, or None if no masks are combined.
    """
    labels_set = set(labels)  # Convert labels list to a set for O(1) lookups

    # Filter and convert mask images based on the specified labels
    mask_images = [
        mask["mask"].convert("L")
        for mask in masks
        if (mask["label"] in labels_set) == is_label
    ]

    # Ensure there is at least one mask to combine
    if not mask_images:
        return None  # Or raise an appropriate error, e.g., ValueError("No masks found for the specified labels.")

    # Initialize the combined mask with the first mask
    combined_mask = mask_images[0]

    # Combine the remaining masks with the existing combined_mask using a bitwise OR operation to ensure non-overlap
    for mask in mask_images[1:]:
        combined_mask = ImageChops.lighter(combined_mask, mask)

    return combined_mask


body_labels = ["hair", "face", "arm", "hand", "leg", "foot", "outfit"]


class BodyMask:

    def __init__(
        self,
        image_path,
        model_id,
        labels=body_labels,
        overlay="mask",
        widen_box=0,
        elongate_box=0,
        resize_to=640,
        dilate_factor=0,
        is_label=False,
        resize_to_nearest_eight=False,
        verbose=True,
        remove_overlap=True,
    ):
        self.image_path = image_path
        self.image = self.get_image(
            resize_to=resize_to, resize_to_nearest_eight=resize_to_nearest_eight
        )
        self.labels = labels
        self.is_label = is_label
        self.model_id = model_id
        self.model = get_model(self.model_id)
        self.model_labels = self.model.names
        self.verbose = verbose
        self.results = self.get_results()
        self.dilate_factor = dilate_factor
        self.body_mask = self.get_body_mask()
        self.box = get_bounding_box(self.body_mask)
        self.body_box = self.get_body_box(
            remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box
        )
        if overlay == "box":
            self.overlay = overlay_mask(
                self.image, self.body_box, opacity=0.9, color="red"
            )
        else:
            self.overlay = overlay_mask(
                self.image, self.body_mask, opacity=0.9, color="red"
            )

    def get_image(self, resize_to, resize_to_nearest_eight):
        image = load_image(self.image_path)
        if resize_to:
            image = resize_preserve_aspect_ratio(image, resize_to)
        if resize_to_nearest_eight:
            image = resize_image_to_nearest_eight(image)
        else:
            image = image
        return image

    def get_body_mask(self):
        body_mask = combine_masks(self.results, self.labels, self.is_label)
        return dilate_mask(body_mask, self.dilate_factor)

    def get_results(self):
        imgsz = max(self.image.size)
        results = self.model(
            self.image, retina_masks=True, imgsz=imgsz, verbose=self.verbose
        )[0]
        self.masks, self.boxes, self.scores, self.phrases = unload(
            results, self.model_labels
        )
        results = format_results(
            self.masks,
            self.boxes,
            self.scores,
            self.phrases,
            self.model_labels,
            person_masks_only=False,
        )

        # filter out lower score results
        masks_to_filter = ["hair"]
        results = filter_highest_score(results, ["hair", "face", "phone"])
        return results

    def display_results(self):
        if len(self.masks) < 4:
            cols = len(self.masks)
        else:
            cols = 4
        display_image_with_masks(self.image, self.results, cols=cols)

    def get_mask(self, mask_label):
        assert mask_label in self.phrases, "Mask label not found in results"
        return [f for f in self.results if f.get("label") == mask_label]

    def combine_masks(self, mask_labels: List, no_labels=None, is_label=True):
        """
        Combine the masks included in the labels list or all of the masks not in the list
        """
        if not is_label:
            mask_labels = [
                phrase for phrase in self.phrases if phrase not in mask_labels
            ]
        masks = [
            row.get("mask") for row in self.results if row.get("label") in mask_labels
        ]
        if len(masks) == 0:
            return None
        combined_mask = masks[0]
        for mask in masks[1:]:
            combined_mask = ImageChops.lighter(combined_mask, mask)
        return combined_mask

    def get_body_box(self, remove_overlap=True, widen=0, elongate=0):
        body_box = get_bounding_box_mask(self.body_mask, widen=widen, elongate=elongate)
        if remove_overlap:
            body_box = self.remove_overlap(body_box)
        return body_box

    def remove_overlap(self, body_box):
        """
        Remove mask regions that overlap with unwanted labels
        """
        # convert mask to numpy array
        box_array = np.array(body_box)

        # combine the masks for those labels
        mask = self.combine_masks(mask_labels=self.labels, is_label=True)

        # convert mask to numpy array
        mask_array = np.array(mask)

        # where the mask array is white set the box array to black
        box_array[mask_array == 255] = 0

        # convert the box array to an image
        mask_image = Image.fromarray(box_array)
        return mask_image


if __name__ == "__main__":
    url = "https://sjc1.vultrobjects.com/photo-storage/images/525d1f68-314c-455b-a8b6-f5dc3fa044e4.jpeg"
    image_name = url.split("/")[-1]
    labels = ["face", "hair", "phone", "hand"]
    image = load_image(url)
    image_size = image.size
    # Get the original size of the image
    original_size = image.size

    # Create body mask
    body_mask = BodyMask(
        image,
        overlay="box",
        labels=labels,
        widen_box=50,
        elongate_box=10,
        dilate_factor=0,
        resize_to=640,
        is_label=False,
        remove_overlap=True,
        verbose=False,
    )

    # Resize the image back to the original size
    image = body_mask.image.resize(original_size)
    body_mask.body_box.save(image_name)