|
import logging |
|
from typing import Any, Optional |
|
import torch |
|
from diffusers import ( |
|
AutoencoderKL, |
|
DDPMScheduler, |
|
EulerDiscreteScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
LCMScheduler, |
|
Transformer2DModel, |
|
UNet2DConditionModel, |
|
) |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
|
|
from models.RewardPixart import RewardPixartPipeline, freeze_params |
|
from models.RewardStableDiffusion import RewardStableDiffusion |
|
from models.RewardStableDiffusionXL import RewardStableDiffusionXL |
|
from models.RewardFlux import RewardFluxPipeline |
|
|
|
|
|
def get_model( |
|
model_name: str, |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
cache_dir: str, |
|
memsave: bool = False, |
|
enable_sequential_cpu_offload: bool = False, |
|
): |
|
logging.info(f"Loading model: {model_name}") |
|
if model_name == "sd-turbo": |
|
pipe = RewardStableDiffusion.from_pretrained( |
|
"stabilityai/sd-turbo", |
|
torch_dtype=dtype, |
|
variant="fp16", |
|
cache_dir=cache_dir, |
|
memsave=memsave, |
|
) |
|
pipe = pipe.to(device, dtype) |
|
elif model_name == "sdxl-turbo": |
|
vae = AutoencoderKL.from_pretrained( |
|
"madebyollin/sdxl-vae-fp16-fix", |
|
torch_dtype=torch.float16, |
|
cache_dir=cache_dir, |
|
) |
|
pipe = RewardStableDiffusionXL.from_pretrained( |
|
"stabilityai/sdxl-turbo", |
|
vae=vae, |
|
torch_dtype=dtype, |
|
variant="fp16", |
|
use_safetensors=True, |
|
cache_dir=cache_dir, |
|
memsave=memsave, |
|
) |
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( |
|
pipe.scheduler.config, timestep_spacing="trailing" |
|
) |
|
pipe = pipe.to(device, dtype) |
|
elif model_name == "pixart": |
|
pipe = RewardPixartPipeline.from_pretrained( |
|
"PixArt-alpha/PixArt-XL-2-1024-MS", |
|
torch_dtype=dtype, |
|
cache_dir=cache_dir, |
|
memsave=memsave, |
|
) |
|
pipe.transformer = Transformer2DModel.from_pretrained( |
|
"PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", |
|
subfolder="transformer", |
|
torch_dtype=dtype, |
|
cache_dir=cache_dir, |
|
) |
|
pipe.scheduler = DDPMScheduler.from_pretrained( |
|
"PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", |
|
subfolder="scheduler", |
|
cache_dir=cache_dir, |
|
) |
|
|
|
|
|
pipe.text_encoder.to_bettertransformer() |
|
pipe.transformer.eval() |
|
freeze_params(pipe.transformer.parameters()) |
|
pipe.transformer.enable_gradient_checkpointing() |
|
pipe = pipe.to(device) |
|
elif model_name == "hyper-sd": |
|
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
|
repo_name = "ByteDance/Hyper-SD" |
|
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors" |
|
|
|
unet = UNet2DConditionModel.from_config( |
|
base_model_id, subfolder="unet", cache_dir=cache_dir |
|
).to(device, dtype) |
|
unet.load_state_dict( |
|
load_file( |
|
hf_hub_download(repo_name, ckpt_name, cache_dir=cache_dir), |
|
device="cuda", |
|
) |
|
) |
|
pipe = RewardStableDiffusionXL.from_pretrained( |
|
base_model_id, |
|
unet=unet, |
|
torch_dtype=dtype, |
|
variant="fp16", |
|
cache_dir=cache_dir, |
|
is_hyper=True, |
|
memsave=memsave, |
|
) |
|
|
|
pipe.scheduler = LCMScheduler.from_config( |
|
pipe.scheduler.config, cache_dir=cache_dir |
|
) |
|
pipe = pipe.to(device, dtype) |
|
|
|
pipe.vae = pipe.vae.to(dtype=torch.float32) |
|
elif model_name == "flux": |
|
pipe = RewardFluxPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-schnell", |
|
torch_dtype=torch.bfloat16, |
|
cache_dir=cache_dir, |
|
) |
|
pipe.to(device, dtype) |
|
else: |
|
raise ValueError(f"Unknown model name: {model_name}") |
|
if enable_sequential_cpu_offload: |
|
pipe.enable_sequential_cpu_offload() |
|
return pipe |
|
|
|
|
|
def get_multi_apply_fn( |
|
model_type: str, |
|
seed: int, |
|
pipe: Optional[Any] = None, |
|
cache_dir: Optional[str] = None, |
|
device: Optional[torch.device] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
): |
|
generator = torch.Generator("cuda").manual_seed(seed) |
|
if model_type == "flux": |
|
return lambda latents, prompt: torch.no_grad(pipe.apply)( |
|
latents=latents, |
|
prompt=prompt, |
|
num_inference_steps=4, |
|
generator=generator, |
|
) |
|
elif model_type == "sdxl": |
|
vae = AutoencoderKL.from_pretrained( |
|
"madebyollin/sdxl-vae-fp16-fix", |
|
torch_dtype=torch.float16, |
|
cache_dir=cache_dir, |
|
) |
|
pipe = RewardStableDiffusionXL.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
vae=vae, |
|
use_safetensors=True, |
|
cache_dir=cache_dir, |
|
) |
|
pipe = pipe.to(device, dtype) |
|
pipe.enable_sequential_cpu_offload() |
|
return lambda latents, prompt: torch.no_grad(pipe.apply)( |
|
latents=latents, |
|
prompt=prompt, |
|
guidance_scale=5.0, |
|
num_inference_steps=50, |
|
generator=generator, |
|
) |
|
elif model_type == "sd2": |
|
sd2_base = "stabilityai/stable-diffusion-2-1-base" |
|
scheduler = EulerDiscreteScheduler.from_pretrained( |
|
sd2_base, |
|
subfolder="scheduler", |
|
cache_dir=cache_dir, |
|
) |
|
pipe = RewardStableDiffusion.from_pretrained( |
|
sd2_base, |
|
torch_dtype=dtype, |
|
cache_dir=cache_dir, |
|
scheduler=scheduler, |
|
) |
|
pipe = pipe.to(device, dtype) |
|
pipe.enable_sequential_cpu_offload() |
|
return lambda latents, prompt: torch.no_grad(pipe.apply)( |
|
latents=latents, |
|
prompt=prompt, |
|
guidance_scale=7.5, |
|
num_inference_steps=50, |
|
generator=generator, |
|
) |
|
else: |
|
raise ValueError(f"Unknown model type: {model_type}") |