stylech / app.py
m451h's picture
Update app.py
6e32234 verified
raw
history blame contribute delete
No virus
1.86 kB
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
from torchvision import transforms
from transformers import SamModel, SamProcessor
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image, make_image_grid
def modify_image(image_url, prompt, mask_id=4):
processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer_b3_clothes")
model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer_b3_clothes")
image = Image.open(image_url)
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
mask = (pred_seg == mask_id).numpy()
mask_image = Image.fromarray((mask * 255).astype('uint8'))
pipeline = AutoPipelineForInpainting.from_pretrained(
"redstonehero/ReV_Animated_Inpainting",
torch_dtype=torch.float16)
# pipeline.enable_model_cpu_offload()
image1 = pipeline(prompt=prompt,
num_inference_steps=24,
image=image,
mask_image=mask_image,
guidance_scale=3,
strength=1.0).images[0]
return make_image_grid([image1], rows = 1, cols = 1)
import gradio as gr
def gradio_wrapper(image, prompt, choice):
return modify_image(image, prompt, int(choice))
demo = gr.Interface(
fn=gradio_wrapper,
inputs=[
gr.Image(type="filepath"),
gr.Textbox(label="Prompt"),
gr.Radio(["4", "6"], label="Mask ID")
],
outputs=gr.Image()
)
demo.launch(inline=False)