FluxMusicGUI / utils.py
flosstradamus's picture
Upload 194 files
afe1a07 verified
raw
history blame
1.4 kB
import torch
from modules.autoencoder import AutoEncoder, AutoEncoderParams
from modules.conditioner import HFEmbedder
from safetensors.torch import load_file as load_sft
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
def load_clap(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
return HFEmbedder("laion/larger_clap_music", max_length=256, torch_dtype=torch.bfloat16).to(device)
def load_ae(ckpt_path, device: str | torch.device = "cuda",) -> AutoEncoder:
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
# Loading the autoencoder
ae = AutoEncoder(ae_params)
sd = load_sft(ckpt_path,)
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
ae.to(device)
return ae