ksort commited on
Commit
0e503f3
1 Parent(s): 3ec9103

Update ssh

Browse files
model/model_manager.py CHANGED
@@ -26,7 +26,30 @@ class ModelManager:
26
  @spaces.GPU(duration=120)
27
  def generate_image_ig(self, prompt, model_name):
28
  pipe = self.load_model_pipe(model_name)
29
- result = pipe(prompt=prompt).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  return result
31
 
32
  def generate_image_ig_api(self, prompt, model_name):
 
26
  @spaces.GPU(duration=120)
27
  def generate_image_ig(self, prompt, model_name):
28
  pipe = self.load_model_pipe(model_name)
29
+ if 'cascade' not in name:
30
+ result = pipe(prompt=prompt).images[0]
31
+ else:
32
+ prior, decoder = pipe
33
+ prior.enable_model_cpu_offload()
34
+ prior_output = prior(
35
+ prompt=prompt,
36
+ height=512,
37
+ width=512,
38
+ negative_prompt='',
39
+ guidance_scale=4.0,
40
+ num_images_per_prompt=1,
41
+ num_inference_steps=20
42
+ )
43
+
44
+ decoder.enable_model_cpu_offload()
45
+ result = decoder(
46
+ image_embeddings=prior_output.image_embeddings.to(torch.float16),
47
+ prompt=prompt,
48
+ negative_prompt='',
49
+ guidance_scale=0.0,
50
+ output_type="pil",
51
+ num_inference_steps=10
52
+ ).images[0]
53
  return result
54
 
55
  def generate_image_ig_api(self, prompt, model_name):
model/models/huggingface_models.py CHANGED
@@ -1,5 +1,6 @@
1
  from diffusers import DiffusionPipeline
2
  from diffusers import AutoPipelineForText2Image
 
3
  import torch
4
 
5
 
@@ -8,11 +9,16 @@ import torch
8
  def load_huggingface_model(model_name, model_type):
9
  if model_name == "SD-turbo":
10
  pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
 
11
  elif model_name == "SDXL-turbo":
12
  pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
 
 
 
 
 
13
  else:
14
  raise NotImplementedError
15
- pipe = pipe.to("cuda")
16
  # if model_name == "SD-turbo":
17
  # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
18
  # elif model_name == "SDXL-turbo":
@@ -24,8 +30,9 @@ def load_huggingface_model(model_name, model_type):
24
 
25
 
26
  if __name__ == "__main__":
27
- for name in ["SD-turbo", "SDXL-turbo"]:
28
- load_huggingface_model(name, "text2image")
 
29
 
30
  # for name in ["IF-I-XL-v1.0"]:
31
  # pipe = load_huggingface_model(name, 'text2image')
 
1
  from diffusers import DiffusionPipeline
2
  from diffusers import AutoPipelineForText2Image
3
+ from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
4
  import torch
5
 
6
 
 
9
  def load_huggingface_model(model_name, model_type):
10
  if model_name == "SD-turbo":
11
  pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
12
+ pipe = pipe.to("cuda")
13
  elif model_name == "SDXL-turbo":
14
  pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
15
+ pipe = pipe.to("cuda")
16
+ elif model_name == "Stable-cascade":
17
+ prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16)
18
+ decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
19
+ pipe = [prior, decoder]
20
  else:
21
  raise NotImplementedError
 
22
  # if model_name == "SD-turbo":
23
  # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
24
  # elif model_name == "SDXL-turbo":
 
30
 
31
 
32
  if __name__ == "__main__":
33
+ for name in ["SD-turbo", "SDXL-turbo"]: #"SD-turbo", "SDXL-turbo"
34
+ pipe = load_huggingface_model(name, "text2image")
35
+
36
 
37
  # for name in ["IF-I-XL-v1.0"]:
38
  # pipe = load_huggingface_model(name, 'text2image')
serve/upload.py CHANGED
@@ -79,18 +79,18 @@ def upload_ssh_all(states, output_dir, data, data_path):
79
  output_file_list.append(output_file)
80
  image_list.append(states[i].output)
81
 
82
- # with sftp_client as sftp:
83
- for i in range(len(output_file_list)):
84
- if isinstance(image_list[i], str):
85
- print("get url image")
86
- image_list[i] = get_image_from_url(image_list[i])
87
- with io.BytesIO() as image_byte_stream:
88
- image_list[i].save(image_byte_stream, format='JPEG')
89
- image_byte_stream.seek(0)
90
- sftp_client.putfo(image_byte_stream, output_file_list[i])
91
- print(f"Successfully uploaded image to {output_file_list[i]}")
92
- json_data = json.dumps(data, indent=4)
93
- with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream:
94
- sftp_client.putfo(json_byte_stream, data_path)
95
- print(f"Successfully uploaded JSON data to {data_path}")
96
  # create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
 
79
  output_file_list.append(output_file)
80
  image_list.append(states[i].output)
81
 
82
+ with sftp_client as sftp:
83
+ for i in range(len(output_file_list)):
84
+ if isinstance(image_list[i], str):
85
+ print("get url image")
86
+ image_list[i] = get_image_from_url(image_list[i])
87
+ with io.BytesIO() as image_byte_stream:
88
+ image_list[i].save(image_byte_stream, format='JPEG')
89
+ image_byte_stream.seek(0)
90
+ sftp.putfo(image_byte_stream, output_file_list[i])
91
+ print(f"Successfully uploaded image to {output_file_list[i]}")
92
+ json_data = json.dumps(data, indent=4)
93
+ with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream:
94
+ sftp.putfo(json_byte_stream, data_path)
95
+ print(f"Successfully uploaded JSON data to {data_path}")
96
  # create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)