sdxl-wombo-optimized / src /pipeline.py
mhussainahmad's picture
Update src/pipeline.py
e76c41a verified
raw
history blame contribute delete
No virus
1.06 kB
import torch
from PIL.Image import Image
from diffusers import StableDiffusionXLPipeline, LCMScheduler
from pipelines.models import TextToImageRequest
from torch import Generator
def load_pipeline() -> StableDiffusionXLPipeline:
pipeline = StableDiffusionXLPipeline.from_pretrained(
"./models/newdream-sdxl-20",
torch_dtype=torch.float16,
local_files_only=True,
).to("cuda")
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_lora_weights("./models/sdxl-lcmlora-1024-100k-3000steps")
pipeline(prompt="")
return pipeline
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
generator = Generator(pipeline.device).manual_seed(request.seed) if request.seed else None
return pipeline(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
width=request.width,
height=request.height,
generator=generator,
num_inference_steps=4,
guidance_scale=1.5,
).images[0]