turbo_inversion / app_diffedit.py
zhiweili
change base model
5068971
raw
history blame contribute delete
No virus
4.22 kB
import torch
import spaces
import time
import gradio as gr
from PIL import Image
from diffusers import (
DDIMScheduler,
DDIMInverseScheduler,
StableDiffusionDiffEditPipeline,
)
DEFAULT_SRC_PROMPT = "a woman"
DEFAULT_EDIT_PROMPT = "a woman, with red lips, 8k, high quality"
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
# BASE_MODEL = "stabilityai/sdxl-turbo"
# BASE_MODEL = "stabilityai/stable-diffusion-2-1"
basepipeline = StableDiffusionDiffEditPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
torch_dtype=torch.float16,
safety_checker=None,
use_safetensors=True,
)
basepipeline.scheduler = DDIMScheduler.from_config(basepipeline.scheduler.config)
basepipeline.inverse_scheduler = DDIMInverseScheduler.from_config(basepipeline.scheduler.config)
basepipeline.enable_model_cpu_offload()
basepipeline.enable_vae_slicing()
@spaces.GPU(duration=30)
def image_to_image(
input_image: Image,
source_prompt: str,
target_prompt: str,
num_inference_steps: int,
start_step: int,
guidance_scale: float,
):
run_task_time = 0
time_cost_str = ''
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
input_image = input_image.resize((768, 768), Image.LANCZOS)
mask_image = basepipeline.generate_mask(
image=input_image,
source_prompt=source_prompt,
target_prompt=target_prompt,
num_inference_steps = num_inference_steps,
guidance_scale=guidance_scale,
)
inv_latents = basepipeline.invert(
prompt=source_prompt,
image=input_image,
num_inference_steps = num_inference_steps,
guidance_scale=guidance_scale,
).latents
# get inverse latents by start step
# inv_latents = inv_latents[-(start_step + 1)][None]
output_image = basepipeline(
prompt=target_prompt,
mask_image=mask_image,
image_latents=inv_latents,
negative_prompt=source_prompt,
num_inference_steps = num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize(input_image.size, Image.LANCZOS)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
return output_image, mask_image, time_cost_str
def get_time_cost(run_task_time, time_cost_str):
now_time = int(time.time()*1000)
if run_task_time == 0:
time_cost_str = 'start'
else:
if time_cost_str != '':
time_cost_str += f'-->'
time_cost_str += f'{now_time - run_task_time}'
run_task_time = now_time
return run_task_time, time_cost_str
def create_demo() -> gr.Blocks:
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
with gr.Column():
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Inference Steps")
start_step = gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Start Step")
with gr.Column():
guidance_scale = gr.Slider(minimum=0, maximum=20, value=7.5, step=0.5, label="Guidance Scale")
g_btn = gr.Button("Edit Image")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
with gr.Column():
output_image = gr.Image(label="Output Image", type="pil", interactive=False)
with gr.Column():
mask_image = gr.Image(label="Mask Image", type="pil", interactive=False)
generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
g_btn.click(
fn=image_to_image,
inputs=[input_image, input_image_prompt, edit_prompt, num_inference_steps, start_step, guidance_scale],
outputs=[output_image, mask_image, generated_cost],
)
return demo