from abc import ABC, abstractmethod import torch class BaseRewardLoss(ABC): """ Base class for reward functions implementing a differentiable reward function for optimization. """ def __init__(self, name: str, weighting: float): self.name = name self.weighting = weighting @staticmethod def freeze_parameters(params: torch.nn.ParameterList): for param in params: param.requires_grad = False @abstractmethod def get_image_features(self, image: torch.Tensor) -> torch.Tensor: pass @abstractmethod def get_text_features(self, prompt: str) -> torch.Tensor: pass @abstractmethod def compute_loss( self, image_features: torch.Tensor, text_features: torch.Tensor ) -> torch.Tensor: pass def process_features(self, features: torch.Tensor) -> torch.Tensor: features_normed = features / features.norm(dim=-1, keepdim=True) return features_normed def __call__(self, image: torch.Tensor, prompt: str) -> torch.Tensor: image_features = self.get_image_features(image) text_features = self.get_text_features(prompt) image_features_normed = self.process_features(image_features) text_features_normed = self.process_features(text_features) loss = self.compute_loss(image_features_normed, text_features_normed) return loss