from argparse import ArgumentParser from diffusers import DDIMScheduler, StableDiffusionXLImg2ImgPipeline import gradio as gr import torch import yaml from ctrl_x.pipelines.pipeline_sdxl import CtrlXStableDiffusionXLPipeline from ctrl_x.utils import * from ctrl_x.utils.sdxl import * import spaces parser = ArgumentParser() parser.add_argument("-m", "--model", type=str, default=None) # Optionally, load model checkpoint from single file args = parser.parse_args() torch.backends.cudnn.enabled = False # Sometimes necessary to suppress CUDNN_STATUS_NOT_SUPPORTED torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id_or_path = "stabilityai/stable-diffusion-xl-base-1.0" refiner_id_or_path = "stabilityai/stable-diffusion-xl-refiner-1.0" device = "cuda" if torch.cuda.is_available() else "cpu" #variant = "fp16" if device == "cuda" else "fp32" scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler") # TODO: Support other schedulers if args.model is None: pipe = CtrlXStableDiffusionXLPipeline.from_pretrained( model_id_or_path, scheduler=scheduler, torch_dtype=torch_dtype, use_safetensors=True ) else: print(f"Using weights {args.model} for SDXL base model.") pipe = CtrlXStableDiffusionXLPipeline.from_single_file(args.model, scheduler=scheduler, torch_dtype=torch_dtype) refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( refiner_id_or_path, scheduler=scheduler, text_encoder_2=pipe.text_encoder_2, vae=pipe.vae, torch_dtype=torch_dtype, use_safetensors=True, ) if torch.cuda.is_available(): pipe = pipe.to("cuda") refiner = refiner.to("cuda") def get_control_config(structure_schedule, appearance_schedule): s = structure_schedule a = appearance_schedule control_config =\ f"""control_schedule: # structure_conv structure_attn appearance_attn conv/attn encoder: # (num layers) 0: [[ ], [ ], [ ]] # 2/0 1: [[ ], [ ], [{a}, {a} ]] # 2/2 2: [[ ], [ ], [{a}, {a} ]] # 2/2 middle: [[ ], [ ], [ ]] # 2/1 decoder: 0: [[{s} ], [{s}, {s}, {s}], [0.0, {a}, {a}]] # 3/3 1: [[ ], [ ], [{a}, {a} ]] # 3/3 2: [[ ], [ ], [ ]] # 3/0 control_target: - [output_tensor] # structure_conv choices: {{hidden_states, output_tensor}} - [query, key] # structure_attn choices: {{query, key, value}} - [before] # appearance_attn choices: {{before, value, after}} self_recurrence_schedule: - [0.1, 0.5, 2] # format: [start, end, num_recurrence]""" return control_config css = """ .config textarea {font-family: monospace; font-size: 80%; white-space: pre} .mono {font-family: monospace} """ title = """

Ctrl-X: Controlling Structure and Appearance for Text-To-Image Generation Without Guidance

SDXL v1.0

[Page]            [Paper]            [Code]

""" description = """

Ctrl-X is a simple training-free and guidance-free framework for text-to-image (T2I) generation with structure and appearance control. Given structure and appearance images, Ctrl-X designs feedforward structure control to enable structure alignment with the arbitrary structure image and semantic-aware appearance transfer to facilitate the appearance transfer from the appearance image.

Here are some notes and tips for this demo:

Have fun! :D

""" @spaces.GPU def inference( structure_image, appearance_image, prompt, structure_prompt, appearance_prompt, positive_prompt="high quality", negative_prompt="ugly, blurry, dark, low res, unrealistic", guidance_scale=5.0, structure_guidance_scale=5.0, appearance_guidance_scale=5.0, num_inference_steps=28, eta=1.0, seed=42, width=1024, height=1024, structure_schedule=0.6, appearance_schedule=0.6, use_advanced_config=False, control_config="", progress=gr.Progress(track_tqdm=True) ): torch.manual_seed(seed) pipe.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = pipe.scheduler.timesteps print(f"\nUsing the following control config (use_advanced_config={use_advanced_config}):") if not use_advanced_config: control_config = get_control_config(structure_schedule, appearance_schedule) print(control_config, end="\n\n") config = yaml.safe_load(control_config) register_control( model = pipe, timesteps = timesteps, control_schedule = config["control_schedule"], control_target = config["control_target"], ) pipe.safety_checker = None pipe.requires_safety_checker = False self_recurrence_schedule = get_self_recurrence_schedule(config["self_recurrence_schedule"], num_inference_steps) pipe.set_progress_bar_config(desc="Ctrl-X inference") refiner.set_progress_bar_config(desc="Refiner") result, structure, appearance = pipe( prompt = prompt, structure_prompt = structure_prompt, appearance_prompt = appearance_prompt, structure_image = structure_image, appearance_image = appearance_image, num_inference_steps = num_inference_steps, negative_prompt = negative_prompt, positive_prompt = positive_prompt, height = height, width = width, guidance_scale = guidance_scale, structure_guidance_scale = structure_guidance_scale, appearance_guidance_scale = appearance_guidance_scale, eta = eta, output_type = "pil", return_dict = False, control_schedule = config["control_schedule"], self_recurrence_schedule = self_recurrence_schedule, ) result_refiner = refiner( image = pipe.refiner_args["latents"], prompt = pipe.refiner_args["prompt"], negative_prompt = pipe.refiner_args["negative_prompt"], height = height, width = width, num_inference_steps = num_inference_steps, guidance_scale = guidance_scale, guidance_rescale = 0.7, num_images_per_prompt = 1, eta = eta, output_type = "pil", ).images del pipe.refiner_args return [result[0], result_refiner[0], structure[0], appearance[0]] with gr.Blocks(theme=gr.themes.Default(), css=css, title="Ctrl-X (SDXL v1.0)") as app: gr.HTML(title) with gr.Accordion("Instructions", open=False): gr.HTML(description) with gr.Row(): with gr.Column(scale=45): with gr.Group(): kwargs = {} # {"width": 400, "height": 400} with gr.Row(): structure_image = gr.Image(label="Upload structure image (optional)", type="pil", **kwargs) appearance_image = gr.Image(label="Upload appearance image (optional)", type="pil", **kwargs) with gr.Row(): structure_prompt = gr.Textbox(label="Structure prompt (optional)", placeholder="Describes the structure image") appearance_prompt = gr.Textbox(label="Appearance prompt (optional)", placeholder="Describes the style image") with gr.Row(): prompt = gr.Textbox(label="Output prompt", placeholder="Prompt which describes the output image") with gr.Row(): positive_prompt = gr.Textbox(label="Positive prompt", value="high quality", placeholder="") negative_prompt = gr.Textbox(label="Negative prompt", value="ugly, blurry, dark, low res, unrealistic", placeholder="") with gr.Accordion("Advanced Options", open=False): with gr.Row(): guidance_scale = gr.Slider(label="Target guidance scale", value=5.0, minimum=1, maximum=10) structure_guidance_scale = gr.Slider(label="Structure guidance scale", value=5.0, minimum=1, maximum=10) appearance_guidance_scale = gr.Slider(label="Appearance guidance scale", value=5.0, minimum=1, maximum=10) with gr.Row(): num_inference_steps = gr.Slider(label="# inference steps", value=28, minimum=1, maximum=200, step=1) eta = gr.Slider(label="Eta (noise)", value=1.0, minimum=0, maximum=1.0, step=0.01) seed = gr.Slider(0, 2147483647, label="Seed", value=90095, step=1) with gr.Row(): width = gr.Slider(label="Width", value=1024, minimum=256, maximum=2048, step=pipe.vae_scale_factor) height = gr.Slider(label="Height", value=1024, minimum=256, maximum=2048, step=pipe.vae_scale_factor) with gr.Row(): structure_schedule = gr.Slider(label="Structure schedule", value=0.6, minimum=0.0, maximum=1.0, step=0.01, scale=2) appearance_schedule = gr.Slider(label="Appearance schedule", value=0.6, minimum=0.0, maximum=1.0, step=0.01, scale=2) use_advanced_config = gr.Checkbox(label="Use advanced config", value=False, scale=1) with gr.Row(): control_config = gr.Textbox( label="Advanced control config", lines=20, value=get_control_config(0.6, 0.6), elem_classes=["config"], visible=False, ) use_advanced_config.change( fn=lambda value: gr.update(visible=value), inputs=use_advanced_config, outputs=control_config, ) with gr.Row(): generate = gr.Button(value="Run") with gr.Column(scale=55): with gr.Group(): with gr.Row(): result_refiner = gr.Image(label="Output image w/ refiner", format="jpg", **kwargs) with gr.Row(): result = gr.Image(label="Output image", format="jpg", **kwargs) structure_recon = gr.Image(label="Structure image", format="jpg", **kwargs) appearance_recon = gr.Image(label="Style image", format="jpg", **kwargs) inputs = [ structure_image, appearance_image, prompt, structure_prompt, appearance_prompt, positive_prompt, negative_prompt, guidance_scale, structure_guidance_scale, appearance_guidance_scale, num_inference_steps, eta, seed, width, height, structure_schedule, appearance_schedule, use_advanced_config, control_config, ] outputs = [result, result_refiner, structure_recon, appearance_recon] generate.click(inference, inputs=inputs, outputs=outputs) examples = gr.Examples( [ [ "assets/images/horse__point_cloud.jpg", "assets/images/horse.jpg", "a photo of a horse standing on grass", "a 3D point cloud of a horse", "", ], [ "assets/images/cat__mesh.jpg", "assets/images/tiger.jpg", "a photo of a tiger standing on snow", "a 3D mesh of a cat", "", ], [ "assets/images/dog__sketch.jpg", "assets/images/squirrel.jpg", "a photo of a squirrel", "a sketch of a dog", "", ], [ "assets/images/living_room__seg.jpg", "assets/images/van_gogh.jpg", "a Van Gogh painting of a living room", "a segmentation map of a living room", "", ], [ "assets/images/bedroom__sketch.jpg", "assets/images/living_room_modern.jpg", "a sketch of a bedroom", "a photo of a modern bedroom during sunset", "", ], [ "assets/images/running__pose.jpg", "assets/images/man_park.jpg", "a photo of a man running in a park", "a pose image of a person running", "", ], [ "assets/images/fruit_bowl.jpg", "assets/images/grapes.jpg", "a photo of a bowl of grapes in the trees", "a photo of a bowl of fruits", "", ], [ "assets/images/bear_avocado__spatext.jpg", None, "a realistic photo of a bear and an avocado in a forest", "a segmentation map of a bear and an avocado", "", ], [ "assets/images/cat__point_cloud.jpg", None, "an embroidery of a white cat sitting on a rock under the night sky", "a 3D point cloud of a cat", "", ], [ "assets/images/library__mesh.jpg", None, "a Polaroid photo of an old library, sunlight streaming in", "a 3D mesh of a library", "", ], [ "assets/images/knight__humanoid.jpg", None, "a photo of a medieval soldier standing on a barren field, raining", "a 3D model of a person holding a sword and shield", "", ], [ "assets/images/person__mesh.jpg", None, "a photo of a Karate man performing in a cyberpunk city at night", "a 3D mesh of a person", "", ], ], [ structure_image, appearance_image, prompt, structure_prompt, appearance_prompt, ], examples_per_page=50, cache_examples="lazy", fn=inference, outputs=[result, result_refiner, structure_recon, appearance_recon] ) app.launch(debug=False, share=False)