File size: 753 Bytes
8b414b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from typing import Dict
import pandas as pd
from sklearn.metrics import mean_squared_error
from src.utils import validate_y
class MSEMetric:
def __init__(self):
super().__init__()
@staticmethod
def evaluate_class_rmse(y_pred: pd.DataFrame, y_true: pd.DataFrame) -> Dict[str, float]:
validate_y(y_pred)
validate_y(y_true)
result = {}
for column in y_pred.drop(columns=['text_id']).columns:
result[column] = mean_squared_error(y_pred[column], y_true[column], squared=False)
return result
def evaluate(self, y_pred: pd.DataFrame, y_true: pd.DataFrame) -> float:
result = self.evaluate_class_rmse(y_pred, y_true)
return sum(result.values()) / len(result)
|