Stick2Body / app.py
tori29umai's picture
Update app.py
fe724f3 verified
raw
history blame
No virus
4.41 kB
import spaces
import gradio as gr
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
from PIL import Image
import os
import time
from utils.dl_utils import dl_cn_model, dl_cn_config, dl_lora_model
from utils.image_utils import resize_image_aspect_ratio, base_generation
from utils.prompt_utils import remove_duplicates
# Setup directories and download necessary models
path = os.getcwd()
cn_dir = f"{path}/controlnet"
lora_dir = f"{path}/lora"
os.makedirs(cn_dir, exist_ok=True)
os.makedirs(lora_dir, exist_ok=True)
dl_cn_model(cn_dir)
dl_cn_config(cn_dir)
dl_lora_model(lora_dir)
# Model loading function
def load_model(lora_dir, cn_dir):
dtype = torch.float16
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
"cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.load_lora_weights(lora_dir, weight_name="Fixhands_anime_bdsqlsz_V1.safetensors")
return pipe
# Image prediction and processing function
@spaces.GPU(duration=120)
def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
pipe = load_model(lora_dir, cn_dir)
input_image = Image.open(input_image_path)
base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
resize_image = resize_image_aspect_ratio(input_image)
resize_base_image = resize_image_aspect_ratio(base_image)
generator = torch.manual_seed(0)
last_time = time.time()
prompt = "masterpiece, best quality, simple background, white background, bald, nude, " + prompt
prompt = remove_duplicates(prompt)
print(prompt)
output_image = pipe(
image=resize_base_image,
control_image=resize_image,
strength=1.0,
prompt=prompt,
negative_prompt=negative_prompt,
controlnet_conditioning_scale=float(controlnet_scale),
generator=generator,
num_inference_steps=30,
eta=1.0,
).images[0]
print(f"Time taken: {time.time() - last_time}")
output_image = output_image.resize(input_image.size, Image.LANCZOS)
return output_image
class Img2Img:
def __init__(self):
self.demo = self.layout()
self.tagger_model = None
self.input_image_path = None
self.canny_image = None
def layout(self):
css = """
#intro{
max-width: 32rem;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
gr.Image(value="title.png", label="Title Image")
gr.Markdown("### Stickman to Posing Doll Image Converter\n\nこのアプリは棒人間をポーズ人形画像に変換するアプリです。入力する棒人間の形状は以下のリンクを参考にしてください。\nある程度形状が一致していれば手書きの棒人間でも認識されます\n\n[VRoid Hub Character Example](https://hub.vroid.com/characters/4765753841994800453/models/6738034259079048708)")
with gr.Row():
with gr.Column(scale=1):
self.input_image_path = gr.Image(label="Input Image", type='filepath')
self.prompt = gr.Textbox(label="Prompt", lines=3)
self.negative_prompt = gr.Textbox(label="Negative Prompt", lines=3, value="nsfw, nipples, bad anatomy, liquid fingers, low quality, worst quality, out of focus, ugly, error, jpeg artifacts, lowers, blurry, bokeh")
self.controlnet_scale = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.01, label="Controlnet Scale")
generate_button = gr.Button("Generate")
with gr.Column(scale=1):
self.output_image = gr.Image(type="pil", label="Output Image")
generate_button.click(
fn=predict,
inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
outputs=self.output_image
)
return demo
img2img = Img2Img()
img2img.demo.launch(share=True)