K-Sort-Arena / model /models /huggingface_models.py
ksort's picture
Update ssh
671eb7e
raw
history blame
No virus
1.22 kB
from diffusers import DiffusionPipeline
from diffusers import AutoPipelineForText2Image
import torch
def load_huggingface_model(model_name, model_type):
if model_name == "SD-turbo":
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
elif model_name == "SDXL-turbo":
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
else:
raise NotImplementedError
pipe = pipe.to("cuda")
# if model_name == "SD-turbo":
# pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
# elif model_name == "SDXL-turbo":
# pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
# else:
# raise NotImplementedError
# pipe = pipe.to("cpu")
return pipe
if __name__ == "__main__":
for name in ["SD-turbo", "SDXL-turbo"]:
load_huggingface_model(name, "text2image")
# for name in ["IF-I-XL-v1.0"]:
# pipe = load_huggingface_model(name, 'text2image')
# pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)