ksort commited on
Commit
671eb7e
1 Parent(s): 6dc341c

Update ssh

Browse files
Files changed (1) hide show
  1. model/models/huggingface_models.py +10 -10
model/models/huggingface_models.py CHANGED
@@ -6,20 +6,20 @@ import torch
6
 
7
 
8
  def load_huggingface_model(model_name, model_type):
9
- # if model_name == "SD-turbo":
10
- # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
11
- # elif model_name == "SDXL-turbo":
12
- # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
13
- # else:
14
- # raise NotImplementedError
15
- # pipe = pipe.to("cuda")
16
  if model_name == "SD-turbo":
17
- pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
18
  elif model_name == "SDXL-turbo":
19
- pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
20
  else:
21
  raise NotImplementedError
22
- pipe = pipe.to("cpu")
 
 
 
 
 
 
 
23
  return pipe
24
 
25
 
 
6
 
7
 
8
  def load_huggingface_model(model_name, model_type):
 
 
 
 
 
 
 
9
  if model_name == "SD-turbo":
10
+ pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
11
  elif model_name == "SDXL-turbo":
12
+ pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
13
  else:
14
  raise NotImplementedError
15
+ pipe = pipe.to("cuda")
16
+ # if model_name == "SD-turbo":
17
+ # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
18
+ # elif model_name == "SDXL-turbo":
19
+ # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
20
+ # else:
21
+ # raise NotImplementedError
22
+ # pipe = pipe.to("cpu")
23
  return pipe
24
 
25