|
import logging |
|
|
|
import torch |
|
from diffusers import (AutoencoderKL, DDPMScheduler, |
|
EulerAncestralDiscreteScheduler, LCMScheduler, |
|
Transformer2DModel, UNet2DConditionModel) |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
|
|
from models.RewardPixart import RewardPixartPipeline, freeze_params |
|
from models.RewardStableDiffusion import RewardStableDiffusion |
|
from models.RewardStableDiffusionXL import RewardStableDiffusionXL |
|
|
|
|
|
def get_model( |
|
model_name: str, |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
cache_dir: str, |
|
memsave: bool = False, |
|
): |
|
logging.info(f"Loading model: {model_name}") |
|
if model_name == "sd-turbo": |
|
pipe = RewardStableDiffusion.from_pretrained( |
|
"stabilityai/sd-turbo", |
|
torch_dtype=dtype, |
|
variant="fp16", |
|
cache_dir=cache_dir, |
|
memsave=memsave, |
|
) |
|
pipe = pipe.to(device, dtype) |
|
elif model_name == "sdxl-turbo": |
|
vae = AutoencoderKL.from_pretrained( |
|
"madebyollin/sdxl-vae-fp16-fix", |
|
torch_dtype=torch.float16, |
|
cache_dir=cache_dir, |
|
) |
|
pipe = RewardStableDiffusionXL.from_pretrained( |
|
"stabilityai/sdxl-turbo", |
|
vae=vae, |
|
torch_dtype=dtype, |
|
variant="fp16", |
|
use_safetensors=True, |
|
cache_dir=cache_dir, |
|
memsave=memsave, |
|
) |
|
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( |
|
pipe.scheduler.config, timestep_spacing="trailing" |
|
) |
|
pipe = pipe.to(device, dtype) |
|
elif model_name == "pixart": |
|
pipe = RewardPixartPipeline.from_pretrained( |
|
"PixArt-alpha/PixArt-XL-2-1024-MS", |
|
torch_dtype=dtype, |
|
cache_dir=cache_dir, |
|
memsave=memsave, |
|
) |
|
pipe.transformer = Transformer2DModel.from_pretrained( |
|
"PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", |
|
subfolder="transformer", |
|
torch_dtype=dtype, |
|
cache_dir=cache_dir, |
|
) |
|
pipe.scheduler = DDPMScheduler.from_pretrained( |
|
"PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", |
|
subfolder="scheduler", |
|
cache_dir=cache_dir, |
|
) |
|
|
|
|
|
pipe.text_encoder.to_bettertransformer() |
|
pipe.transformer.eval() |
|
freeze_params(pipe.transformer.parameters()) |
|
pipe.transformer.enable_gradient_checkpointing() |
|
pipe = pipe.to(device) |
|
elif model_name == "hyper-sd": |
|
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
|
repo_name = "ByteDance/Hyper-SD" |
|
ckpt_name = "Hyper-SDXL-1step-Unet.safetensors" |
|
|
|
unet = UNet2DConditionModel.from_config( |
|
base_model_id, subfolder="unet", cache_dir=cache_dir |
|
).to(device, dtype) |
|
unet.load_state_dict( |
|
load_file( |
|
hf_hub_download(repo_name, ckpt_name, cache_dir=cache_dir), |
|
device="cuda", |
|
) |
|
) |
|
pipe = RewardStableDiffusionXL.from_pretrained( |
|
base_model_id, |
|
unet=unet, |
|
torch_dtype=dtype, |
|
variant="fp16", |
|
cache_dir=cache_dir, |
|
is_hyper=True, |
|
memsave=memsave, |
|
) |
|
|
|
pipe.scheduler = LCMScheduler.from_config( |
|
pipe.scheduler.config, cache_dir=cache_dir |
|
) |
|
pipe = pipe.to(device, dtype) |
|
|
|
pipe.vae = pipe.vae.to(dtype=torch.float32) |
|
|
|
else: |
|
raise ValueError(f"Unknown model name: {model_name}") |
|
return pipe |
|
|