Upload metrics.py with huggingface_hub
Browse files- metrics.py +76 -13
metrics.py
CHANGED
@@ -16,7 +16,7 @@ from scipy.stats import bootstrap
|
|
16 |
from scipy.stats._warnings_errors import DegenerateDataWarning
|
17 |
|
18 |
from .artifact import Artifact
|
19 |
-
from .dataclass import InternalField, OptionalField
|
20 |
from .logging_utils import get_logger
|
21 |
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
|
22 |
from .operator import (
|
@@ -58,6 +58,16 @@ def nan_mean(x):
|
|
58 |
return np.nanmean(x)
|
59 |
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
class UpdateStream(StreamInstanceOperator):
|
62 |
update: dict
|
63 |
|
@@ -69,11 +79,7 @@ class UpdateStream(StreamInstanceOperator):
|
|
69 |
|
70 |
|
71 |
class Metric(Artifact):
|
72 |
-
|
73 |
-
@abstractmethod
|
74 |
-
def main_score(self):
|
75 |
-
pass
|
76 |
-
|
77 |
# Override 'prediction_type' with the expected type of predictions
|
78 |
# and references. Example: "List[str]", "List[Dict]"", "string".
|
79 |
# If left with default None, a warning will be displayed.
|
@@ -229,6 +235,18 @@ class MetricWithConfidenceInterval(Metric):
|
|
229 |
[instance["score"]["instance"][score_name] for instance in instances]
|
230 |
)
|
231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
@staticmethod
|
233 |
def _all_instance_scores_equal(instances, score_name):
|
234 |
instance_scores = [
|
@@ -625,13 +643,10 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
625 |
# if subgroup_column is not None, a column by the specified name will be required in task_data
|
626 |
subgroup_column = None
|
627 |
implemented_reductions: List[str] = field(
|
628 |
-
default_factory=lambda: ["mean", "group_mean"]
|
629 |
)
|
630 |
|
631 |
-
|
632 |
-
@abstractmethod
|
633 |
-
def reduction_map(self) -> dict:
|
634 |
-
pass
|
635 |
|
636 |
def _validate_group_mean_reduction(self, instances: List[dict]):
|
637 |
"""Ensure that group_mean reduction_map is properly formatted.
|
@@ -739,12 +754,19 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
739 |
|
740 |
field_name_full_prefix = ""
|
741 |
# used for passing to the bootstrapping, depends on whether the groups are fixed or not
|
742 |
-
aggregation_function =
|
743 |
if reduction_type == "mean":
|
|
|
|
|
|
|
|
|
|
|
|
|
744 |
reduction_fields = list(set(reduction_params))
|
745 |
# no group reduction, so resample instances individually
|
746 |
scores_to_resample = instances
|
747 |
elif reduction_type == "group_mean":
|
|
|
748 |
self._validate_group_mean_reduction(instances=instances)
|
749 |
reduction_fields = (
|
750 |
[self.main_score]
|
@@ -941,6 +963,12 @@ class Accuracy(InstanceMetric):
|
|
941 |
return result
|
942 |
|
943 |
|
|
|
|
|
|
|
|
|
|
|
|
|
944 |
class UnsortedListExactMatch(InstanceMetric):
|
945 |
reduction_map = {"mean": ["unsorted_list_exact_match"]}
|
946 |
main_score = "unsorted_list_exact_match"
|
@@ -988,7 +1016,15 @@ class MetricPipeline(MultiStreamOperator, Metric):
|
|
988 |
self.metric.disable_confidence_interval_calculation()
|
989 |
|
990 |
def verify(self):
|
991 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
992 |
|
993 |
def prepare(self):
|
994 |
super().prepare()
|
@@ -3266,3 +3302,30 @@ class BinaryMaxAccuracy(GlobalMetric):
|
|
3266 |
best_thr = thr
|
3267 |
|
3268 |
return {self.main_score: best_acc, "best_thr_max_acc": best_thr}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from scipy.stats._warnings_errors import DegenerateDataWarning
|
17 |
|
18 |
from .artifact import Artifact
|
19 |
+
from .dataclass import AbstractField, InternalField, OptionalField
|
20 |
from .logging_utils import get_logger
|
21 |
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
|
22 |
from .operator import (
|
|
|
58 |
return np.nanmean(x)
|
59 |
|
60 |
|
61 |
+
def nan_max(x):
|
62 |
+
with warnings.catch_warnings():
|
63 |
+
# final mean should be mean of scores, ignoring NaN, hence nanmax
|
64 |
+
# but if the group function values is NaN for ALL values, nanmean throws a
|
65 |
+
# RuntimeWarning that it is calculating the mean of an empty slice (with no non-Nans)
|
66 |
+
# this is the desired behavior, but we want to avoid the warning here
|
67 |
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
68 |
+
return np.nanmax(x)
|
69 |
+
|
70 |
+
|
71 |
class UpdateStream(StreamInstanceOperator):
|
72 |
update: dict
|
73 |
|
|
|
79 |
|
80 |
|
81 |
class Metric(Artifact):
|
82 |
+
main_score: str = AbstractField()
|
|
|
|
|
|
|
|
|
83 |
# Override 'prediction_type' with the expected type of predictions
|
84 |
# and references. Example: "List[str]", "List[Dict]"", "string".
|
85 |
# If left with default None, a warning will be displayed.
|
|
|
235 |
[instance["score"]["instance"][score_name] for instance in instances]
|
236 |
)
|
237 |
|
238 |
+
@staticmethod
|
239 |
+
def max_item_scores(instances: List[dict], score_name: str):
|
240 |
+
"""Calculate max of a set of instance scores (given by score_name), omitting NaN values.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
instances: list of dicts of each instance's instance scores.
|
244 |
+
score_name: score field names to compute the mean for.
|
245 |
+
"""
|
246 |
+
return nan_max(
|
247 |
+
[instance["score"]["instance"][score_name] for instance in instances]
|
248 |
+
)
|
249 |
+
|
250 |
@staticmethod
|
251 |
def _all_instance_scores_equal(instances, score_name):
|
252 |
instance_scores = [
|
|
|
643 |
# if subgroup_column is not None, a column by the specified name will be required in task_data
|
644 |
subgroup_column = None
|
645 |
implemented_reductions: List[str] = field(
|
646 |
+
default_factory=lambda: ["mean", "group_mean", "max"]
|
647 |
)
|
648 |
|
649 |
+
reduction_map: Dict[str, List[str]] = AbstractField()
|
|
|
|
|
|
|
650 |
|
651 |
def _validate_group_mean_reduction(self, instances: List[dict]):
|
652 |
"""Ensure that group_mean reduction_map is properly formatted.
|
|
|
754 |
|
755 |
field_name_full_prefix = ""
|
756 |
# used for passing to the bootstrapping, depends on whether the groups are fixed or not
|
757 |
+
aggregation_function = None
|
758 |
if reduction_type == "mean":
|
759 |
+
aggregation_function = self.average_item_scores
|
760 |
+
reduction_fields = list(set(reduction_params))
|
761 |
+
# no group reduction, so resample instances individually
|
762 |
+
scores_to_resample = instances
|
763 |
+
elif reduction_type == "max":
|
764 |
+
aggregation_function = self.max_item_scores
|
765 |
reduction_fields = list(set(reduction_params))
|
766 |
# no group reduction, so resample instances individually
|
767 |
scores_to_resample = instances
|
768 |
elif reduction_type == "group_mean":
|
769 |
+
aggregation_function = self.average_item_scores
|
770 |
self._validate_group_mean_reduction(instances=instances)
|
771 |
reduction_fields = (
|
772 |
[self.main_score]
|
|
|
963 |
return result
|
964 |
|
965 |
|
966 |
+
class MaxAccuracy(Accuracy):
|
967 |
+
"""Calculate the maximal accuracy over all instances as the global score."""
|
968 |
+
|
969 |
+
reduction_map = {"max": ["accuracy"]}
|
970 |
+
|
971 |
+
|
972 |
class UnsortedListExactMatch(InstanceMetric):
|
973 |
reduction_map = {"mean": ["unsorted_list_exact_match"]}
|
974 |
main_score = "unsorted_list_exact_match"
|
|
|
1016 |
self.metric.disable_confidence_interval_calculation()
|
1017 |
|
1018 |
def verify(self):
|
1019 |
+
assert (
|
1020 |
+
self.metric is not None
|
1021 |
+
), f"'metric' is not set in {self.get_metric_name()}"
|
1022 |
+
assert (
|
1023 |
+
self.main_score is not None
|
1024 |
+
), f"'main_score' is not set in {self.get_metric_name()}"
|
1025 |
+
assert isinstance(
|
1026 |
+
self.metric, Metric
|
1027 |
+
), f"'metric' is not set to a Metric class in {self.get_metric_name()} (type{self.metric})"
|
1028 |
|
1029 |
def prepare(self):
|
1030 |
super().prepare()
|
|
|
3302 |
best_thr = thr
|
3303 |
|
3304 |
return {self.main_score: best_acc, "best_thr_max_acc": best_thr}
|
3305 |
+
|
3306 |
+
|
3307 |
+
KO_ERROR_MESSAGE = """
|
3308 |
+
|
3309 |
+
Additional dependencies required. To install them, run:
|
3310 |
+
`pip install "sacrebleu[ko]"`.
|
3311 |
+
|
3312 |
+
For MacOS: If error on 'mecab-config' show up during installation ], one should run:
|
3313 |
+
|
3314 |
+
`brew install mecab`
|
3315 |
+
`pip install "sacrebleu[ko]"`
|
3316 |
+
|
3317 |
+
"""
|
3318 |
+
|
3319 |
+
|
3320 |
+
class NormalizedSacrebleu(HuggingfaceMetric):
|
3321 |
+
hf_metric_name = "sacrebleu"
|
3322 |
+
hf_main_score = "score"
|
3323 |
+
prediction_type = "str"
|
3324 |
+
main_score = "sacrebleu"
|
3325 |
+
scale = 100.0
|
3326 |
+
scaled_fields = ["sacrebleu", "precisions"]
|
3327 |
+
hf_additional_input_fields_pass_one_value = ["tokenize"]
|
3328 |
+
_requirements_list = {
|
3329 |
+
"mecab_ko": KO_ERROR_MESSAGE,
|
3330 |
+
"mecab_ko_dic": KO_ERROR_MESSAGE,
|
3331 |
+
}
|