Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import spaces | |
import time | |
import gradio as gr | |
from PIL import Image | |
from diffusers import ( | |
DDIMScheduler, | |
DDIMInverseScheduler, | |
StableDiffusionDiffEditPipeline, | |
) | |
DEFAULT_SRC_PROMPT = "a woman" | |
DEFAULT_EDIT_PROMPT = "a woman, with red lips, 8k, high quality" | |
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" | |
# BASE_MODEL = "stabilityai/sdxl-turbo" | |
# BASE_MODEL = "stabilityai/stable-diffusion-2-1" | |
basepipeline = StableDiffusionDiffEditPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-1", | |
torch_dtype=torch.float16, | |
safety_checker=None, | |
use_safetensors=True, | |
) | |
basepipeline.scheduler = DDIMScheduler.from_config(basepipeline.scheduler.config) | |
basepipeline.inverse_scheduler = DDIMInverseScheduler.from_config(basepipeline.scheduler.config) | |
basepipeline.enable_model_cpu_offload() | |
basepipeline.enable_vae_slicing() | |
def image_to_image( | |
input_image: Image, | |
source_prompt: str, | |
target_prompt: str, | |
num_inference_steps: int, | |
start_step: int, | |
guidance_scale: float, | |
): | |
run_task_time = 0 | |
time_cost_str = '' | |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) | |
input_image = input_image.resize((768, 768), Image.LANCZOS) | |
mask_image = basepipeline.generate_mask( | |
image=input_image, | |
source_prompt=source_prompt, | |
target_prompt=target_prompt, | |
num_inference_steps = num_inference_steps, | |
guidance_scale=guidance_scale, | |
) | |
inv_latents = basepipeline.invert( | |
prompt=source_prompt, | |
image=input_image, | |
num_inference_steps = num_inference_steps, | |
guidance_scale=guidance_scale, | |
).latents | |
# get inverse latents by start step | |
# inv_latents = inv_latents[-(start_step + 1)][None] | |
output_image = basepipeline( | |
prompt=target_prompt, | |
mask_image=mask_image, | |
image_latents=inv_latents, | |
negative_prompt=source_prompt, | |
num_inference_steps = num_inference_steps, | |
guidance_scale=guidance_scale, | |
).images[0] | |
mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize(input_image.size, Image.LANCZOS) | |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) | |
return output_image, mask_image, time_cost_str | |
def get_time_cost(run_task_time, time_cost_str): | |
now_time = int(time.time()*1000) | |
if run_task_time == 0: | |
time_cost_str = 'start' | |
else: | |
if time_cost_str != '': | |
time_cost_str += f'-->' | |
time_cost_str += f'{now_time - run_task_time}' | |
run_task_time = now_time | |
return run_task_time, time_cost_str | |
def create_demo() -> gr.Blocks: | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT) | |
edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT) | |
with gr.Column(): | |
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Inference Steps") | |
start_step = gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Start Step") | |
with gr.Column(): | |
guidance_scale = gr.Slider(minimum=0, maximum=20, value=7.5, step=0.5, label="Guidance Scale") | |
g_btn = gr.Button("Edit Image") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="pil") | |
with gr.Column(): | |
output_image = gr.Image(label="Output Image", type="pil", interactive=False) | |
with gr.Column(): | |
mask_image = gr.Image(label="Mask Image", type="pil", interactive=False) | |
generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False) | |
g_btn.click( | |
fn=image_to_image, | |
inputs=[input_image, input_image_prompt, edit_prompt, num_inference_steps, start_step, guidance_scale], | |
outputs=[output_image, mask_image, generated_cost], | |
) | |
return demo |