ReNO / rewards /utils.py
fffiloni's picture
Upload 24 files
ca25718 verified
raw
history blame contribute delete
No virus
2.34 kB
from typing import Any, List
import torch
from torchvision.transforms import (CenterCrop, Compose, InterpolationMode,
Normalize, Resize)
from transformers import AutoProcessor
from rewards.aesthetic import AestheticLoss
from rewards.base_reward import BaseRewardLoss
from rewards.clip import CLIPLoss
from rewards.hps import HPSLoss
from rewards.imagereward import ImageRewardLoss
from rewards.pickscore import PickScoreLoss
def get_reward_losses(
args: Any, dtype: torch.dtype, device: torch.device, cache_dir: str
) -> List[BaseRewardLoss]:
if args.enable_clip or args.enable_pickscore:
tokenizer = AutoProcessor.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", cache_dir=cache_dir
)
reward_losses = []
if args.enable_hps:
reward_losses.append(
HPSLoss(args.hps_weighting, dtype, device, cache_dir, memsave=args.memsave)
)
if args.enable_imagereward:
reward_losses.append(
ImageRewardLoss(
args.imagereward_weighting,
dtype,
device,
cache_dir,
memsave=args.memsave,
)
)
if args.enable_clip:
reward_losses.append(
CLIPLoss(
args.clip_weighting,
dtype,
device,
cache_dir,
tokenizer,
memsave=args.memsave,
)
)
if args.enable_pickscore:
reward_losses.append(
PickScoreLoss(
args.pickscore_weighting,
dtype,
device,
cache_dir,
tokenizer,
memsave=args.memsave,
)
)
if args.enable_aesthetic:
reward_losses.append(
AestheticLoss(
args.aesthetic_weighting, dtype, device, cache_dir, memsave=args.memsave
)
)
return reward_losses
def clip_img_transform(size: int = 224):
return Compose(
[
Resize(size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(size),
Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)