|
import json |
|
import logging |
|
import os |
|
|
|
import blobfile as bf |
|
import torch |
|
from datasets import load_dataset |
|
from pytorch_lightning import seed_everything |
|
from tqdm import tqdm |
|
|
|
from arguments import parse_args |
|
from models import get_model |
|
from rewards import get_reward_losses |
|
from training import LatentNoiseTrainer, get_optimizer |
|
|
|
|
|
def setup(args): |
|
|
|
seed_everything(args.seed) |
|
bf.makedirs(f"{args.save_dir}/logs/{args.task}") |
|
|
|
logger = logging.getLogger() |
|
settings = ( |
|
f"{args.model}{'_' + args.prompt if args.task == 't2i-compbench' else ''}" |
|
f"{'_no-optim' if args.no_optim else ''}_{args.seed if args.task != 'geneval' else ''}" |
|
f"_lr{args.lr}_gc{args.grad_clip}_iter{args.n_iters}" |
|
f"_reg{args.reg_weight if args.enable_reg else '0'}" |
|
f"{'_pickscore' + str(args.pickscore_weighting) if args.enable_pickscore else ''}" |
|
f"{'_clip' + str(args.clip_weighting) if args.enable_clip else ''}" |
|
f"{'_hps' + str(args.hps_weighting) if args.enable_hps else ''}" |
|
f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}" |
|
f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}" |
|
) |
|
file_stream = open(f"{args.save_dir}/logs/{args.task}/{settings}.txt", "w") |
|
handler = logging.StreamHandler(file_stream) |
|
formatter = logging.Formatter("%(asctime)s - %(message)s") |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
logger.setLevel("INFO") |
|
consoleHandler = logging.StreamHandler() |
|
consoleHandler.setFormatter(formatter) |
|
logger.addHandler(consoleHandler) |
|
logging.info(args) |
|
if args.device_id is not None: |
|
logging.info(f"Using CUDA device {args.device_id}") |
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
os.environ["CUDA_VISIBLE_DEVICE"] = args.device_id |
|
if args.device == "cuda": |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
dtype = torch.float16 |
|
|
|
reward_losses = get_reward_losses(args, dtype, device, args.cache_dir) |
|
|
|
|
|
sd_model = get_model(args.model, dtype, device, args.cache_dir, args.memsave) |
|
trainer = LatentNoiseTrainer( |
|
reward_losses=reward_losses, |
|
model=sd_model, |
|
n_iters=args.n_iters, |
|
n_inference_steps=args.n_inference_steps, |
|
seed=args.seed, |
|
save_all_images=args.save_all_images, |
|
device=device, |
|
no_optim=args.no_optim, |
|
regularize=args.enable_reg, |
|
regularization_weight=args.reg_weight, |
|
grad_clip=args.grad_clip, |
|
log_metrics=args.task == "single" or not args.no_optim, |
|
imageselect=args.imageselect, |
|
) |
|
|
|
|
|
if args.model != "pixart": |
|
height = sd_model.unet.config.sample_size * sd_model.vae_scale_factor |
|
width = sd_model.unet.config.sample_size * sd_model.vae_scale_factor |
|
shape = ( |
|
1, |
|
sd_model.unet.in_channels, |
|
height // sd_model.vae_scale_factor, |
|
width // sd_model.vae_scale_factor, |
|
) |
|
else: |
|
height = sd_model.transformer.config.sample_size * sd_model.vae_scale_factor |
|
width = sd_model.transformer.config.sample_size * sd_model.vae_scale_factor |
|
shape = ( |
|
1, |
|
sd_model.transformer.config.in_channels, |
|
height // sd_model.vae_scale_factor, |
|
width // sd_model.vae_scale_factor, |
|
) |
|
enable_grad = not args.no_optim |
|
|
|
return args, trainer, device, dtype, shape, enable_grad, settings |
|
|
|
def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, progress_callback=None): |
|
|
|
if args.task == "single": |
|
init_latents = torch.randn(shape, device=device, dtype=dtype) |
|
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad) |
|
optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov) |
|
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}" |
|
os.makedirs(f"{save_dir}", exist_ok=True) |
|
best_image, total_init_rewards, total_best_rewards = trainer.train( |
|
latents, args.prompt, optimizer, save_dir, progress_callback=progress_callback |
|
) |
|
best_image.save(f"{save_dir}/best_image.png") |
|
return best_image, total_init_rewards, total_best_rewards |
|
elif args.task == "example-prompts": |
|
fo = open("assets/example_prompts.txt", "r") |
|
prompts = fo.readlines() |
|
fo.close() |
|
for i, prompt in tqdm(enumerate(prompts)): |
|
|
|
init_latents = torch.randn(shape, device=device, dtype=dtype) |
|
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad) |
|
optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov) |
|
|
|
prompt = prompt.strip() |
|
name = f"{i:03d}_{prompt}.png" |
|
save_dir = f"{args.save_dir}/{args.task}/{settings}/{name}" |
|
os.makedirs(save_dir, exist_ok=True) |
|
best_image, init_rewards, best_rewards = trainer.train( |
|
latents, prompt, optimizer, save_dir |
|
) |
|
if i == 0: |
|
total_best_rewards = {k: 0.0 for k in best_rewards.keys()} |
|
total_init_rewards = {k: 0.0 for k in best_rewards.keys()} |
|
for k in best_rewards.keys(): |
|
total_best_rewards[k] += best_rewards[k] |
|
total_init_rewards[k] += init_rewards[k] |
|
best_image.save(f"{save_dir}/best_image.png") |
|
logging.info(f"Initial rewards: {init_rewards}") |
|
logging.info(f"Best rewards: {best_rewards}") |
|
for k in total_best_rewards.keys(): |
|
total_best_rewards[k] /= len(prompts) |
|
total_init_rewards[k] /= len(prompts) |
|
|
|
|
|
with open(f"{args.save_dir}/example-prompts/{settings}/results.txt", "w") as f: |
|
f.write( |
|
f"Mean initial all rewards: {total_init_rewards}\n" |
|
f"Mean best all rewards: {total_best_rewards}\n" |
|
) |
|
elif args.task == "t2i-compbench": |
|
prompt_list_file = f"../T2I-CompBench/examples/dataset/{args.prompt}.txt" |
|
fo = open(prompt_list_file, "r") |
|
prompts = fo.readlines() |
|
fo.close() |
|
os.makedirs(f"{args.save_dir}/{args.task}/{settings}/samples", exist_ok=True) |
|
for i, prompt in tqdm(enumerate(prompts)): |
|
|
|
init_latents = torch.randn(shape, device=device, dtype=dtype) |
|
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad) |
|
optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov) |
|
|
|
prompt = prompt.strip() |
|
best_image, init_rewards, best_rewards = trainer.train( |
|
latents, prompt, optimizer |
|
) |
|
if i == 0: |
|
total_best_rewards = {k: 0.0 for k in best_rewards.keys()} |
|
total_init_rewards = {k: 0.0 for k in best_rewards.keys()} |
|
for k in best_rewards.keys(): |
|
total_best_rewards[k] += best_rewards[k] |
|
total_init_rewards[k] += init_rewards[k] |
|
name = f"{prompt}_{i:06d}.png" |
|
best_image.save(f"{args.save_dir}/{args.task}/{settings}/samples/{name}") |
|
logging.info(f"Initial rewards: {init_rewards}") |
|
logging.info(f"Best rewards: {best_rewards}") |
|
for k in total_best_rewards.keys(): |
|
total_best_rewards[k] /= len(prompts) |
|
total_init_rewards[k] /= len(prompts) |
|
elif args.task == "parti-prompts": |
|
parti_dataset = load_dataset("nateraw/parti-prompts", split="train") |
|
total_reward_diff = 0.0 |
|
total_best_reward = 0.0 |
|
total_init_reward = 0.0 |
|
total_improved_samples = 0 |
|
for index, sample in enumerate(parti_dataset): |
|
os.makedirs( |
|
f"{args.save_dir}/{args.task}/{settings}/{index}", exist_ok=True |
|
) |
|
prompt = sample["Prompt"] |
|
best_image, init_rewards, best_rewards = trainer.train( |
|
latents, prompt, optimizer |
|
) |
|
best_image.save( |
|
f"{args.save_dir}/{args.task}/{settings}/{index}/best_image.png" |
|
) |
|
open( |
|
f"{args.save_dir}/{args.task}/{settings}/{index}/prompt.txt", "w" |
|
).write( |
|
f"{prompt} \n Initial Rewards: {init_rewards} \n Best Rewards: {best_rewards}" |
|
) |
|
logging.info(f"Initial rewards: {init_rewards}") |
|
logging.info(f"Best rewards: {best_rewards}") |
|
initial_reward = init_rewards[args.benchmark_reward] |
|
best_reward = best_rewards[args.benchmark_reward] |
|
total_reward_diff += best_reward - initial_reward |
|
total_best_reward += best_reward |
|
total_init_reward += initial_reward |
|
if best_reward < initial_reward: |
|
total_improved_samples += 1 |
|
if i == 0: |
|
total_best_rewards = {k: 0.0 for k in best_rewards.keys()} |
|
total_init_rewards = {k: 0.0 for k in best_rewards.keys()} |
|
for k in best_rewards.keys(): |
|
total_best_rewards[k] += best_rewards[k] |
|
total_init_rewards[k] += init_rewards[k] |
|
|
|
init_latents = torch.randn(shape, device=device, dtype=dtype) |
|
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad) |
|
optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov) |
|
improvement_percentage = total_improved_samples / parti_dataset.num_rows |
|
mean_best_reward = total_best_reward / parti_dataset.num_rows |
|
mean_init_reward = total_init_reward / parti_dataset.num_rows |
|
mean_reward_diff = total_reward_diff / parti_dataset.num_rows |
|
logging.info( |
|
f"Improvement percentage: {improvement_percentage:.4f}, " |
|
f"mean initial reward: {mean_init_reward:.4f}, " |
|
f"mean best reward: {mean_best_reward:.4f}, " |
|
f"mean reward diff: {mean_reward_diff:.4f}" |
|
) |
|
for k in total_best_rewards.keys(): |
|
total_best_rewards[k] /= len(parti_dataset) |
|
total_init_rewards[k] /= len(parti_dataset) |
|
|
|
os.makedirs(f"{args.save_dir}/parti-prompts/{settings}", exist_ok=True) |
|
with open(f"{args.save_dir}/parti-prompts/{settings}/results.txt", "w") as f: |
|
f.write( |
|
f"Mean improvement: {improvement_percentage:.4f}, " |
|
f"mean initial reward: {mean_init_reward:.4f}, " |
|
f"mean best reward: {mean_best_reward:.4f}, " |
|
f"mean reward diff: {mean_reward_diff:.4f}\n" |
|
f"Mean initial all rewards: {total_init_rewards}\n" |
|
f"Mean best all rewards: {total_best_rewards}" |
|
) |
|
elif args.task == "geneval": |
|
prompt_list_file = "../geneval/prompts/evaluation_metadata.jsonl" |
|
with open(prompt_list_file) as fp: |
|
metadatas = [json.loads(line) for line in fp] |
|
outdir = f"{args.save_dir}/{args.task}/{settings}" |
|
for index, metadata in enumerate(metadatas): |
|
|
|
init_latents = torch.randn(shape, device=device, dtype=dtype) |
|
latents = torch.nn.Parameter(init_latents, requires_grad=True) |
|
optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov) |
|
|
|
prompt = metadata["prompt"] |
|
best_image, init_rewards, best_rewards = trainer.train( |
|
latents, prompt, optimizer |
|
) |
|
logging.info(f"Initial rewards: {init_rewards}") |
|
logging.info(f"Best rewards: {best_rewards}") |
|
outpath = f"{outdir}/{index:0>5}" |
|
os.makedirs(f"{outpath}/samples", exist_ok=True) |
|
with open(f"{outpath}/metadata.jsonl", "w") as fp: |
|
json.dump(metadata, fp) |
|
best_image.save(f"{outpath}/samples/{args.seed:05}.png") |
|
if i == 0: |
|
total_best_rewards = {k: 0.0 for k in best_rewards.keys()} |
|
total_init_rewards = {k: 0.0 for k in best_rewards.keys()} |
|
for k in best_rewards.keys(): |
|
total_best_rewards[k] += best_rewards[k] |
|
total_init_rewards[k] += init_rewards[k] |
|
for k in total_best_rewards.keys(): |
|
total_best_rewards[k] /= len(parti_dataset) |
|
total_init_rewards[k] /= len(parti_dataset) |
|
else: |
|
raise ValueError(f"Unknown task {args.task}") |
|
|
|
logging.info(f"Mean initial rewards: {total_init_rewards}") |
|
logging.info(f"Mean best rewards: {total_best_rewards}") |
|
|
|
def main(): |
|
args = parse_args() |
|
args, trainer, device, dtype, shape, enable_grad, settings = setup(args) |
|
execute_task(args, trainer, device, dtype, shape, enable_grad, settings) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |