ReNO / rewards /hps.py
fffiloni's picture
Upload 24 files
ca25718 verified
raw
history blame
No virus
1.97 kB
import huggingface_hub
import torch
from hpsv2.src.open_clip import create_model, get_tokenizer
from rewards.base_reward import BaseRewardLoss
class HPSLoss(BaseRewardLoss):
"""HPS reward loss function for optimization."""
def __init__(
self,
weighting: float,
dtype: torch.dtype,
device: torch.device,
cache_dir: str,
memsave: bool = False,
):
self.hps_model = create_model(
"ViT-H-14",
"laion2B-s32B-b79K",
precision=dtype,
device=device,
cache_dir=cache_dir,
)
checkpoint_path = huggingface_hub.hf_hub_download(
"xswu/HPSv2", "HPS_v2.1_compressed.pt", cache_dir=cache_dir
)
self.hps_model.load_state_dict(
torch.load(checkpoint_path, map_location=device)["state_dict"]
)
self.hps_tokenizer = get_tokenizer("ViT-H-14")
if memsave:
import memsave_torch.nn
self.hps_model = memsave_torch.nn.convert_to_memory_saving(self.hps_model)
self.hps_model = self.hps_model.to(device, dtype=dtype)
self.hps_model.eval()
self.freeze_parameters(self.hps_model.parameters())
super().__init__("HPS", weighting)
self.hps_model.set_grad_checkpointing(True)
def get_image_features(self, image: torch.Tensor) -> torch.Tensor:
hps_image_features = self.hps_model.encode_image(image)
return hps_image_features
def get_text_features(self, prompt: str) -> torch.Tensor:
hps_text = self.hps_tokenizer(prompt).to("cuda")
hps_text_features = self.hps_model.encode_text(hps_text)
return hps_text_features
def compute_loss(
self, image_features: torch.Tensor, text_features: torch.Tensor
) -> torch.Tensor:
logits_per_image = image_features @ text_features.T
hps_loss = 1 - torch.diagonal(logits_per_image)[0]
return hps_loss