|
from __future__ import annotations |
|
|
|
import inspect |
|
from abc import ABC, abstractmethod |
|
from typing import Any, Callable, Iterable, List, Mapping, Optional |
|
|
|
from diffusers.utils import logging |
|
from PIL import Image |
|
|
|
from asdff.utils import ( |
|
ADOutput, |
|
bbox_padding, |
|
composite, |
|
mask_dilate, |
|
mask_gaussian_blur, |
|
) |
|
from asdff.yolo import yolo_detector |
|
|
|
logger = logging.get_logger("diffusers") |
|
|
|
|
|
DetectorType = Callable[[Image.Image], Optional[List[Image.Image]]] |
|
|
|
|
|
def ordinal(n: int) -> str: |
|
d = {1: "st", 2: "nd", 3: "rd"} |
|
return str(n) + ("th" if 11 <= n % 100 <= 13 else d.get(n % 10, "th")) |
|
|
|
|
|
class AdPipelineBase(ABC): |
|
@property |
|
@abstractmethod |
|
def inpaint_pipeline(self) -> Callable: |
|
raise NotImplementedError |
|
|
|
@property |
|
@abstractmethod |
|
def txt2img_class(self) -> type: |
|
raise NotImplementedError |
|
|
|
def __call__( |
|
self, |
|
common: Mapping[str, Any] | None = None, |
|
txt2img_only: Mapping[str, Any] | None = None, |
|
inpaint_only: Mapping[str, Any] | None = None, |
|
images: Image.Image | Iterable[Image.Image] | None = None, |
|
detectors: DetectorType | Iterable[DetectorType] | None = None, |
|
mask_dilation: int = 4, |
|
mask_blur: int = 4, |
|
mask_padding: int = 32, |
|
): |
|
if common is None: |
|
common = {} |
|
if txt2img_only is None: |
|
txt2img_only = {} |
|
if inpaint_only is None: |
|
inpaint_only = {} |
|
if "strength" not in inpaint_only: |
|
inpaint_only = {**inpaint_only, "strength": 0.4} |
|
|
|
if detectors is None: |
|
detectors = [self.default_detector] |
|
elif not isinstance(detectors, Iterable): |
|
detectors = [detectors] |
|
|
|
if images is None: |
|
txt2img_output = self.process_txt2img(common, txt2img_only) |
|
txt2img_images = txt2img_output[0] |
|
else: |
|
if txt2img_only: |
|
msg = "Both `images` and `txt2img_only` are specified. if `images` is specified, `txt2img_only` is ignored." |
|
logger.warning(msg) |
|
|
|
txt2img_images = [images] if not isinstance(images, Iterable) else images |
|
|
|
init_images = [] |
|
final_images = [] |
|
|
|
for i, init_image in enumerate(txt2img_images): |
|
init_images.append(init_image.copy()) |
|
final_image = None |
|
|
|
for j, detector in enumerate(detectors): |
|
masks = detector(init_image) |
|
if masks is None: |
|
logger.info( |
|
f"No object detected on {ordinal(i + 1)} image with {ordinal(j + 1)} detector." |
|
) |
|
continue |
|
|
|
for k, mask in enumerate(masks): |
|
mask = mask.convert("L") |
|
mask = mask_dilate(mask, mask_dilation) |
|
bbox = mask.getbbox() |
|
if bbox is None: |
|
logger.info(f"No object in {ordinal(k + 1)} mask.") |
|
continue |
|
mask = mask_gaussian_blur(mask, mask_blur) |
|
bbox_padded = bbox_padding(bbox, init_image.size, mask_padding) |
|
|
|
inpaint_output = self.process_inpainting( |
|
common, |
|
inpaint_only, |
|
init_image, |
|
mask, |
|
bbox_padded, |
|
) |
|
inpaint_image = inpaint_output[0][0] |
|
|
|
final_image = composite( |
|
init_image, |
|
mask, |
|
inpaint_image, |
|
bbox_padded, |
|
) |
|
init_image = final_image |
|
|
|
if final_image is not None: |
|
final_images.append(final_image) |
|
|
|
return ADOutput(images=final_images, init_images=init_images) |
|
|
|
@property |
|
def default_detector(self) -> Callable[..., list[Image.Image] | None]: |
|
return yolo_detector |
|
|
|
def _get_txt2img_args( |
|
self, common: Mapping[str, Any], txt2img_only: Mapping[str, Any] |
|
): |
|
return {**common, **txt2img_only, "output_type": "pil"} |
|
|
|
def _get_inpaint_args( |
|
self, common: Mapping[str, Any], inpaint_only: Mapping[str, Any] |
|
): |
|
common = dict(common) |
|
sig = inspect.signature(self.inpaint_pipeline) |
|
if ( |
|
"control_image" in sig.parameters |
|
and "control_image" not in common |
|
and "image" in common |
|
): |
|
common["control_image"] = common.pop("image") |
|
return { |
|
**common, |
|
**inpaint_only, |
|
"num_images_per_prompt": 1, |
|
"output_type": "pil", |
|
} |
|
|
|
def process_txt2img( |
|
self, common: Mapping[str, Any], txt2img_only: Mapping[str, Any] |
|
): |
|
txt2img_args = self._get_txt2img_args(common, txt2img_only) |
|
return self.txt2img_class.__call__(self, **txt2img_args) |
|
|
|
def process_inpainting( |
|
self, |
|
common: Mapping[str, Any], |
|
inpaint_only: Mapping[str, Any], |
|
init_image: Image.Image, |
|
mask: Image.Image, |
|
bbox_padded: tuple[int, int, int, int], |
|
): |
|
crop_image = init_image.crop(bbox_padded) |
|
crop_mask = mask.crop(bbox_padded) |
|
inpaint_args = self._get_inpaint_args(common, inpaint_only) |
|
inpaint_args["image"] = crop_image |
|
inpaint_args["mask_image"] = crop_mask |
|
|
|
if "control_image" in inpaint_args: |
|
inpaint_args["control_image"] = inpaint_args["control_image"].resize( |
|
crop_image.size |
|
) |
|
return self.inpaint_pipeline(**inpaint_args) |
|
|