fffiloni commited on
Commit
25f49a4
1 Parent(s): dd8f929

display model loading status separately

Browse files
Files changed (1) hide show
  1. app.py +25 -9
app.py CHANGED
@@ -58,8 +58,8 @@ def start_over(gallery_state, loaded_model_setup):
58
  return gallery_state, None, None, gr.update(visible=False), loaded_model_setup
59
 
60
  def setup_model(prompt, model, seed, num_iterations, learning_rate, hps_w, imgrw_w, pcks_w, clip_w, progress=gr.Progress(track_tqdm=True)):
61
- if prompt is None:
62
- raise gr.Error("You forgot the prompt !")
63
 
64
  """Clear CUDA memory before starting the training."""
65
  torch.cuda.empty_cache() # Free up cached memory
@@ -86,10 +86,14 @@ def setup_model(prompt, model, seed, num_iterations, learning_rate, hps_w, imgrw
86
  args.enable_multi_apply= True
87
  args.multi_step_model = "flux"
88
 
89
- args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args)
90
- loaded_setup = [args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings]
91
-
92
- return None, loaded_setup
 
 
 
 
93
 
94
  def generate_image(setup_args, num_iterations):
95
  torch.cuda.empty_cache() # Free up cached memory
@@ -198,7 +202,19 @@ def show_gallery_output(gallery_state):
198
  title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
199
  description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed."
200
 
201
- with gr.Blocks(analytics_enabled=False) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
202
  loaded_model_setup = gr.State()
203
  gallery_state = gr.State()
204
  with gr.Column():
@@ -221,7 +237,6 @@ with gr.Blocks(analytics_enabled=False) as demo:
221
  with gr.Row():
222
  chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sd-turbo")
223
  seed = gr.Number(label="seed", value=0)
224
- model_status = gr.Textbox(label="model status", visible=False)
225
 
226
  with gr.Row():
227
  n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
@@ -249,6 +264,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
249
  )
250
 
251
  with gr.Column():
 
252
  output_image = gr.Image(type="filepath", label="Best Generated Image")
253
  status = gr.Textbox(label="Status")
254
  iter_gallery = gr.Gallery(label="Iterations", columns=4, visible=False)
@@ -260,7 +276,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
260
  ).then(
261
  fn = setup_model,
262
  inputs = [prompt, chosen_model, seed, n_iter, hps_w, imgrw_w, pcks_w, clip_w, learning_rate],
263
- outputs = [output_image, loaded_model_setup] # Load the new setup into the state
264
  ).then(
265
  fn = generate_image,
266
  inputs = [loaded_model_setup, n_iter],
 
58
  return gallery_state, None, None, gr.update(visible=False), loaded_model_setup
59
 
60
  def setup_model(prompt, model, seed, num_iterations, learning_rate, hps_w, imgrw_w, pcks_w, clip_w, progress=gr.Progress(track_tqdm=True)):
61
+ if prompt is None or prompt == "":
62
+ raise gr.Error("You forgot to provide a prompt !")
63
 
64
  """Clear CUDA memory before starting the training."""
65
  torch.cuda.empty_cache() # Free up cached memory
 
86
  args.enable_multi_apply= True
87
  args.multi_step_model = "flux"
88
 
89
+ try:
90
+ args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args)
91
+ loaded_setup = [args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings]
92
+ return f"{model} model loaded succesfully !", loaded_setup
93
+
94
+ except Exception as e:
95
+ print(f"Unexpected Error: {e}")
96
+ return f"Something went wrong with {model} loading", None
97
 
98
  def generate_image(setup_args, num_iterations):
99
  torch.cuda.empty_cache() # Free up cached memory
 
202
  title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
203
  description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed."
204
 
205
+ css="""
206
+ #model-status-id{
207
+ height: 126px;
208
+ }
209
+ #model-status-id .progress-text{
210
+ font-size: 10px!important;
211
+ }
212
+ #model-status-id .progress-level-inner{
213
+ font-size: 8px!important;
214
+ }
215
+ """
216
+
217
+ with gr.Blocks(css=css, analytics_enabled=False) as demo:
218
  loaded_model_setup = gr.State()
219
  gallery_state = gr.State()
220
  with gr.Column():
 
237
  with gr.Row():
238
  chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sd-turbo")
239
  seed = gr.Number(label="seed", value=0)
 
240
 
241
  with gr.Row():
242
  n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
 
264
  )
265
 
266
  with gr.Column():
267
+ model_status = gr.Textbox(label="model status", visible=True, elem_id="model-status-id")
268
  output_image = gr.Image(type="filepath", label="Best Generated Image")
269
  status = gr.Textbox(label="Status")
270
  iter_gallery = gr.Gallery(label="Iterations", columns=4, visible=False)
 
276
  ).then(
277
  fn = setup_model,
278
  inputs = [prompt, chosen_model, seed, n_iter, hps_w, imgrw_w, pcks_w, clip_w, learning_rate],
279
+ outputs = [model_status, loaded_model_setup] # Load the new setup into the state
280
  ).then(
281
  fn = generate_image,
282
  inputs = [loaded_model_setup, n_iter],