multimodalart HF staff commited on
Commit
3ecd23d
1 Parent(s): 698bcd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -47,10 +47,8 @@ if torch.cuda.is_available():
47
  previewer = Previewer()
48
  previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
49
  previewer.load_state_dict(previewer_state_dict)
50
- def callback_prior(self_remote, i, t, kwargs):
51
- latents = kwargs["latents"]
52
  output = previewer(latents)
53
- print(output)
54
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
55
  return output
56
  callback_steps = 1
@@ -100,7 +98,8 @@ def generate(
100
  guidance_scale=prior_guidance_scale,
101
  num_images_per_prompt=num_images_per_prompt,
102
  generator=generator,
103
- callback_on_step_end=callback_prior,
 
104
  )
105
 
106
  if PREVIEW_IMAGES:
 
47
  previewer = Previewer()
48
  previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
49
  previewer.load_state_dict(previewer_state_dict)
50
+ def callback_prior(i, t, latents):
 
51
  output = previewer(latents)
 
52
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
53
  return output
54
  callback_steps = 1
 
98
  guidance_scale=prior_guidance_scale,
99
  num_images_per_prompt=num_images_per_prompt,
100
  generator=generator,
101
+ callback=callback_prior,
102
+ callback_steps=callback_steps
103
  )
104
 
105
  if PREVIEW_IMAGES: