tfwang commited on
Commit
780da15
1 Parent(s): 71c0c21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -24,6 +24,7 @@ import numpy as np
24
  from huggingface_hub import hf_hub_download
25
 
26
  def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
 
27
  parser, parser_up = create_argparser()
28
 
29
  args = parser.parse_args()
@@ -70,8 +71,8 @@ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
70
  model_ckpt2 , strict=True )
71
 
72
 
73
- model.cuda()
74
- model_up.cuda()
75
  model.eval()
76
  model_up.eval()
77
 
@@ -120,7 +121,7 @@ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
120
  prompt=model_kwargs,
121
  batch_size= args.num_samples,
122
  guidance_scale=args.sample_c,
123
- device=torch.device('cuda'),
124
  prediction_respacing= str(sample_step),
125
  upsample_enabled= False,
126
  upsample_temp=0.997,
@@ -140,7 +141,7 @@ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
140
  prompt=model_kwargs,
141
  batch_size=args.num_samples,
142
  guidance_scale=1,
143
- device=torch.device('cuda'),
144
  prediction_respacing= "fast27",
145
  upsample_enabled=True,
146
  upsample_temp=0.997,
 
24
  from huggingface_hub import hf_hub_download
25
 
26
  def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
27
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
28
  parser, parser_up = create_argparser()
29
 
30
  args = parser.parse_args()
 
71
  model_ckpt2 , strict=True )
72
 
73
 
74
+ model.to(device)
75
+ model_up.to(device)
76
  model.eval()
77
  model_up.eval()
78
 
 
121
  prompt=model_kwargs,
122
  batch_size= args.num_samples,
123
  guidance_scale=args.sample_c,
124
+ device=device,
125
  prediction_respacing= str(sample_step),
126
  upsample_enabled= False,
127
  upsample_temp=0.997,
 
141
  prompt=model_kwargs,
142
  batch_size=args.num_samples,
143
  guidance_scale=1,
144
+ device=device,
145
  prediction_respacing= "fast27",
146
  upsample_enabled=True,
147
  upsample_temp=0.997,