|
import gradio as gr |
|
import os |
|
from ultralytics import YOLO |
|
from yolo.BodyMask import BodyMask |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from matplotlib import patches |
|
from skimage.transform import resize |
|
from PIL import Image |
|
import io |
|
|
|
model_id = os.path.abspath("yolo-human-parse-epoch-125.pt") |
|
|
|
|
|
def display_image_with_masks(image, results, cols=4): |
|
|
|
image_np = np.array(image) |
|
|
|
|
|
if image_np.ndim != 3 or image_np.shape[2] != 3: |
|
raise ValueError("Image must be a 3-dimensional array with 3 color channels") |
|
|
|
|
|
n = len(results) |
|
rows = (n + cols - 1) // cols |
|
|
|
|
|
fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows)) |
|
axs = np.array(axs).reshape(-1) |
|
|
|
for i, result in enumerate(results): |
|
mask = result["mask"] |
|
label = result["label"] |
|
score = float(result["score"]) |
|
|
|
|
|
mask_np = np.array(mask) |
|
if mask_np.shape != image_np.shape[:2]: |
|
mask_np = resize( |
|
mask_np, image_np.shape[:2], mode="constant", anti_aliasing=False |
|
) |
|
mask_np = (mask_np > 0.5).astype( |
|
np.uint8 |
|
) |
|
|
|
|
|
overlay = np.zeros_like(image_np) |
|
overlay[mask_np > 0] = [0, 0, 255] |
|
|
|
|
|
combined = image_np.copy() |
|
indices = np.where(mask_np > 0) |
|
combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5 |
|
|
|
|
|
ax = axs[i] |
|
ax.imshow(combined) |
|
ax.axis("off") |
|
ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12) |
|
rect = patches.Rectangle( |
|
(0, 0), |
|
image_np.shape[1], |
|
image_np.shape[0], |
|
linewidth=1, |
|
edgecolor="r", |
|
facecolor="none", |
|
) |
|
ax.add_patch(rect) |
|
|
|
|
|
for idx in range(i + 1, rows * cols): |
|
axs[idx].axis("off") |
|
|
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png") |
|
buf.seek(0) |
|
|
|
|
|
plt.close(fig) |
|
|
|
return buf |
|
|
|
|
|
def perform_segmentation(input_image): |
|
bm = BodyMask(input_image, model_id=model_id, resize_to=640) |
|
results = bm.results |
|
buf = display_image_with_masks(input_image, results) |
|
|
|
|
|
img = Image.open(buf) |
|
return img |
|
|
|
|
|
|
|
example_images = [ |
|
os.path.join("sample_images", f) |
|
for f in os.listdir("sample_images") |
|
if f.endswith((".png", ".jpg", ".jpeg")) |
|
] |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# YOLO Segmentation Demo with BodyMask") |
|
gr.Markdown( |
|
"Upload an image or select an example to see the YOLO segmentation results." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="pil", label="Input Image", height=512) |
|
segment_button = gr.Button("Perform Segmentation") |
|
|
|
output_image = gr.Image(label="Segmentation Result") |
|
|
|
gr.Examples( |
|
examples=example_images, |
|
inputs=input_image, |
|
outputs=output_image, |
|
fn=perform_segmentation, |
|
cache_examples=True, |
|
) |
|
|
|
segment_button.click( |
|
fn=perform_segmentation, |
|
inputs=input_image, |
|
outputs=output_image, |
|
) |
|
|
|
demo.launch() |
|
|