ReNO / rewards /clip.py
fffiloni's picture
Upload 24 files
ca25718 verified
raw
history blame contribute delete
No virus
1.8 kB
import torch
from transformers import CLIPModel
from rewards.base_reward import BaseRewardLoss
class CLIPLoss(BaseRewardLoss):
"""CLIP reward loss function for optimization."""
def __init__(
self,
weigthing: float,
dtype: torch.dtype,
device: torch.device,
cache_dir: str,
tokenizer,
memsave: bool = False,
):
self.tokenizer = tokenizer
self.clip_model = CLIPModel.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
cache_dir=cache_dir,
)
# freeze all models parameters
if memsave:
import memsave_torch.nn
self.clip_model = memsave_torch.nn.convert_to_memory_saving(self.clip_model)
self.clip_model = self.clip_model.to(device, dtype=dtype)
self.clip_model.eval()
self.freeze_parameters(self.clip_model.parameters())
super().__init__("CLIP", weigthing)
self.clip_model.gradient_checkpointing_enable()
def get_image_features(self, image: torch.Tensor) -> torch.Tensor:
clip_img_features = self.clip_model.get_image_features(image)
return clip_img_features
def get_text_features(self, prompt: str) -> torch.Tensor:
prompt_token = self.tokenizer(
prompt, return_tensors="pt", padding=True, max_length=77, truncation=True
).to("cuda")
clip_text_features = self.clip_model.get_text_features(**prompt_token)
return clip_text_features
def compute_loss(
self, image_features: torch.Tensor, text_features: torch.Tensor
) -> torch.Tensor:
clip_loss = (
100
- (image_features @ text_features.T).mean()
* self.clip_model.logit_scale.exp()
)
return clip_loss