|
import argparse |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Process Reward Optimization.") |
|
|
|
|
|
parser.add_argument( |
|
"--cache_dir", |
|
type=str, |
|
help="HF cache directory", |
|
default="/shared-local/aoq951/HF_CACHE/", |
|
) |
|
parser.add_argument( |
|
"--save_dir", |
|
type=str, |
|
help="Directory to save images", |
|
default="/shared-local/aoq951/ReNO/outputs", |
|
) |
|
|
|
|
|
parser.add_argument("--model", type=str, help="Model to use", default="sdxl-turbo") |
|
parser.add_argument("--lr", type=float, help="Learning rate", default=5.0) |
|
parser.add_argument("--n_iters", type=int, help="Number of iterations", default=50) |
|
parser.add_argument( |
|
"--n_inference_steps", type=int, help="Number of iterations", default=1 |
|
) |
|
parser.add_argument( |
|
"--optim", |
|
choices=["sgd", "adam", "lbfgs"], |
|
default="sgd", |
|
help="Optimizer to be used", |
|
) |
|
parser.add_argument("--nesterov", default=True, action="store_false") |
|
parser.add_argument( |
|
"--grad_clip", type=float, help="Gradient clipping", default=0.1 |
|
) |
|
parser.add_argument("--seed", type=int, help="Seed to use", default=0) |
|
|
|
|
|
parser.add_argument( |
|
"--disable_hps", default=True, action="store_false", dest="enable_hps" |
|
) |
|
parser.add_argument( |
|
"--hps_weighting", type=float, help="Weighting for HPS", default=5.0 |
|
) |
|
parser.add_argument( |
|
"--disable_imagereward", |
|
default=True, |
|
action="store_false", |
|
dest="enable_imagereward", |
|
) |
|
parser.add_argument( |
|
"--imagereward_weighting", |
|
type=float, |
|
help="Weighting for ImageReward", |
|
default=1.0, |
|
) |
|
parser.add_argument( |
|
"--disable_clip", default=True, action="store_false", dest="enable_clip" |
|
) |
|
parser.add_argument( |
|
"--clip_weighting", type=float, help="Weighting for CLIP", default=0.01 |
|
) |
|
parser.add_argument( |
|
"--disable_pickscore", |
|
default=True, |
|
action="store_false", |
|
dest="enable_pickscore", |
|
) |
|
parser.add_argument( |
|
"--pickscore_weighting", |
|
type=float, |
|
help="Weighting for PickScore", |
|
default=0.05, |
|
) |
|
parser.add_argument( |
|
"--disable_aesthetic", |
|
default=False, |
|
action="store_false", |
|
dest="enable_aesthetic", |
|
) |
|
parser.add_argument( |
|
"--aesthetic_weighting", |
|
type=float, |
|
help="Weighting for Aesthetic", |
|
default=0.0, |
|
) |
|
parser.add_argument( |
|
"--disable_reg", default=True, action="store_false", dest="enable_reg" |
|
) |
|
parser.add_argument( |
|
"--reg_weight", type=float, help="Regularization weight", default=0.01 |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--task", |
|
type=str, |
|
help="Task to run", |
|
default="single", |
|
choices=[ |
|
"t2i-compbench", |
|
"single", |
|
"parti-prompts", |
|
"geneval", |
|
"example-prompts", |
|
], |
|
) |
|
parser.add_argument( |
|
"--prompt", |
|
type=str, |
|
help="Prompt to run", |
|
default="A red dog and a green cat", |
|
) |
|
parser.add_argument( |
|
"--benchmark_reward", |
|
help="Reward to benchmark on", |
|
default="total", |
|
choices=["ImageReward", "PickScore", "HPS", "CLIP", "total"], |
|
) |
|
|
|
|
|
parser.add_argument("--save_all_images", default=False, action="store_true") |
|
parser.add_argument("--no_optim", default=False, action="store_true") |
|
parser.add_argument("--imageselect", default=False, action="store_true") |
|
parser.add_argument("--memsave", default=False, action="store_true") |
|
parser.add_argument("--dtype", type=str, help="Data type to use", default="float16") |
|
parser.add_argument("--device_id", type=str, help="Device ID to use", default=None) |
|
parser.add_argument( |
|
"--cpu_offloading", |
|
help="Enable CPU offloading", |
|
default=False, |
|
action="store_true", |
|
) |
|
|
|
|
|
parser.add_argument("--enable_multi_apply", default=False, action="store_true") |
|
parser.add_argument( |
|
"--multi_step_model", type=str, help="Model to use", default="flux" |
|
) |
|
|
|
args = parser.parse_args() |
|
return args |