import os from functools import lru_cache from typing import List import cv2 import numpy as np from diffusers.utils import load_image from PIL import Image, ImageChops, ImageFilter from ultralytics import YOLO from .utils import * def dilate_mask(mask, dilate_factor=6, blur_radius=2, erosion_factor=2): if not mask: return None # Convert PIL image to NumPy array if necessary if isinstance(mask, Image.Image): mask = np.array(mask) # Ensure mask is in uint8 format mask = mask.astype(np.uint8) # Apply dilation kernel = np.ones((dilate_factor, dilate_factor), np.uint8) dilated_mask = cv2.dilate(mask, kernel, iterations=1) # Apply erosion for refinement kernel = np.ones((erosion_factor, erosion_factor), np.uint8) eroded_mask = cv2.erode(dilated_mask, kernel, iterations=1) # Apply Gaussian blur to smooth the edges blurred_mask = cv2.GaussianBlur( eroded_mask, (2 * blur_radius + 1, 2 * blur_radius + 1), 0 ) # Convert back to PIL image smoothed_mask = Image.fromarray(blurred_mask).convert("L") # Optionally, apply an additional blur for extra smoothness using PIL smoothed_mask = smoothed_mask.filter(ImageFilter.GaussianBlur(radius=blur_radius)) return smoothed_mask @lru_cache(maxsize=1) def get_model(model_id): model = YOLO(model=model_id) return model def combine_masks(masks: List[dict], labels: List[str], is_label=True) -> Image.Image: """ Combine masks with the specified labels into a single mask, optimized for speed and non-overlapping of excluded masks. Parameters: - masks (List[dict]): A list of dictionaries, each containing the mask under a 'mask' key and its label under a 'label' key. - labels (List[str]): A list of labels to include in the combination. Returns: - Image.Image: The combined mask as a PIL Image object, or None if no masks are combined. """ labels_set = set(labels) # Convert labels list to a set for O(1) lookups # Filter and convert mask images based on the specified labels mask_images = [ mask["mask"].convert("L") for mask in masks if (mask["label"] in labels_set) == is_label ] # Ensure there is at least one mask to combine if not mask_images: return None # Or raise an appropriate error, e.g., ValueError("No masks found for the specified labels.") # Initialize the combined mask with the first mask combined_mask = mask_images[0] # Combine the remaining masks with the existing combined_mask using a bitwise OR operation to ensure non-overlap for mask in mask_images[1:]: combined_mask = ImageChops.lighter(combined_mask, mask) return combined_mask body_labels = ["hair", "face", "arm", "hand", "leg", "foot", "outfit"] class BodyMask: def __init__( self, image_path, model_id, labels=body_labels, overlay="mask", widen_box=0, elongate_box=0, resize_to=640, dilate_factor=0, is_label=False, resize_to_nearest_eight=False, verbose=True, remove_overlap=True, ): self.image_path = image_path self.image = self.get_image( resize_to=resize_to, resize_to_nearest_eight=resize_to_nearest_eight ) self.labels = labels self.is_label = is_label self.model_id = model_id self.model = get_model(self.model_id) self.model_labels = self.model.names self.verbose = verbose self.results = self.get_results() self.dilate_factor = dilate_factor self.body_mask = self.get_body_mask() self.box = get_bounding_box(self.body_mask) self.body_box = self.get_body_box( remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box ) if overlay == "box": self.overlay = overlay_mask( self.image, self.body_box, opacity=0.9, color="red" ) else: self.overlay = overlay_mask( self.image, self.body_mask, opacity=0.9, color="red" ) def get_image(self, resize_to, resize_to_nearest_eight): image = load_image(self.image_path) if resize_to: image = resize_preserve_aspect_ratio(image, resize_to) if resize_to_nearest_eight: image = resize_image_to_nearest_eight(image) else: image = image return image def get_body_mask(self): body_mask = combine_masks(self.results, self.labels, self.is_label) return dilate_mask(body_mask, self.dilate_factor) def get_results(self): imgsz = max(self.image.size) results = self.model( self.image, retina_masks=True, imgsz=imgsz, verbose=self.verbose )[0] self.masks, self.boxes, self.scores, self.phrases = unload( results, self.model_labels ) results = format_results( self.masks, self.boxes, self.scores, self.phrases, self.model_labels, person_masks_only=False, ) # filter out lower score results masks_to_filter = ["hair"] results = filter_highest_score(results, ["hair", "face", "phone"]) return results def display_results(self): if len(self.masks) < 4: cols = len(self.masks) else: cols = 4 display_image_with_masks(self.image, self.results, cols=cols) def get_mask(self, mask_label): assert mask_label in self.phrases, "Mask label not found in results" return [f for f in self.results if f.get("label") == mask_label] def combine_masks(self, mask_labels: List, no_labels=None, is_label=True): """ Combine the masks included in the labels list or all of the masks not in the list """ if not is_label: mask_labels = [ phrase for phrase in self.phrases if phrase not in mask_labels ] masks = [ row.get("mask") for row in self.results if row.get("label") in mask_labels ] if len(masks) == 0: return None combined_mask = masks[0] for mask in masks[1:]: combined_mask = ImageChops.lighter(combined_mask, mask) return combined_mask def get_body_box(self, remove_overlap=True, widen=0, elongate=0): body_box = get_bounding_box_mask(self.body_mask, widen=widen, elongate=elongate) if remove_overlap: body_box = self.remove_overlap(body_box) return body_box def remove_overlap(self, body_box): """ Remove mask regions that overlap with unwanted labels """ # convert mask to numpy array box_array = np.array(body_box) # combine the masks for those labels mask = self.combine_masks(mask_labels=self.labels, is_label=True) # convert mask to numpy array mask_array = np.array(mask) # where the mask array is white set the box array to black box_array[mask_array == 255] = 0 # convert the box array to an image mask_image = Image.fromarray(box_array) return mask_image if __name__ == "__main__": url = "https://sjc1.vultrobjects.com/photo-storage/images/525d1f68-314c-455b-a8b6-f5dc3fa044e4.jpeg" image_name = url.split("/")[-1] labels = ["face", "hair", "phone", "hand"] image = load_image(url) image_size = image.size # Get the original size of the image original_size = image.size # Create body mask body_mask = BodyMask( image, overlay="box", labels=labels, widen_box=50, elongate_box=10, dilate_factor=0, resize_to=640, is_label=False, remove_overlap=True, verbose=False, ) # Resize the image back to the original size image = body_mask.image.resize(original_size) body_mask.body_box.save(image_name)