|
import torch |
|
import gradio as gr |
|
from main import setup, execute_task |
|
from arguments import parse_args |
|
import os |
|
import shutil |
|
import glob |
|
import time |
|
import threading |
|
import argparse |
|
|
|
|
|
|
|
def list_iter_images(save_dir): |
|
|
|
image_extensions = ['jpg', 'jpeg', 'png', 'gif', 'bmp'] |
|
|
|
|
|
image_paths = [] |
|
|
|
|
|
for ext in image_extensions: |
|
|
|
image_paths.extend(glob.glob(os.path.join(save_dir, f'*.{ext}'))) |
|
|
|
|
|
|
|
|
|
return image_paths |
|
|
|
def clean_dir(save_dir): |
|
|
|
if os.path.exists(save_dir): |
|
|
|
if len(os.listdir(save_dir)) > 0: |
|
|
|
for filename in os.listdir(save_dir): |
|
file_path = os.path.join(save_dir, filename) |
|
try: |
|
if os.path.isfile(file_path) or os.path.islink(file_path): |
|
os.unlink(file_path) |
|
elif os.path.isdir(file_path): |
|
shutil.rmtree(file_path) |
|
except Exception as e: |
|
print(f"Failed to delete {file_path}. Reason: {e}") |
|
print(f"All files in {save_dir} have been deleted.") |
|
else: |
|
print(f"{save_dir} exists but is empty.") |
|
else: |
|
print(f"{save_dir} does not exist.") |
|
|
|
def start_over(gallery_state, loaded_model_setup): |
|
torch.cuda.empty_cache() |
|
if gallery_state is not None: |
|
gallery_state = None |
|
if loaded_model_setup is not None: |
|
loaded_model_setup = None |
|
return gallery_state, None, None, gr.update(visible=False), loaded_model_setup |
|
|
|
def setup_model(prompt, model, seed, num_iterations, learning_rate, hps_w, imgrw_w, pcks_w, clip_w, progress=gr.Progress(track_tqdm=True)): |
|
if prompt is None or prompt == "": |
|
raise gr.Error("You forgot to provide a prompt !") |
|
|
|
"""Clear CUDA memory before starting the training.""" |
|
torch.cuda.empty_cache() |
|
|
|
|
|
args = parse_args() |
|
args.task = "single" |
|
args.prompt = prompt |
|
args.model = model |
|
args.seed = seed |
|
args.n_iters = num_iterations |
|
args.lr = learning_rate |
|
args.cache_dir = "./HF_model_cache" |
|
args.save_dir = "./outputs" |
|
args.save_all_images = True |
|
|
|
args.hps_weighting = hps_w |
|
args.imagereward_weighting = imgrw_w |
|
args.pickscore_weighting = pcks_w |
|
args.clip_weighting = clip_w |
|
|
|
if model == "flux": |
|
args.cpu_offloading = True |
|
args.enable_multi_apply= True |
|
args.multi_step_model = "flux" |
|
|
|
try: |
|
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args) |
|
loaded_setup = [args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings] |
|
return f"{model} model loaded succesfully !", loaded_setup |
|
|
|
except Exception as e: |
|
print(f"Unexpected Error: {e}") |
|
return f"Something went wrong with {model} loading", None |
|
|
|
def generate_image(setup_args, num_iterations): |
|
torch.cuda.empty_cache() |
|
|
|
args = setup_args[0] |
|
trainer = setup_args[1] |
|
device = setup_args[2] |
|
dtype = setup_args[3] |
|
shape = setup_args[4] |
|
enable_grad = setup_args[5] |
|
multi_apply_fn = setup_args[6] |
|
|
|
settings = setup_args[7] |
|
print(f"SETTINGS: {settings}") |
|
|
|
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:150]}" |
|
clean_dir(save_dir) |
|
|
|
try: |
|
torch.cuda.empty_cache() |
|
steps_completed = [] |
|
result_container = {"best_image": None, "total_init_rewards": None, "total_best_rewards": None} |
|
error_status = {"error_occurred": False} |
|
thread_status = {"running": False} |
|
|
|
def progress_callback(step): |
|
|
|
if not steps_completed or step > steps_completed[-1]: |
|
steps_completed.append(step) |
|
print(f"Progress: Step {step} completed.") |
|
|
|
def run_main(): |
|
thread_status["running"] = True |
|
try: |
|
execute_task( |
|
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback |
|
) |
|
except torch.cuda.OutOfMemoryError as e: |
|
print(f"CUDA Out of Memory Error: {e}") |
|
error_status["error_occurred"] = True |
|
except RuntimeError as e: |
|
if 'out of memory' in str(e): |
|
print(f"Runtime Error: {e}") |
|
error_status["error_occurred"] = True |
|
else: |
|
raise |
|
finally: |
|
thread_status["running"] = False |
|
|
|
if not thread_status["running"]: |
|
main_thread = threading.Thread(target=run_main) |
|
main_thread.start() |
|
|
|
last_step_yielded = 0 |
|
while main_thread.is_alive() and not error_status["error_occurred"]: |
|
|
|
if steps_completed and steps_completed[-1] > last_step_yielded: |
|
last_step_yielded = steps_completed[-1] |
|
png_number = last_step_yielded - 1 |
|
|
|
image_path = os.path.join(save_dir, f"{png_number}.png") |
|
if os.path.exists(image_path): |
|
yield (image_path, f"Iteration {last_step_yielded}/{num_iterations} - Image saved", None) |
|
else: |
|
yield (None, f"Iteration {last_step_yielded}/{num_iterations} - Image not found", None) |
|
else: |
|
time.sleep(0.1) |
|
|
|
if error_status["error_occurred"]: |
|
torch.cuda.empty_cache() |
|
yield (None, "CUDA out of memory. Please reduce your batch size or image resolution.", None) |
|
else: |
|
main_thread.join() |
|
final_image_path = os.path.join(save_dir, "best_image.png") |
|
if os.path.exists(final_image_path): |
|
iter_images = list_iter_images(save_dir) |
|
torch.cuda.empty_cache() |
|
time.sleep(0.5) |
|
yield (final_image_path, f"Final image saved at {final_image_path}", iter_images) |
|
else: |
|
torch.cuda.empty_cache() |
|
yield (None, "Image generation completed, but no final image was found.", None) |
|
|
|
torch.cuda.empty_cache() |
|
|
|
except torch.cuda.OutOfMemoryError as e: |
|
print(f"Global CUDA Out of Memory Error: {e}") |
|
yield (None, "CUDA out of memory.", None) |
|
except RuntimeError as e: |
|
if 'out of memory' in str(e): |
|
print(f"Runtime Error: {e}") |
|
yield (None, "CUDA out of memory.", None) |
|
else: |
|
yield (None, f"An error occurred: {str(e)}", None) |
|
except Exception as e: |
|
print(f"Unexpected Error: {e}") |
|
yield (None, f"An unexpected error occurred: {str(e)}", None) |
|
|
|
def show_gallery_output(gallery_state): |
|
if gallery_state is not None: |
|
return gr.update(value=gallery_state, visible=True) |
|
else: |
|
return gr.update(value=None, visible=False) |
|
|
|
|
|
title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization" |
|
description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed." |
|
|
|
css=""" |
|
#model-status-id{ |
|
height: 126px; |
|
} |
|
#model-status-id .progress-text{ |
|
font-size: 10px!important; |
|
} |
|
#model-status-id .progress-level-inner{ |
|
font-size: 8px!important; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css, analytics_enabled=False) as demo: |
|
loaded_model_setup = gr.State() |
|
gallery_state = gr.State() |
|
with gr.Column(): |
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
gr.HTML(""" |
|
<div style="display:flex;column-gap:4px;"> |
|
<a href='https://github.com/ExplainableML/ReNO'> |
|
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> |
|
</a> |
|
<a href='https://arxiv.org/abs/2406.04312v1'> |
|
<img src='https://img.shields.io/badge/Paper-Arxiv-red'> |
|
</a> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox(label="Prompt") |
|
with gr.Row(): |
|
chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sd-turbo") |
|
seed = gr.Number(label="seed", value=0) |
|
|
|
with gr.Row(): |
|
n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations") |
|
learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate") |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
with gr.Column(): |
|
hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0) |
|
imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0) |
|
pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=5.0, value=0.05) |
|
clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01) |
|
|
|
submit_btn = gr.Button("Submit") |
|
|
|
gr.Examples( |
|
examples = [ |
|
"A red dog and a green cat", |
|
"A pink elephant and a grey cow", |
|
"A toaster riding a bike", |
|
"Dwayne Johnson depicted as a philosopher king in an academic painting by Greg Rutkowski", |
|
"A curious, orange fox and a fluffy, white rabbit, playing together in a lush, green meadow filled with yellow dandelions", |
|
"An epic oil painting: a red portal infront of a cityscape, a solitary figure, and a colorful sky over snowy mountains" |
|
], |
|
inputs = [prompt] |
|
) |
|
|
|
with gr.Column(): |
|
model_status = gr.Textbox(label="model status", visible=True, elem_id="model-status-id") |
|
output_image = gr.Image(type="filepath", label="Best Generated Image") |
|
status = gr.Textbox(label="Status") |
|
iter_gallery = gr.Gallery(label="Iterations", columns=4, visible=False) |
|
|
|
submit_btn.click( |
|
fn = start_over, |
|
inputs =[gallery_state, loaded_model_setup], |
|
outputs = [gallery_state, output_image, status, iter_gallery, loaded_model_setup] |
|
).then( |
|
fn = setup_model, |
|
inputs = [prompt, chosen_model, seed, n_iter, hps_w, imgrw_w, pcks_w, clip_w, learning_rate], |
|
outputs = [model_status, loaded_model_setup] |
|
).then( |
|
fn = generate_image, |
|
inputs = [loaded_model_setup, n_iter], |
|
outputs = [output_image, status, gallery_state] |
|
).then( |
|
fn = show_gallery_output, |
|
inputs = [gallery_state], |
|
outputs = iter_gallery |
|
) |
|
|
|
|
|
demo.queue().launch(show_error=True, show_api=False) |