import spaces import gradio as gr import torch from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL from PIL import Image import os import time from utils.dl_utils import dl_cn_model, dl_cn_config, dl_lora_model from utils.image_utils import resize_image_aspect_ratio, base_generation from utils.prompt_utils import remove_duplicates # Setup directories and download necessary models path = os.getcwd() cn_dir = f"{path}/controlnet" lora_dir = f"{path}/lora" os.makedirs(cn_dir, exist_ok=True) os.makedirs(lora_dir, exist_ok=True) dl_cn_model(cn_dir) dl_cn_config(cn_dir) dl_lora_model(lora_dir) # Model loading function def load_model(lora_dir, cn_dir): dtype = torch.float16 vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True) pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload() pipe.load_lora_weights(lora_dir, weight_name="Fixhands_anime_bdsqlsz_V1.safetensors") return pipe # Image prediction and processing function @spaces.GPU(duration=120) def predict(input_image_path, prompt, negative_prompt, controlnet_scale): pipe = load_model(lora_dir, cn_dir) input_image = Image.open(input_image_path) base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB") resize_image = resize_image_aspect_ratio(input_image) resize_base_image = resize_image_aspect_ratio(base_image) generator = torch.manual_seed(0) last_time = time.time() prompt = "masterpiece, best quality, simple background, white background, bald, nude, " + prompt prompt = remove_duplicates(prompt) print(prompt) output_image = pipe( image=resize_base_image, control_image=resize_image, strength=1.0, prompt=prompt, negative_prompt=negative_prompt, controlnet_conditioning_scale=float(controlnet_scale), generator=generator, num_inference_steps=30, eta=1.0, ).images[0] print(f"Time taken: {time.time() - last_time}") output_image = output_image.resize(input_image.size, Image.LANCZOS) return output_image class Img2Img: def __init__(self): self.demo = self.layout() self.tagger_model = None self.input_image_path = None self.canny_image = None def layout(self): css = """ #intro{ max-width: 32rem; text-align: center; margin: 0 auto; } """ with gr.Blocks(css=css) as demo: with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Stickman to Posing Doll Image Converter\nこのアプリは棒人間をポーズ人形画像に変換するアプリです。\n入力する棒人間の形状は以下のリンクを参考にしてください。\n[VRoid Hub Character Example](https://hub.vroid.com/characters/4765753841994800453/models/6738034259079048708)\nIf your stick figure resembles the linked shape, it should work reasonably well even if hand-drawn.") self.input_image_path = gr.Image(label="Input Image", type='filepath') self.prompt = gr.Textbox(label="Prompt", lines=3) self.negative_prompt = gr.Textbox(label="Negative Prompt", lines=3, value="nsfw, nipples, bad anatomy, liquid fingers, low quality, worst quality, out of focus, ugly, error, jpeg artifacts, lowers, blurry, bokeh") self.controlnet_scale = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.01, label="Controlnet Scale") generate_button = gr.Button("Generate") with gr.Column(scale=1): self.output_image = gr.Image(type="pil", label="Output Image") generate_button.click( fn=predict, inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale], outputs=self.output_image ) return demo img2img = Img2Img() img2img.demo.launch(share=True)