tfwang commited on
Commit
f4a50a2
1 Parent(s): b3ff61f

Update glide_text2im/glide_util.py

Browse files
Files changed (1) hide show
  1. glide_text2im/glide_util.py +3 -3
glide_text2im/glide_util.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  from typing import Tuple
3
- from . import dist_util
4
  import PIL
5
  import numpy as np
6
  import torch as th
@@ -44,7 +44,7 @@ def sample(
44
  uncond_ref = th.ones_like(cond_ref)
45
 
46
  model_kwargs = {}
47
- model_kwargs['ref'] = th.cat([cond_ref, uncond_ref], 0).to(dist_util.dev())
48
 
49
  def cfg_model_fn(x_t, ts, **kwargs):
50
  half = x_t[: len(x_t) // 2]
@@ -60,7 +60,7 @@ def sample(
60
 
61
 
62
  if upsample_enabled:
63
- model_kwargs['low_res'] = prompt['low_res'].to(dist_util.dev())
64
  noise = th.randn((batch_size, 3, side_y, side_x), device=device) * upsample_temp
65
  model_fn = glide_model # just use the base model, no need for CFG.
66
  model_kwargs['ref'] = model_kwargs['ref'][:batch_size]
 
1
  import os
2
  from typing import Tuple
3
+ #from . import dist_util
4
  import PIL
5
  import numpy as np
6
  import torch as th
 
44
  uncond_ref = th.ones_like(cond_ref)
45
 
46
  model_kwargs = {}
47
+ model_kwargs['ref'] = th.cat([cond_ref, uncond_ref], 0).cuda()
48
 
49
  def cfg_model_fn(x_t, ts, **kwargs):
50
  half = x_t[: len(x_t) // 2]
 
60
 
61
 
62
  if upsample_enabled:
63
+ model_kwargs['low_res'] = prompt['low_res'].cuda()
64
  noise = th.randn((batch_size, 3, side_y, side_x), device=device) * upsample_temp
65
  model_fn = glide_model # just use the base model, no need for CFG.
66
  model_kwargs['ref'] = model_kwargs['ref'][:batch_size]