fffiloni commited on
Commit
b355934
1 Parent(s): 9bd0658

reset loaded_setup state on start over

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -49,11 +49,13 @@ def clean_dir(save_dir):
49
  else:
50
  print(f"{save_dir} does not exist.")
51
 
52
- def start_over(gallery_state):
53
  torch.cuda.empty_cache() # Free up cached memory
54
  if gallery_state is not None:
55
  gallery_state = None
56
- return gallery_state, None, None, gr.update(visible=False)
 
 
57
 
58
  def setup_model(prompt, model, num_iterations, learning_rate, hps_w, imgrw_w, pcks_w, clip_w, progress=gr.Progress(track_tqdm=True)):
59
 
@@ -241,12 +243,12 @@ with gr.Blocks(analytics_enabled=False) as demo:
241
 
242
  submit_btn.click(
243
  fn = start_over,
244
- inputs =[gallery_state],
245
- outputs = [gallery_state, output_image, status, iter_gallery]
246
  ).then(
247
  fn = setup_model,
248
  inputs = [prompt, chosen_model, n_iter, hps_w, imgrw_w, pcks_w, clip_w, learning_rate],
249
- outputs = [output_image, loaded_model_setup]
250
  ).then(
251
  fn = generate_image,
252
  inputs = [loaded_model_setup, n_iter],
 
49
  else:
50
  print(f"{save_dir} does not exist.")
51
 
52
+ def start_over(gallery_state, loaded_model_setup):
53
  torch.cuda.empty_cache() # Free up cached memory
54
  if gallery_state is not None:
55
  gallery_state = None
56
+ if loaded_model_setup is not None:
57
+ loaded_model_setup = None # Reset loaded model setup to prevent re-triggering old state
58
+ return gallery_state, None, None, gr.update(visible=False), loaded_model_setup
59
 
60
  def setup_model(prompt, model, num_iterations, learning_rate, hps_w, imgrw_w, pcks_w, clip_w, progress=gr.Progress(track_tqdm=True)):
61
 
 
243
 
244
  submit_btn.click(
245
  fn = start_over,
246
+ inputs =[gallery_state, loaded_model_setup], # Reset loaded model setup as well
247
+ outputs = [gallery_state, output_image, status, iter_gallery, loaded_model_setup] # Ensure loaded_model_setup is reset
248
  ).then(
249
  fn = setup_model,
250
  inputs = [prompt, chosen_model, n_iter, hps_w, imgrw_w, pcks_w, clip_w, learning_rate],
251
+ outputs = [output_image, loaded_model_setup] # Load the new setup into the state
252
  ).then(
253
  fn = generate_image,
254
  inputs = [loaded_model_setup, n_iter],