ctrlx-ssd-1b / run_ctrlx.py
twodgirl's picture
Update run_ctrlx.py
fd1b9b3 verified
raw
history blame
No virus
9.05 kB
from argparse import ArgumentParser
from datetime import datetime
from diffusers import DDIMScheduler, StableDiffusionXLImg2ImgPipeline
from diffusers.utils import load_image
from os import makedirs, path
from pipelines.pipeline_sdxl import CtrlXStableDiffusionXLPipeline
import torch
from time import time
from utils import *
from utils.media import preprocess
from utils.sdxl import *
import yaml
###
# Code from genforce/ctrl-x/run_ctrlx.py
@torch.no_grad()
def inference(
pipe, refiner, device,
structure_image, appearance_image,
prompt, structure_prompt, appearance_prompt,
positive_prompt, negative_prompt,
guidance_scale, structure_guidance_scale, appearance_guidance_scale,
num_inference_steps, eta, seed,
width, height,
structure_schedule, appearance_schedule,
):
seed_everything(seed)
# Process images.
# Moved from CtrlXStableDiffusionXLPipeline.__call__.
if structure_image is not None and isinstance(args.structure_image, str):
structure_image = load_image(args.structure_image)
structure_image = preprocess(structure_image, pipe.image_processor,
height=height, width=width, resize_mode="crop")
if appearance_image is not None:
appearance_image = load_image(appearance_image)
appearance_image = preprocess(appearance_image, pipe.image_processor,
height=height, width=width, resize_mode="crop")
# Scheduler.
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipe.scheduler.timesteps
control_config = get_control_config(structure_schedule, appearance_schedule)
print(f"\nUsing the following control config:\n{control_config}\n")
config = yaml.safe_load(control_config)
register_control(
model=pipe,
timesteps=timesteps,
control_schedule=config["control_schedule"],
control_target=config["control_target"],
)
# Pipe settings.
pipe.safety_checker = None
pipe.requires_safety_checker = False
self_recurrence_schedule = get_self_recurrence_schedule(config["self_recurrence_schedule"], num_inference_steps)
pipe.set_progress_bar_config(desc="Ctrl-X inference")
# Inference.
result, structure, appearance = pipe(
prompt=prompt,
structure_prompt=structure_prompt,
appearance_prompt=appearance_prompt,
structure_image=structure_image,
appearance_image=appearance_image,
num_inference_steps=num_inference_steps,
negative_prompt=negative_prompt,
positive_prompt=positive_prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
structure_guidance_scale=structure_guidance_scale,
appearance_guidance_scale=appearance_guidance_scale,
eta=eta,
output_type="pil",
return_dict=False,
control_schedule=config["control_schedule"],
self_recurrence_schedule=self_recurrence_schedule,
)
result_refiner = [None]
del pipe.refiner_args
return result[0], result_refiner[0], structure[0], appearance[0]
@torch.no_grad()
def main(args):
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id_or_path = "OzzyGT/SSD-1B"
# refiner_id_or_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
device = "cuda" if torch.cuda.is_available() else "cpu"
variant = "fp16" if device == "cuda" else "fp32"
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
if args.model is None:
pipe = CtrlXStableDiffusionXLPipeline.from_pretrained(
model_id_or_path, scheduler=scheduler, torch_dtype=torch_dtype, variant=variant, use_safetensors=True,
)
else:
print(f"Using weights {args.model} for SDXL base model.")
pipe = CtrlXStableDiffusionXLPipeline.from_single_file(args.model, scheduler=scheduler, torch_dtype=torch_dtype)
if args.model_offload or args.sequential_offload:
try:
import accelerate # Checking if accelerate is installed for Model/CPU offloading
except:
raise ModuleNotFoundError("`accelerate` must be installed for Model/CPU offloading.")
if args.sequential_offload:
pipe.enable_sequential_cpu_offload()
elif args.model_offload:
pipe.enable_model_cpu_offload()
else:
pipe = pipe.to(device)
model_load_print = "Base model "
if not args.disable_refiner:
model_load_print += "+ refiner "
if args.sequential_offload:
model_load_print += "loaded with sequential CPU offloading."
elif args.model_offload:
model_load_print += "loaded with model CPU offloading."
else:
model_load_print += "loaded."
print(f"{model_load_print} Running on device: {device}.")
t = time()
result, result_refiner, structure, appearance = inference(
pipe=pipe,
refiner=None,
device=device,
structure_image=args.structure_image,
appearance_image=args.appearance_image,
prompt=args.prompt,
structure_prompt=args.structure_prompt,
appearance_prompt=args.appearance_prompt,
positive_prompt=args.positive_prompt,
negative_prompt=args.negative_prompt,
guidance_scale=args.guidance_scale,
structure_guidance_scale=args.structure_guidance_scale,
appearance_guidance_scale=args.appearance_guidance_scale,
num_inference_steps=args.num_inference_steps,
eta=args.eta,
seed=args.seed,
width=args.width,
height=args.height,
structure_schedule=args.structure_schedule,
appearance_schedule=args.appearance_schedule,
)
makedirs(args.output_folder, exist_ok=True)
prefix = "ctrlx__" + datetime.now().strftime("%Y%m%d_%H%M%S")
structure.save(path.join(args.output_folder, f"{prefix}__structure.jpg"), quality=JPEG_QUALITY)
appearance.save(path.join(args.output_folder, f"{prefix}__appearance.jpg"), quality=JPEG_QUALITY)
result.save(path.join(args.output_folder, f"{prefix}__result.jpg"), quality=JPEG_QUALITY)
if result_refiner is not None:
result_refiner.save(path.join(args.output_folder, f"{prefix}__result_refiner.jpg"), quality=JPEG_QUALITY)
if args.benchmark:
inference_time = time() - t
peak_memory_usage = torch.cuda.max_memory_reserved()
print(f"Inference time: {inference_time:.2f}s")
print(f"Peak memory usage: {peak_memory_usage / pow(1024, 3):.2f}GiB")
print("Done.")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--structure_image", "-si", type=str, default=None)
parser.add_argument("--appearance_image", "-ai", type=str, default=None)
parser.add_argument("--prompt", "-p", type=str, required=True)
parser.add_argument("--structure_prompt", "-sp", type=str, default="")
parser.add_argument("--appearance_prompt", "-ap", type=str, default="")
parser.add_argument("--positive_prompt", "-pp", type=str, default="high quality")
parser.add_argument("--negative_prompt", "-np", type=str, default="ugly, blurry, dark, low res, unrealistic")
parser.add_argument("--guidance_scale", "-g", type=float, default=5.0)
parser.add_argument("--structure_guidance_scale", "-sg", type=float, default=5.0)
parser.add_argument("--appearance_guidance_scale", "-ag", type=float, default=5.0)
parser.add_argument("--num_inference_steps", "-n", type=int, default=50)
parser.add_argument("--eta", "-e", type=float, default=1.0)
parser.add_argument("--seed", "-s", type=int, default=90095)
parser.add_argument("--width", "-W", type=int, default=1024)
parser.add_argument("--height", "-H", type=int, default=1024)
parser.add_argument("--structure_schedule", "-ss", type=float, default=0.6)
parser.add_argument("--appearance_schedule", "-as", type=float, default=0.6)
parser.add_argument("--output_folder", "-o", type=str, default="./results")
parser.add_argument(
"-mo", "--model_offload", action="store_true",
help="Model CPU offload, lowers memory usage with slight runtime increase. `accelerate` must be installed.",
)
parser.add_argument(
"-so", "--sequential_offload", action="store_true",
help=(
"Sequential layer CPU offload, significantly lowers memory usage with massive runtime increase."
"`accelerate` must be installed. If both model_offload and sequential_offload are set, then use the latter."
),
)
parser.add_argument("-r", "--disable_refiner", action="store_true")
parser.add_argument("-m", "--model", type=str, default=None, help="Optionally, load model safetensors.")
parser.add_argument("-b", "--benchmark", action="store_true", help="Show inference time and max memory usage.")
args = parser.parse_args()
main(args)