Elron commited on
Commit
1e4d944
1 Parent(s): 49c1c5f

Upload metrics.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. metrics.py +1338 -199
metrics.py CHANGED
@@ -1,19 +1,24 @@
1
  import re
2
  import string
3
  import uuid
 
4
  from abc import ABC, abstractmethod
5
  from collections import Counter
 
6
  from dataclasses import field
 
7
  from typing import Any, Dict, Generator, List, Optional, Tuple
8
 
9
  import evaluate
10
  import numpy
11
  import numpy as np
12
  from scipy.stats import bootstrap
 
13
 
14
  from .artifact import Artifact
15
  from .dataclass import InternalField, OptionalField
16
  from .logging_utils import get_logger
 
17
  from .operator import (
18
  MultiStreamOperator,
19
  SingleStreamOperator,
@@ -22,14 +27,17 @@ from .operator import (
22
  )
23
  from .operators import CopyFields
24
  from .random_utils import get_seed
 
25
  from .stream import MultiStream, Stream
26
- from .type_utils import isoftype
27
 
28
  logger = get_logger()
29
- # The default number of resamples used to estimate the confidence intervals
30
- # global and instances metrics. Use None to disable confidence interval computation by default.
31
- _N_RESAMPLES_DEFAULT_FOR_INSTANCE_METRICS = 1000
32
- _N_RESAMPLES_DEFAULT_FOR_GLOBAL_METRICS = 100
 
 
33
 
34
 
35
  def abstract_factory():
@@ -40,6 +48,18 @@ def abstract_field():
40
  return field(default_factory=abstract_factory)
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  class UpdateStream(StreamInstanceOperator):
44
  update: dict
45
 
@@ -57,6 +77,48 @@ class Metric(Artifact):
57
  def main_score(self):
58
  pass
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  class MetricWithConfidenceInterval(Metric):
62
  # The number of resamples used to estimate the confidence intervals of this metric.
@@ -73,7 +135,12 @@ class MetricWithConfidenceInterval(Metric):
73
  return np.random.default_rng(hash(get_seed()) & _max_32bit)
74
 
75
  def disable_confidence_interval_calculation(self):
 
76
  self.n_resamples = None
 
 
 
 
77
 
78
  def _can_compute_confidence_intervals(self, num_predictions):
79
  return (
@@ -82,45 +149,117 @@ class MetricWithConfidenceInterval(Metric):
82
  and num_predictions > 1
83
  )
84
 
85
- def score_based_confidence_interval(self, instances):
86
- """Compute confidence intervals based on existing scores, already computed on the input instances.
 
87
 
88
- score_names: List[str]
89
- Compute a confidence interval for each score_name from this list.
90
- instances:
91
- The instances for which the confidence intervals are computed.
92
  """
93
- from statistics import mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
 
 
95
  result = {}
96
 
97
  if not self._can_compute_confidence_intervals(num_predictions=len(instances)):
98
  return result
99
 
100
- score_names = (
101
- self.ci_scores if self.ci_scores is not None else [self.main_score]
102
- )
103
-
 
 
104
  for score_name in score_names:
105
- scores = [
106
- instance["score"]["instance"][score_name] for instance in instances
107
- ]
 
 
 
 
 
 
 
 
 
 
 
108
  ci = bootstrap(
109
- (scores,),
110
- statistic=mean,
111
  n_resamples=self.n_resamples,
112
  confidence_level=self.confidence_level,
113
  random_state=self.new_random_generator(),
114
  ).confidence_interval
115
- result[f"{score_name}_ci_low"] = ci.low
116
- result[f"{score_name}_ci_high"] = ci.high
 
117
  if score_name == self.main_score:
118
  result["score_ci_low"] = ci.low
119
  result["score_ci_high"] = ci.high
120
  return result
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def compute_global_confidence_intervals(
123
- self, references, predictions, additional_inputs, score_name
124
  ):
125
  """Computed confidence intervals for a set of references and predictions."""
126
  random_gen = self.new_random_generator()
@@ -128,12 +267,12 @@ class MetricWithConfidenceInterval(Metric):
128
  def statistic(arr, axis):
129
  # arr is a 2d array where each row is a resampling, so we
130
  # iterate over the rows and compute the metric on each resampling
131
- def metric(sample_refs, sample_preds, sample_additional_inputs):
132
  try:
133
  return self._compute(
134
  references=sample_refs,
135
  predictions=sample_preds,
136
- additional_inputs=sample_additional_inputs,
137
  )["score"]
138
  except Exception as e:
139
  # this happens in edge cases, for example, when the sampling creates a
@@ -141,40 +280,21 @@ class MetricWithConfidenceInterval(Metric):
141
  logger.info(f"Warning in {self.__class__.__name__}", e)
142
  return np.nan
143
 
 
144
  scores = numpy.apply_along_axis(
145
  lambda x: metric(
146
  sample_refs=[references[i] for i in x],
147
  sample_preds=[predictions[i] for i in x],
148
- sample_additional_inputs=[additional_inputs[i] for i in x],
149
  ),
150
  axis=axis,
151
  arr=arr,
152
  )
153
 
154
- # when running with bca interval (default), the statistic is called twice: with the
155
- # original data and with the resamples. here we want to focus only on the latter.
156
- if scores.size > 1:
157
- # here we deal with samples on which the metric could not be computed. These are
158
- # edge cases - for example, when the sample contains only empty strings.
159
- # CI is about the distribution around the statistic (e.g. mean), it doesn't deal with
160
- # cases in which the metric is not computable. Therefore, we ignore these edge cases
161
- # as part of the computation of CI. The question is how to implement this policy.
162
- # Options:
163
- # 1. skip the errors and return a shorter array => this fails because Scipy demans
164
- # this callback (i.e. the statistic() callback) to return an array of the same size
165
- # as the number of resamples
166
- # 2. Put np.nan for the errors => this fails because in such case the ci itself
167
- # becomes np.nan. So one edge case can fail the whole CI computation.
168
- # 3. Replace the errors with a sampling from the successful cases => this is what
169
- # is implemented.
170
- error_indices = numpy.isnan(scores)
171
- n_errors = sum(error_indices)
172
- if n_errors > 0:
173
- new_scores = random_gen.choice(scores, n_errors, replace=True)
174
- scores = scores[~error_indices]
175
- scores = np.concatenate([scores, new_scores])
176
-
177
- return scores
178
 
179
  result = {}
180
  num_predictions = len(predictions)
@@ -202,12 +322,15 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
202
  need to be considered. Accuracy, on the other hand, is just an average of the accuracy of all the instances.
203
  """
204
 
205
- n_resamples = _N_RESAMPLES_DEFAULT_FOR_GLOBAL_METRICS
 
 
 
206
 
207
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
208
  references = []
209
  predictions = []
210
- additional_inputs = []
211
  global_score = {}
212
 
213
  instances = []
@@ -226,31 +349,40 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
226
  predictions.append(instance_prediction)
227
  instances.append(instance)
228
 
229
- instance_additional_inputs = (
230
- instance["additional_inputs"] if "additional_inputs" in instance else {}
231
  )
232
- additional_inputs.append(instance_additional_inputs)
233
- try:
234
- instance_score = self._compute(
235
- [instance_references],
236
- [instance_prediction],
237
- [instance_additional_inputs],
238
- )
239
- except:
240
- instance_score = {"score": None, "score_name": self.main_score}
 
 
 
 
 
 
 
 
 
241
 
242
  if isinstance(self.main_score, str):
243
- instance_score[self.main_score] = None
244
 
245
  instance["score"]["instance"].update(instance_score)
246
 
247
- result = self._compute(references, predictions, additional_inputs)
248
 
249
  global_score.update(result)
250
 
251
  score_name = global_score["score_name"]
252
  confidence_interval = self.compute_global_confidence_intervals(
253
- references, predictions, additional_inputs, score_name
254
  )
255
  global_score.update(confidence_interval)
256
 
@@ -262,9 +394,9 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
262
  self,
263
  references: List[List[str]],
264
  predictions: List[str],
265
- additional_inputs: List[Any],
266
  ) -> dict:
267
- result = self.compute(references, predictions, additional_inputs)
268
  result["score"] = result[self.main_score]
269
  result["score_name"] = self.main_score
270
  return result
@@ -274,13 +406,25 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
274
  self,
275
  references: List[List[Any]],
276
  predictions: List[Any],
277
- additional_inputs: List[Any],
278
  ) -> dict:
 
 
 
 
 
 
 
 
 
 
279
  pass
280
 
281
 
282
  class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
283
- n_resamples = _N_RESAMPLES_DEFAULT_FOR_INSTANCE_METRICS
 
 
284
  main_score: str
285
  reduction_map: Dict[str, List[str]]
286
 
@@ -301,8 +445,8 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
301
  ),
302
  )
303
 
304
- additional_inputs = [
305
- instance["additional_inputs"] if "additional_inputs" in instance else {}
306
  for instance in stream
307
  ]
308
 
@@ -310,7 +454,7 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
310
  instance_scores = self.compute(
311
  references=references,
312
  predictions=predictions,
313
- additional_inputs=additional_inputs,
314
  )
315
 
316
  # add the score and score_name fields
@@ -334,8 +478,6 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
334
  ), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
335
 
336
  if reduction == "mean":
337
- from statistics import mean
338
-
339
  for field_name in fields:
340
  global_score[field_name] = mean(
341
  [
@@ -347,8 +489,13 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
347
  global_score["score"] = global_score[field_name]
348
  global_score["score_name"] = self.main_score
349
 
 
 
 
 
 
350
  confidence_interval = self.score_based_confidence_interval(
351
- instances=instances
352
  )
353
  global_score.update(confidence_interval)
354
 
@@ -360,33 +507,217 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
360
  self,
361
  references: List[List[Any]],
362
  predictions: List[Any],
363
- additional_inputs: List[Dict],
364
  ) -> List[Dict[str, Any]]:
365
  pass
366
 
367
 
368
  class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
369
- n_resamples = _N_RESAMPLES_DEFAULT_FOR_INSTANCE_METRICS
 
 
 
 
 
 
 
 
370
 
371
- implemented_reductions: List[str] = field(default_factory=lambda: ["mean"])
 
 
 
 
 
 
 
 
 
 
 
372
 
373
  @property
374
  @abstractmethod
375
  def reduction_map(self) -> dict:
376
  pass
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  global_score = {}
380
  instances = []
381
 
382
  for instance in stream:
383
  refs, pred = instance["references"], instance["prediction"]
384
- additional_inputs = (
385
- instance["additional_inputs"] if "additional_inputs" in instance else {}
386
- )
387
 
388
  instance_score = self.compute(
389
- references=refs, prediction=pred, additional_inputs=additional_inputs
390
  )
391
  instance_score["score"] = instance_score[self.main_score]
392
  instance_score["score_name"] = self.main_score
@@ -399,36 +730,100 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
399
 
400
  instances.append(instance)
401
 
402
- for reduction, fields in self.reduction_map.items():
403
- assert (
404
- reduction in self.implemented_reductions
405
- ), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
406
 
407
- if reduction == "mean":
408
- from statistics import mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
- for field_name in fields:
411
- scores = [
412
- instance["score"]["instance"][field_name]
413
- for instance in instances
414
- ]
415
- global_score[field_name] = mean(scores)
416
- if field_name == self.main_score:
417
- global_score["score"] = global_score[field_name]
418
- global_score["score_name"] = self.main_score
419
 
420
- confidence_interval = self.score_based_confidence_interval(
421
- instances=instances
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  )
423
- global_score.update(confidence_interval)
424
 
425
- for instance in instances:
426
- yield instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
  @abstractmethod
429
- def compute(
430
- self, references: List[Any], prediction: Any, additional_inputs: Dict
431
- ) -> dict:
432
  pass
433
 
434
 
@@ -445,7 +840,7 @@ class Squad(GlobalMetric):
445
  self,
446
  references: List[List[str]],
447
  predictions: List[str],
448
- additional_inputs: List[Dict],
449
  ) -> dict:
450
  ids = [str(uuid.uuid4()).replace("-", "") for _ in range(len(predictions))]
451
  formatted_predictions = [
@@ -466,9 +861,10 @@ class Squad(GlobalMetric):
466
  class Accuracy(InstanceMetric):
467
  reduction_map = {"mean": ["accuracy"]}
468
  main_score = "accuracy"
 
469
 
470
  def compute(
471
- self, references: List[Any], prediction: Any, additional_inputs: List[Dict]
472
  ) -> dict:
473
  result = {
474
  self.main_score: float(
@@ -483,13 +879,14 @@ class Accuracy(InstanceMetric):
483
  class StringContainment(InstanceMetric):
484
  reduction_map = {"mean": ["string_containment"]}
485
  main_score = "string_containment"
 
486
 
487
  def compute(
488
- self, references: List[Any], prediction: Any, additional_inputs: List[Dict]
489
  ) -> dict:
490
  result = {
491
  self.main_score: float(
492
- any(str(reference) in prediction for reference in references)
493
  )
494
  }
495
  result["score"] = result[self.main_score]
@@ -505,6 +902,13 @@ class MetricPipeline(MultiStreamOperator, Metric):
505
  )
506
  metric: Metric = None
507
 
 
 
 
 
 
 
 
508
  def verify(self):
509
  assert self.main_score is not None, "main_score is not set"
510
 
@@ -569,37 +973,37 @@ class HuggingfaceMetric(GlobalMetric):
569
  self,
570
  references: List[List[Any]],
571
  predictions: List[Any],
572
- additional_inputs: List[Dict],
573
  ) -> dict:
574
- passed_additional_inputs = {}
575
  for additional_input_field in self.hf_additional_input_fields:
576
  assert (
577
- additional_input_field in additional_inputs[0]
578
- ), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in additional inputs: {additional_inputs[0]}"
579
- passed_additional_inputs[additional_input_field] = [
580
  additional_input[additional_input_field]
581
- for additional_input in additional_inputs
582
  ]
583
  for additional_input_field in self.hf_additional_input_fields_pass_one_value:
584
  assert (
585
- additional_input_field in additional_inputs[0]
586
- ), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in additional inputs: {additional_inputs[0]}"
587
 
588
  values = {
589
  additional_input[additional_input_field]
590
- for additional_input in additional_inputs
591
  }
592
  assert (
593
  len(values) == 1
594
  ), f"Values of '{additional_input_field}' field required by {__class__.__name__} should all be the same, but have multiple values {values}"
595
 
596
- passed_additional_inputs[additional_input_field] = next(iter(values))
597
 
598
- # add check that all required fields in self.metrics are in passed_additional_inputs print(passed_additional_inputs)
599
  result = self.metric.compute(
600
  predictions=predictions,
601
  references=references,
602
- **passed_additional_inputs,
603
  **self.hf_compute_args,
604
  )
605
  if self.hf_main_score:
@@ -641,23 +1045,23 @@ class HuggingfaceBulkMetric(BulkInstanceMetric):
641
  self,
642
  references: List[List[str]],
643
  predictions: List[str],
644
- additional_inputs: List[Any],
645
  ) -> List[Dict[str, Any]]:
646
- passed_additional_inputs = {}
647
  for additional_input_field in self.hf_additional_input_fields:
648
  assert (
649
- additional_input_field in additional_inputs[0]
650
- ), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in additional inputs: {additional_inputs[0]}"
651
- passed_additional_inputs[additional_input_field] = [
652
  additional_input[additional_input_field]
653
- for additional_input in additional_inputs
654
  ]
655
- # add check that all required fields in self.metrics are in passed_additional_inputs
656
 
657
  scores = self.metric.compute(
658
  predictions=predictions,
659
  references=references,
660
- **passed_additional_inputs,
661
  **self.hf_compute_args,
662
  )
663
 
@@ -692,7 +1096,7 @@ class F1(GlobalMetric):
692
  self,
693
  references: List[List[str]],
694
  predictions: List[str],
695
- additional_inputs: List[Dict],
696
  ) -> dict:
697
  assert all(
698
  len(reference) == 1 for reference in references
@@ -714,8 +1118,6 @@ class F1(GlobalMetric):
714
  average=self.average,
715
  )
716
  if isinstance(result["f1"], numpy.ndarray):
717
- from statistics import mean
718
-
719
  final_result = {self.main_score: mean(result["f1"])}
720
  for i, label in enumerate(labels):
721
  final_result["f1_" + self.id_to_str[label]] = result["f1"][i]
@@ -742,7 +1144,6 @@ class F1MultiLabel(GlobalMetric):
742
  _metric = None
743
  main_score = "f1_macro"
744
  average = None # Report per class then aggregate by mean
745
- classes_to_ignore = ["none"]
746
  metric = "f1"
747
 
748
  def prepare(self):
@@ -767,7 +1168,7 @@ class F1MultiLabel(GlobalMetric):
767
  self,
768
  references: List[List[str]],
769
  predictions: List[List[str]],
770
- additional_inputs: List[Dict],
771
  ) -> dict:
772
  self.str_to_id = {}
773
  self.id_to_str = {}
@@ -775,13 +1176,9 @@ class F1MultiLabel(GlobalMetric):
775
  self._validate_references_and_prediction(references, predictions)
776
  references = [reference[0] for reference in references]
777
 
778
- labels = [
779
- lbl
780
- for lbl in {label for reference in references for label in reference}
781
- if lbl not in self.classes_to_ignore
782
- ]
783
  # if no classes are left then F1 is not defined
784
- # (e.g. only "none" in references)
785
  if len(labels) == 0:
786
  return {self.main_score: float("nan")}
787
 
@@ -809,8 +1206,6 @@ class F1MultiLabel(GlobalMetric):
809
  labels=labels_param,
810
  )
811
  if isinstance(result[self.metric], numpy.ndarray):
812
- from statistics import mean
813
-
814
  assert (
815
  len(result[self.metric]) == len(labels)
816
  ), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
@@ -883,6 +1278,8 @@ class Rouge(HuggingfaceMetric):
883
 
884
  sent_split_newline: bool = True
885
 
 
 
886
  def prepare(self):
887
  super().prepare()
888
 
@@ -895,7 +1292,7 @@ class Rouge(HuggingfaceMetric):
895
  nltk.download("punkt")
896
  self.sent_tokenize = nltk.sent_tokenize
897
 
898
- def compute(self, references, predictions, additional_inputs: List[Dict]):
899
  if self.sent_split_newline:
900
  predictions = [
901
  "\n".join(self.sent_tokenize(prediction.strip()))
@@ -905,13 +1302,16 @@ class Rouge(HuggingfaceMetric):
905
  ["\n".join(self.sent_tokenize(r.strip())) for r in reference]
906
  for reference in references
907
  ]
908
- return super().compute(references, predictions, additional_inputs)
909
 
910
 
911
  # Computes char edit distance, ignoring whitespace
912
  class CharEditDistanceAccuracy(InstanceMetric):
913
  reduction_map = {"mean": ["char_edit_dist_accuracy"]}
914
  main_score = "char_edit_dist_accuracy"
 
 
 
915
 
916
  def prepare(self):
917
  super().prepare()
@@ -919,9 +1319,7 @@ class CharEditDistanceAccuracy(InstanceMetric):
919
 
920
  self.eval = editdistance.eval
921
 
922
- def compute(
923
- self, references, prediction: str, additional_inputs: List[Dict]
924
- ) -> dict:
925
  assert (
926
  len(references) == 1
927
  ), f"Expected only one reference , but received: {references}"
@@ -939,11 +1337,13 @@ class Wer(HuggingfaceMetric):
939
  hf_metric_name = "wer"
940
  main_score = "wer"
941
 
 
 
942
  def compute(
943
  self,
944
  references: List[List[str]],
945
  predictions: List[str],
946
- additional_inputs: List[Dict],
947
  ) -> dict:
948
  assert all(
949
  len(reference) == 1 for reference in references
@@ -955,6 +1355,43 @@ class Wer(HuggingfaceMetric):
955
  return {self.main_score: result}
956
 
957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
958
  class MatthewsCorrelation(HuggingfaceMetric):
959
  hf_metric_name = "matthews_correlation"
960
  main_score = "matthews_correlation"
@@ -970,7 +1407,7 @@ class MatthewsCorrelation(HuggingfaceMetric):
970
  self,
971
  references: List[List[str]],
972
  predictions: List[str],
973
- additional_inputs: List[Dict],
974
  ) -> dict:
975
  formatted_references = [
976
  self.get_str_id(reference[0]) for reference in references
@@ -983,6 +1420,33 @@ class MatthewsCorrelation(HuggingfaceMetric):
983
  )
984
 
985
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
986
  class CustomF1(GlobalMetric):
987
  main_score = "f1_micro"
988
  groups = None
@@ -1036,9 +1500,9 @@ class CustomF1(GlobalMetric):
1036
  except ZeroDivisionError:
1037
  return self.zero_division
1038
 
1039
- def get_groups(self, elements, additional_inputs):
1040
  groups = set()
1041
- for sublist, additional_input in zip(elements, additional_inputs):
1042
  for e in sublist:
1043
  if self.should_ignore_element(e, additional_input):
1044
  continue
@@ -1049,7 +1513,7 @@ class CustomF1(GlobalMetric):
1049
  self,
1050
  references: List[List[Any]],
1051
  predictions: List[Any],
1052
- additional_inputs: List[Dict],
1053
  ) -> dict:
1054
  # in case reference are List[List[List[Any]]] and predictions are List[List[Any]]:
1055
  if (
@@ -1065,12 +1529,12 @@ class CustomF1(GlobalMetric):
1065
  )
1066
 
1067
  if self.groups is None:
1068
- groups = self.get_groups(references, additional_inputs)
1069
  else:
1070
  groups = self.groups
1071
  groups_statistics = {}
1072
  for references_batch, predictions_batch, additional_input in zip(
1073
- references, predictions, additional_inputs
1074
  ):
1075
  grouped_references = self.group_elements(references_batch, additional_input)
1076
  grouped_predictions = self.group_elements(
@@ -1187,10 +1651,11 @@ class TokenOverlap(InstanceMetric):
1187
  ci_scores = ["f1", "precision", "recall"]
1188
 
1189
  def compute(
1190
- self, references: List[Any], prediction: Any, additional_inputs: List[Dict]
1191
  ) -> dict:
1192
  results = [
1193
- self._compute_single_ref(reference, prediction) for reference in references
 
1194
  ]
1195
  return {
1196
  measure: max(r[i] for r in results)
@@ -1200,8 +1665,8 @@ class TokenOverlap(InstanceMetric):
1200
  def _compute_single_ref(
1201
  self, reference: Any, prediction: Any
1202
  ) -> Tuple[float, float, float]:
1203
- prediction_tokens = normalize_answer(prediction).split()
1204
- reference_tokens = normalize_answer(reference).split()
1205
  common = Counter(prediction_tokens) & Counter(reference_tokens)
1206
  num_same = sum(common.values())
1207
  if num_same == 0:
@@ -1221,9 +1686,11 @@ class BertScore(HuggingfaceBulkMetric):
1221
  ci_scores = ["f1", "precision", "recall"]
1222
  model_name: str
1223
 
 
 
1224
  def prepare(self):
1225
  super().prepare()
1226
- self.hf_compute_args = {"model_type": self.model_name}
1227
 
1228
 
1229
  class SentenceBert(BulkInstanceMetric):
@@ -1233,19 +1700,23 @@ class SentenceBert(BulkInstanceMetric):
1233
 
1234
  model_name: str
1235
 
 
 
1236
  def prepare(self):
1237
  super().prepare()
 
1238
  from sentence_transformers import SentenceTransformer
1239
  from sentence_transformers import util as sbert_util
1240
 
1241
- self.model = SentenceTransformer(self.model_name)
 
1242
  self.util = sbert_util
1243
 
1244
  def compute(
1245
  self,
1246
  references: List[List[Any]],
1247
  predictions: List[Any],
1248
- additional_inputs: List[Dict],
1249
  ) -> List[Dict[str, Any]]:
1250
  scores = []
1251
 
@@ -1260,9 +1731,9 @@ class SentenceBert(BulkInstanceMetric):
1260
  count += len(ref_group)
1261
 
1262
  # compute s-bert embeddings
1263
- preds_emb = self.model.encode(predictions)
1264
  refs_emb = self.model.encode(
1265
- [ref for ref_group in references for ref in ref_group]
1266
  )
1267
 
1268
  # for each candidate, pick the reference with the highest score
@@ -1280,17 +1751,23 @@ class Reward(BulkInstanceMetric):
1280
 
1281
  model_name: str
1282
 
 
 
1283
  def prepare(self):
1284
  super().prepare()
 
1285
  from transformers import pipeline
1286
 
1287
- self.pipe = pipeline("text-classification", model=self.model_name)
 
 
 
1288
 
1289
  def compute(
1290
  self,
1291
  references: List[List[Any]],
1292
  predictions: List[Any],
1293
- additional_inputs: List[Dict],
1294
  ) -> List[Dict[str, Any]]:
1295
  # treat the references as the questions and the predictions as answers
1296
  # assume a single reference
@@ -1316,25 +1793,27 @@ class Perplexity(BulkInstanceMetric):
1316
  batch_size: int = 32
1317
  model_name: str
1318
 
 
 
1319
  def compute(
1320
  self,
1321
  references: List[List[Any]],
1322
  predictions: List[Any],
1323
- additional_inputs: List[Dict],
1324
  ) -> List[Dict[str, Any]]:
1325
  """Computes the likelihood of generating text Y after text X - P(Y|X).
1326
 
1327
- :param references: the list of Y texts as a list of singletons.
1328
- :param predictions: the list of X texts as a plain list of strings
1329
 
1330
- :return: the likelihood of generating text Y_i after text X_i = P(Y_i|X_i) for every i.
1331
  """
1332
  sources = []
1333
  targets = []
1334
  for prediction, instance_references in zip(predictions, references):
1335
  for instance_reference in instance_references:
1336
- sources.append(f"{self.perplexity_prompt} {prediction}")
1337
- targets.append(instance_reference)
1338
 
1339
  from transformers import AutoConfig
1340
 
@@ -1375,9 +1854,11 @@ class Perplexity(BulkInstanceMetric):
1375
  from transformers import AutoTokenizer
1376
 
1377
  self.model_name = model_name
 
 
 
 
1378
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
1379
- self.model = self.model_class().from_pretrained(self.model_name)
1380
- self.is_cuda = torch.cuda.is_available()
1381
 
1382
  def compute_lm(
1383
  self, source: List[str], target: List[str], batch_size: int
@@ -1470,16 +1951,9 @@ class Perplexity(BulkInstanceMetric):
1470
  return AutoModelForSeq2SeqLM
1471
 
1472
  def compute_batch(self, tokens_source, tokens_target):
1473
- tokens_docs_ids = tokens_source["input_ids"]
1474
- attention = tokens_source["attention_mask"]
1475
- labels = tokens_target["input_ids"]
1476
-
1477
- if self.is_cuda:
1478
- tokens_docs_ids, attention, labels = (
1479
- tokens_docs_ids.cuda(),
1480
- attention.cuda(),
1481
- labels.cuda(),
1482
- )
1483
 
1484
  logits = self.model(
1485
  input_ids=tokens_docs_ids.long(),
@@ -1519,12 +1993,9 @@ class Perplexity(BulkInstanceMetric):
1519
  # replace the padding token in the labels by -100
1520
  labels[labels == self.tokenizer.pad_token_id] = -100
1521
 
1522
- if self.is_cuda:
1523
- tokens, attention, labels = (
1524
- tokens.cuda(),
1525
- attention.cuda(),
1526
- labels.cuda(),
1527
- )
1528
 
1529
  # no need to pass labels as we calculate the loss below per document
1530
  model_output = self.model(
@@ -1558,6 +2029,8 @@ class NDCG(GlobalMetric):
1558
 
1559
  main_score = "nDCG"
1560
 
 
 
1561
  def prepare(self):
1562
  from sklearn.metrics import ndcg_score
1563
 
@@ -1568,15 +2041,12 @@ class NDCG(GlobalMetric):
1568
  self,
1569
  references: List[List[Any]],
1570
  predictions: List[Any],
1571
- additional_inputs: List[Any],
1572
  ) -> dict:
1573
  from collections import defaultdict
1574
- from statistics import mean
1575
 
1576
  query_to_predictions_and_references = defaultdict(lambda: [[], []])
1577
- for reference, pred, inputs_dict in zip(
1578
- references, predictions, additional_inputs
1579
- ):
1580
  query = inputs_dict.get("query")
1581
  query_to_predictions_and_references[query][0].append(pred)
1582
  query_to_predictions_and_references[query][1].append(reference)
@@ -1606,9 +2076,7 @@ class NDCG(GlobalMetric):
1606
 
1607
 
1608
  class RetrievalMetric(InstanceMetric):
1609
- def compute(
1610
- self, references: List[Any], prediction: Any, additional_inputs: Dict
1611
- ) -> dict:
1612
  # digest input
1613
  pred_ids: List[Any] = prediction
1614
  ref_ids: List[Any] = list(dict.fromkeys(references))
@@ -1681,6 +2149,7 @@ class RetrievalMetric(InstanceMetric):
1681
  class MRR(RetrievalMetric):
1682
  reduction_map = {"mean": ["mrr"]}
1683
  main_score = "mrr"
 
1684
 
1685
  def _compute(
1686
  self,
@@ -1697,6 +2166,7 @@ class MRR(RetrievalMetric):
1697
  class MAP(RetrievalMetric):
1698
  reduction_map = {"mean": ["map"]}
1699
  main_score = "map"
 
1700
 
1701
  def _compute(
1702
  self,
@@ -1765,3 +2235,672 @@ class KPA(CustomF1):
1765
 
1766
  def should_ignore_element(self, element, additional_input):
1767
  return element == "none"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import string
3
  import uuid
4
+ import warnings
5
  from abc import ABC, abstractmethod
6
  from collections import Counter
7
+ from copy import deepcopy
8
  from dataclasses import field
9
+ from statistics import mean
10
  from typing import Any, Dict, Generator, List, Optional, Tuple
11
 
12
  import evaluate
13
  import numpy
14
  import numpy as np
15
  from scipy.stats import bootstrap
16
+ from scipy.stats._warnings_errors import DegenerateDataWarning
17
 
18
  from .artifact import Artifact
19
  from .dataclass import InternalField, OptionalField
20
  from .logging_utils import get_logger
21
+ from .metric_utils import InstanceInput, MetricRequest, MetricResponse
22
  from .operator import (
23
  MultiStreamOperator,
24
  SingleStreamOperator,
 
27
  )
28
  from .operators import CopyFields
29
  from .random_utils import get_seed
30
+ from .settings_utils import get_settings
31
  from .stream import MultiStream, Stream
32
+ from .type_utils import isoftype, to_float_or_default
33
 
34
  logger = get_logger()
35
+ settings = get_settings()
36
+
37
+ warnings.filterwarnings("ignore", category=DegenerateDataWarning)
38
+
39
+
40
+ warnings.filterwarnings("ignore", category=DegenerateDataWarning)
41
 
42
 
43
  def abstract_factory():
 
48
  return field(default_factory=abstract_factory)
49
 
50
 
51
+ def nan_mean(x):
52
+ import warnings
53
+
54
+ with warnings.catch_warnings():
55
+ # final mean should be mean of scores, ignoring NaN, hence nanmean
56
+ # but if the group function values is NaN for ALL values, nanmean throws a
57
+ # RuntimeWarning that it is calculating the mean of an empty slice (with no non-Nans)
58
+ # this is the desired behavior, but we want to avoid the warning here
59
+ warnings.simplefilter("ignore", category=RuntimeWarning)
60
+ return np.nanmean(x)
61
+
62
+
63
  class UpdateStream(StreamInstanceOperator):
64
  update: dict
65
 
 
77
  def main_score(self):
78
  pass
79
 
80
+ def consume_stream(self, stream: Stream):
81
+ references = []
82
+ predictions = []
83
+ additional_inputs = []
84
+ instances = []
85
+ for instance in stream:
86
+ references.append(instance["references"])
87
+ predictions.append(instance["prediction"])
88
+ additional_inputs.append(
89
+ instance["additional_inputs"] if "additional_inputs" in instance else {}
90
+ )
91
+ instances.append(instance)
92
+ return predictions, references, additional_inputs, instances
93
+
94
+ @staticmethod
95
+ def update_instance_scores(instances, instances_scores: List[Dict[str, Any]]):
96
+ for instance, new_scores in zip(instances, instances_scores):
97
+ if "score" not in instance:
98
+ instance["score"] = {}
99
+ scores = instance["score"]
100
+ if "instance" not in scores:
101
+ scores["instance"] = {}
102
+ scores["instance"].update(new_scores)
103
+
104
+ @staticmethod
105
+ def set_global_score(instances, global_score: Dict[str, Any]):
106
+ for instance in instances:
107
+ if "score" not in instance:
108
+ instance["score"] = {}
109
+ scores = instance["score"]
110
+ if "global" not in scores:
111
+ scores["global"] = {}
112
+ scores["global"] = global_score
113
+
114
+ @abstractmethod
115
+ def disable_confidence_interval_calculation(self):
116
+ pass
117
+
118
+ @abstractmethod
119
+ def set_n_resamples(self, n_resample):
120
+ pass
121
+
122
 
123
  class MetricWithConfidenceInterval(Metric):
124
  # The number of resamples used to estimate the confidence intervals of this metric.
 
135
  return np.random.default_rng(hash(get_seed()) & _max_32bit)
136
 
137
  def disable_confidence_interval_calculation(self):
138
+ n = self.n_resamples
139
  self.n_resamples = None
140
+ return n
141
+
142
+ def set_n_resamples(self, n_resamples):
143
+ self.n_resamples = n_resamples
144
 
145
  def _can_compute_confidence_intervals(self, num_predictions):
146
  return (
 
149
  and num_predictions > 1
150
  )
151
 
152
+ @staticmethod
153
+ def average_item_scores(instances: List[dict], score_name: str):
154
+ """Calculate mean of a set of instance scores (given by score_name), omitting NaN values.
155
 
156
+ Args:
157
+ instances: list of dicts of each instance's instance scores.
158
+ score_name: score field names to compute the mean for.
 
159
  """
160
+ return nan_mean(
161
+ [instance["score"]["instance"][score_name] for instance in instances]
162
+ )
163
+
164
+ def score_based_confidence_interval(
165
+ self,
166
+ instances: List[dict],
167
+ score_names: List[str],
168
+ aggregation_func=None,
169
+ ci_score_prefix="",
170
+ ):
171
+ """Compute confidence intervals based on existing scores, already computed on the input instances.
172
+
173
+ Unlike GlobalMetric, this is simply a function of the instance scores (possibly taking into account task_data field),
174
+ so they don't need to be recomputed after every bootstrap draw.
175
+
176
+ Args:
177
+ instances: The instances for which the confidence intervals are computed; should already have the relevant instance scores calculated.
178
+ score_names: List of instance score field names to compute a confidence interval for.
179
+ aggregation_func: A function with arguments instances, field_name; is applied on list of instances (which may include task_data
180
+ field, as well as the prediction and references), and the field_name; default is simply to take the mean field_name from
181
+ instances after resampling, if argument is None.
182
+ ci_score_prefix: An optional string prefix to the score_name in the CI. Useful in cases where the
183
+ aggregation_func is something other than the mean
184
 
185
+ Returns:
186
+ Dict of confidence interval values
187
+ """
188
  result = {}
189
 
190
  if not self._can_compute_confidence_intervals(num_predictions=len(instances)):
191
  return result
192
 
193
+ ci_score_prefix = str(ci_score_prefix)
194
+ if aggregation_func is None:
195
+ # if aggregation_func is None, we simply take the mean of the resampled instance scores
196
+ # otherwise, the aggregation_func needs to be applied AFTER resampling the instances;
197
+ # that is, re-form the groups, calculate the function, and take the mean of the group scores
198
+ aggregation_func = self.average_item_scores
199
  for score_name in score_names:
200
+ # need to redefine the statistic function within the loop because score_name is a loop variable
201
+ def statistic(arr, axis, score_name=score_name):
202
+ # arr is a 2d array where each row is a resampling, so we
203
+ # iterate over the rows and compute the metric on each resampling
204
+ scores = numpy.apply_along_axis(
205
+ lambda resampled_instances: aggregation_func(
206
+ resampled_instances, score_name
207
+ ),
208
+ axis=axis,
209
+ arr=arr,
210
+ )
211
+ return self.resample_from_non_nan(scores)
212
+
213
+ # apply bootstrap only on the relevant field
214
  ci = bootstrap(
215
+ (instances,),
216
+ statistic=statistic,
217
  n_resamples=self.n_resamples,
218
  confidence_level=self.confidence_level,
219
  random_state=self.new_random_generator(),
220
  ).confidence_interval
221
+ full_score_name = ci_score_prefix + score_name
222
+ result[f"{full_score_name}_ci_low"] = ci.low
223
+ result[f"{full_score_name}_ci_high"] = ci.high
224
  if score_name == self.main_score:
225
  result["score_ci_low"] = ci.low
226
  result["score_ci_high"] = ci.high
227
  return result
228
 
229
+ def resample_from_non_nan(self, values):
230
+ """Given an array values, will replace any NaN values with elements resampled with replacement from the non-NaN ones.
231
+
232
+ here we deal with samples on which the metric could not be computed. These are
233
+ edge cases - for example, when the sample contains only empty strings.
234
+ CI is about the distribution around the statistic (e.g. mean), it doesn't deal with
235
+ cases in which the metric is not computable. Therefore, we ignore these edge cases
236
+ as part of the computation of CI.
237
+
238
+ In theory there would be several ways to deal with this:
239
+ 1. skip the errors and return a shorter array => this fails because Scipy requires
240
+ this callback (i.e. the statistic() callback) to return an array of the same size
241
+ as the number of resamples
242
+ 2. Put np.nan for the errors => this fails because in such case the ci itself
243
+ becomes np.nan. So one edge case can fail the whole CI computation.
244
+ 3. Replace the errors with a sampling from the successful cases => this is what is implemented.
245
+
246
+ This resampling makes it so that, if possible, the bca confidence interval returned by bootstrap will not be NaN, since
247
+ bootstrap does not ignore NaNs. However, if there are 0 or 1 non-NaN values, or all non-NaN values are equal,
248
+ the resulting distribution will be degenerate (only one unique value) so the CI will still be NaN since there is
249
+ no variability. In this case, the CI is essentially an interval of length 0 equaling the mean itself.
250
+ """
251
+ if values.size > 1:
252
+ error_indices = numpy.isnan(values)
253
+ n_errors = sum(error_indices)
254
+ if 0 < n_errors < values.size:
255
+ # replace NaN aggregate scores with random draws from non-NaN scores, so that confidence interval isn't NaN itself
256
+ values[error_indices] = self.new_random_generator().choice(
257
+ values[~error_indices], n_errors, replace=True
258
+ )
259
+ return values
260
+
261
  def compute_global_confidence_intervals(
262
+ self, references, predictions, task_data, score_name
263
  ):
264
  """Computed confidence intervals for a set of references and predictions."""
265
  random_gen = self.new_random_generator()
 
267
  def statistic(arr, axis):
268
  # arr is a 2d array where each row is a resampling, so we
269
  # iterate over the rows and compute the metric on each resampling
270
+ def metric(sample_refs, sample_preds, sample_task_data):
271
  try:
272
  return self._compute(
273
  references=sample_refs,
274
  predictions=sample_preds,
275
+ task_data=sample_task_data,
276
  )["score"]
277
  except Exception as e:
278
  # this happens in edge cases, for example, when the sampling creates a
 
280
  logger.info(f"Warning in {self.__class__.__name__}", e)
281
  return np.nan
282
 
283
+ # resample the instance scores, and then return the global score each time
284
  scores = numpy.apply_along_axis(
285
  lambda x: metric(
286
  sample_refs=[references[i] for i in x],
287
  sample_preds=[predictions[i] for i in x],
288
+ sample_task_data=[task_data[i] for i in x],
289
  ),
290
  axis=axis,
291
  arr=arr,
292
  )
293
 
294
+ # in some resamplings of instances, the global score may be NaN since it cannot be computed;
295
+ # in these cases, the bca confidence interval will be NaN because it does not ignore these values,
296
+ # so we replace any NaN values with those resampled from the non-NaN ones.
297
+ return self.resample_from_non_nan(scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
  result = {}
300
  num_predictions = len(predictions)
 
322
  need to be considered. Accuracy, on the other hand, is just an average of the accuracy of all the instances.
323
  """
324
 
325
+ n_resamples: int = OptionalField(
326
+ default_factory=lambda: settings.num_resamples_for_global_metrics
327
+ )
328
+ process_single_instances = True
329
 
330
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
331
  references = []
332
  predictions = []
333
+ task_data = []
334
  global_score = {}
335
 
336
  instances = []
 
349
  predictions.append(instance_prediction)
350
  instances.append(instance)
351
 
352
+ instance_task_data = (
353
+ instance["task_data"] if "task_data" in instance else {}
354
  )
355
+ task_data.append(instance_task_data)
356
+ instance_score = None
357
+ # for backward compatibility
358
+ no_score_value = np.nan
359
+ if self.process_single_instances:
360
+ try:
361
+ instance_score = self._compute(
362
+ [instance_references],
363
+ [instance_prediction],
364
+ [instance_task_data],
365
+ )
366
+ except:
367
+ no_score_value = None
368
+ if not instance_score:
369
+ instance_score = {
370
+ "score": no_score_value,
371
+ "score_name": self.main_score,
372
+ }
373
 
374
  if isinstance(self.main_score, str):
375
+ instance_score[self.main_score] = no_score_value
376
 
377
  instance["score"]["instance"].update(instance_score)
378
 
379
+ result = self._compute(references, predictions, task_data)
380
 
381
  global_score.update(result)
382
 
383
  score_name = global_score["score_name"]
384
  confidence_interval = self.compute_global_confidence_intervals(
385
+ references, predictions, task_data, score_name
386
  )
387
  global_score.update(confidence_interval)
388
 
 
394
  self,
395
  references: List[List[str]],
396
  predictions: List[str],
397
+ task_data: List[Any],
398
  ) -> dict:
399
+ result = self.compute(references, predictions, task_data)
400
  result["score"] = result[self.main_score]
401
  result["score_name"] = self.main_score
402
  return result
 
406
  self,
407
  references: List[List[Any]],
408
  predictions: List[Any],
409
+ task_data: List[Any],
410
  ) -> dict:
411
+ """Computes a scores dictionary on a list of references, predictions and input.
412
+
413
+ This function is called once per instance, and then another time
414
+ over all data instances.
415
+
416
+ Returns:
417
+ a dictionary of scores that is set as:
418
+ the instance scores when called on a single data instance
419
+ the global score when called on the all data instances
420
+ """
421
  pass
422
 
423
 
424
  class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
425
+ n_resamples: int = OptionalField(
426
+ default_factory=lambda: settings.num_resamples_for_instance_metrics
427
+ )
428
  main_score: str
429
  reduction_map: Dict[str, List[str]]
430
 
 
445
  ),
446
  )
447
 
448
+ task_data = [
449
+ instance["task_data"] if "task_data" in instance else {}
450
  for instance in stream
451
  ]
452
 
 
454
  instance_scores = self.compute(
455
  references=references,
456
  predictions=predictions,
457
+ task_data=task_data,
458
  )
459
 
460
  # add the score and score_name fields
 
478
  ), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
479
 
480
  if reduction == "mean":
 
 
481
  for field_name in fields:
482
  global_score[field_name] = mean(
483
  [
 
489
  global_score["score"] = global_score[field_name]
490
  global_score["score_name"] = self.main_score
491
 
492
+ ci_fields = (
493
+ list(set(self.ci_scores))
494
+ if self.ci_scores is not None
495
+ else [self.main_score]
496
+ )
497
  confidence_interval = self.score_based_confidence_interval(
498
+ instances=instances, score_names=ci_fields
499
  )
500
  global_score.update(confidence_interval)
501
 
 
507
  self,
508
  references: List[List[Any]],
509
  predictions: List[Any],
510
+ task_data: List[Dict],
511
  ) -> List[Dict[str, Any]]:
512
  pass
513
 
514
 
515
  class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
516
+ """Class for metrics for which a global score can be calculated by aggregating the instance scores (possibly with additional instance inputs).
517
+
518
+ InstanceMetric currently allows two reductions:
519
+ 1. 'mean', which calculates the mean of instance scores,
520
+ 2. 'group_mean', which first applies an aggregation function specified in the reduction_map
521
+ to instance scores grouped by the field grouping_field (which must not be None), and returns the mean
522
+ of the group scores; if grouping_field is None, grouping is disabled.
523
+ See _validate_group_mean_reduction for formatting instructions.
524
+ """
525
 
526
+ n_resamples: int = OptionalField(
527
+ default_factory=lambda: settings.num_resamples_for_instance_metrics
528
+ )
529
+
530
+ # some group_mean aggregation functions (3rd element of "agg_func" list in the reduction)
531
+ # only require a list of instance scores (e.g., mean, median, etc.). Others aggregation functions
532
+ # require an additional column (e.g., a subgroup identifier) by which the instance scores will be grouped
533
+ # if subgroup_column is not None, a column by the specified name will be required in task_data
534
+ subgroup_column = None
535
+ implemented_reductions: List[str] = field(
536
+ default_factory=lambda: ["mean", "group_mean"]
537
+ )
538
 
539
  @property
540
  @abstractmethod
541
  def reduction_map(self) -> dict:
542
  pass
543
 
544
+ def _validate_group_mean_reduction(self, instances: List[dict]):
545
+ """Ensure that group_mean reduction_map is properly formatted.
546
+
547
+ Example: Apply the variance (np.var) to group Accuracy instance scores. This class would be specified as follows:
548
+
549
+ class GroupVarianceAccuracy(Accuracy):
550
+ reduction_map = {'group_mean': {'agg_func': ['variance', np.var, True]}}
551
+
552
+ reduction_map must be a dict with values containing
553
+ - an 'agg_func' field with value being a 3-element list where
554
+ - 1st element is a string name of the aggregation function (used in naming the CI report)
555
+ - 2nd element is the callable aggregation function
556
+ - 3rd element is a Boolean indicator of whether, during boostrap CI calculation, the groups are to be sampled as single units.
557
+ If True, the group scores are calculated and then resampled. This treats the group units as the unit of
558
+ interest for which the CI is being compared.
559
+ If False, the instances are resampled individually, and the groups determined
560
+ (meaning the groups may be of slightly different size or composition from the original
561
+ depending on the resampling of the instances).
562
+ - Optional: 'score_fields' key with list value containing the string names of fields to apply the aggregation to
563
+ - If not present, the parent class main_score is used.
564
+
565
+ The aggregation function (2nd element of agg_func) can be one of two types:
566
+ 1. simple: calculate a summary statistic from a single group of values (e.g. mean, median, etc.).
567
+ This is best suited for cases where the instances are independent of each other, other than belonging to the same group
568
+ 2. comparison: requires subgroup_column to be specified. This function conducts
569
+ a comparison between scores for differing values of subgroup_column (e.g., 'original' vs 'paraphrase').
570
+ An example is where the original instance is a question, and the others are various paraphrases
571
+ or perturbations of this question. Here, the function would return, say, a comparison of the instance accuracies
572
+ rather than, say, the average instance accuracy.
573
+ In these cases, we recommend setting the 3rd parameter to be True so that the groups are resampled together.
574
+
575
+ Example:
576
+ class GroupVsBaselineDiffAccuracy(Accuracy):
577
+ subgroup_column = 'variant_type'
578
+ reduction_map = {'group_mean': {'agg_func': ['accuracy_diff', accuracy_diff, True],}}
579
+
580
+ # where the function is defined as
581
+ def accuracy_diff(subgroup_scores_dict, expected_subgroup_types=['original', 'paraphrase']):
582
+ validate_subgroup_types(subgroup_scores_dict, expected_subgroup_types)
583
+ from statistics import mean
584
+ return mean(subgroup_scores_dict['paraphrase']) - mean(subgroup_scores_dict['original'])
585
+ The input dataset should look like:
586
+
587
+ 'group_id' 'question' 'variant_type'
588
+ 1 'How do you fix a car engine?' 'original'
589
+ 1 'What is the best way to fix an engine?' 'paraphrase'
590
+ 1 'How do you repair a car engine?' 'paraphrase'
591
+ 1 'How do I repair my engine?' 'paraphrase'
592
+ 2 'Why are ants eating my food?' 'original'
593
+ """
594
+ # instances need to all have task_data field with field group_id
595
+ assert all(
596
+ "task_data" in instance for instance in instances
597
+ ), "each instance must have an task_data field"
598
+ assert all(
599
+ isinstance(instance["task_data"], dict) for instance in instances
600
+ ), "each instance must have an task_data field that is a dict"
601
+ assert all(
602
+ "group_id" in instance["task_data"] for instance in instances
603
+ ), "each instance task_data dict must have a key group_id"
604
+
605
+ # validate the reduction_map
606
+ assert (
607
+ "group_mean" in self.reduction_map
608
+ ), "reduction_map must have a 'group_mean' key"
609
+ fields = self.reduction_map["group_mean"]
610
+ # for group_mean, expects a dict
611
+ assert isinstance(fields, dict)
612
+ assert (
613
+ "agg_func" in fields
614
+ ), "fields should have a key 'agg_func' whose value is a 3-element list of a function name, function definition, and a boolean indicator"
615
+ assert isinstance(
616
+ fields["agg_func"], list
617
+ ), "fields['agg_func'] should be a list"
618
+ assert (
619
+ len(fields["agg_func"]) == 3
620
+ ), "fields['agg_func'] should be a 3-element list"
621
+ assert isinstance(
622
+ fields["agg_func"][0], str
623
+ ), "first item in fields['agg_func'] should be a string name of a function"
624
+ assert callable(
625
+ fields["agg_func"][1]
626
+ ), "second item in fields['agg_func'] should be a callable function"
627
+ assert isinstance(
628
+ fields["agg_func"][2], bool
629
+ ), "third item in fields['agg_func'] should be a boolean value"
630
+ if "score_fields" in fields:
631
+ assert isinstance(fields["score_fields"], list)
632
+
633
+ # for aggregation functions that use the subgroup_column (expect a dict of lists), check that
634
+ # this field exists
635
+ if self.subgroup_column is not None:
636
+ assert all(
637
+ self.subgroup_column in instance["task_data"] for instance in instances
638
+ ), f"each instance task_data dict must have a key {self.subgroup_column}"
639
+
640
  def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
641
+ instances, global_score = self.compute_instance_scores(stream)
642
+
643
+ for reduction_type, reduction_params in self.reduction_map.items():
644
+ assert (
645
+ reduction_type in self.implemented_reductions
646
+ ), f"Reduction {reduction_type} is not implemented, use one of {self.implemented_reductions}"
647
+
648
+ field_name_full_prefix = ""
649
+ # used for passing to the bootstrapping, depends on whether the groups are fixed or not
650
+ aggregation_function = self.average_item_scores
651
+ if reduction_type == "mean":
652
+ reduction_fields = list(set(reduction_params))
653
+ # no group reduction, so resample instances individually
654
+ scores_to_resample = instances
655
+ elif reduction_type == "group_mean":
656
+ self._validate_group_mean_reduction(instances=instances)
657
+ reduction_fields = (
658
+ [self.main_score]
659
+ if "score_fields" not in reduction_params
660
+ else list(set(reduction_params["score_fields"]))
661
+ )
662
+ aggregation_function_name = str(reduction_params["agg_func"][0])
663
+ field_name_full_prefix = "group_" + aggregation_function_name + "_"
664
+ do_resample_as_group = reduction_params["agg_func"][2]
665
+ if do_resample_as_group:
666
+ # append fixed_ to name because resamples the groups as fixed units
667
+ field_name_full_prefix = "fixed_" + field_name_full_prefix
668
+ (
669
+ scores_to_resample,
670
+ aggregation_function,
671
+ ) = self._set_up_group_mean_aggregation(
672
+ instances, reduction_params, reduction_fields
673
+ )
674
+ else:
675
+ raise ValueError(
676
+ f"Reduction {reduction_type} is not supported, please specify a valid reduction method in reduction_map {self.reduction_map}."
677
+ )
678
+
679
+ # calculate global scores for each reduction field
680
+ for field_name in reduction_fields:
681
+ field_name_full = field_name_full_prefix + field_name
682
+ # if group resampling (3rd element of agg_func parameter) is True, then
683
+ # 1. scores_to_resample are the group scores, and
684
+ # 2. aggregation_function is to take the raw mean
685
+ # if no group resampling (3rd element of agg_func parameter) is False, then
686
+ # 1. scores_to_resample are the original instance scores, and
687
+ # 2. aggregation_function is to apply the group aggregation from the instance scores
688
+ # either way, the application of aggregation_function to scores_to_resample yields the global score
689
+ global_score[field_name_full] = aggregation_function(
690
+ scores_to_resample, field_name
691
+ )
692
+ if field_name == self.main_score:
693
+ global_score["score"] = global_score[field_name_full]
694
+ global_score["score_name"] = field_name_full
695
+
696
+ # need to specify which fields should have CIs calculated for them through ci_scores
697
+ # (will not automatically calculate CIs for fields in reduction map)
698
+ if self.ci_scores is not None:
699
+ confidence_interval = self.score_based_confidence_interval(
700
+ instances=scores_to_resample,
701
+ score_names=list(set(self.ci_scores)),
702
+ ci_score_prefix=field_name_full_prefix,
703
+ aggregation_func=aggregation_function,
704
+ )
705
+ global_score.update(confidence_interval)
706
+
707
+ yield from instances
708
+
709
+ def compute_instance_scores(
710
+ self, stream: Stream, stream_name: Optional[str] = None
711
+ ):
712
  global_score = {}
713
  instances = []
714
 
715
  for instance in stream:
716
  refs, pred = instance["references"], instance["prediction"]
717
+ task_data = instance["task_data"] if "task_data" in instance else {}
 
 
718
 
719
  instance_score = self.compute(
720
+ references=refs, prediction=pred, task_data=task_data
721
  )
722
  instance_score["score"] = instance_score[self.main_score]
723
  instance_score["score_name"] = self.main_score
 
730
 
731
  instances.append(instance)
732
 
733
+ return instances, global_score
 
 
 
734
 
735
+ def get_group_scores(
736
+ self, instances: List[dict], score_names: List[str], group_aggregation_func
737
+ ):
738
+ """Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
739
+
740
+ Args:
741
+ instances: List of observation instances with instance-level scores (fields) computed.
742
+ score_names: List of instance score names in each instance to apply the aggregation function.
743
+ group_aggregation_func: Callable aggregation function accepting a list of numeric scores;
744
+ or, if self.subgroup_column is not None, a dict of subgroup types scores by subgroup_column value.
745
+ callable function returns a single score for the group
746
+
747
+ Returns:
748
+ List of dicts, each corresponding to a group of instances (defined by 'group_id'),
749
+ with an aggregate group score for each score_name
750
+ """
751
+ from collections import defaultdict
752
 
753
+ # three-level defaultdict:
754
+ # first is the grouping, second is the field name, the third is the subgroup_type (by default 'default')
755
+ group_to_instance_scores = defaultdict(
756
+ lambda: defaultdict(lambda: defaultdict(list))
757
+ )
 
 
 
 
758
 
759
+ # check if function has fields for subgroup_column
760
+ uses_subgroups = self.subgroup_column is not None
761
+ default_subgroup_name = "default"
762
+ # loop through the instances and group the scores
763
+ for instance in instances:
764
+ task_data = instance["task_data"]
765
+ group_key = task_data["group_id"]
766
+ # for functions that do comparisons between subgroup_column groups
767
+ # if function doesn't use subgroup_column, or none is present, set "default" as default value, and pass all scores
768
+ subgroup_type = (
769
+ task_data[self.subgroup_column]
770
+ if uses_subgroups
771
+ else default_subgroup_name
772
+ )
773
+ for score_name in score_names:
774
+ group_to_instance_scores[group_key][score_name][subgroup_type].append(
775
+ instance["score"]["instance"][score_name]
776
  )
 
777
 
778
+ # if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
779
+ return [
780
+ {
781
+ "score": {
782
+ "instance": {
783
+ score_name: group_aggregation_func(
784
+ score_dict
785
+ if uses_subgroups
786
+ else score_dict[default_subgroup_name]
787
+ )
788
+ for score_name, score_dict in group_scores.items()
789
+ }
790
+ }
791
+ }
792
+ for group_scores in group_to_instance_scores.values()
793
+ ]
794
+
795
+ def _set_up_group_mean_aggregation(
796
+ self, instances, reduction_params, reduction_fields
797
+ ):
798
+ group_aggregation_func = reduction_params["agg_func"][1]
799
+ # if treat groups as units
800
+ do_resample_as_group = reduction_params["agg_func"][2]
801
+ if do_resample_as_group:
802
+ # pass the group aggregate---not instance---scores to resample as usual
803
+ aggregation_function = self.average_item_scores
804
+ scores_to_resample = self.get_group_scores(
805
+ instances, reduction_fields, group_aggregation_func
806
+ )
807
+ else:
808
+ # pass the instance scores to resample, and calculate the group aggregation on the resamplings
809
+ scores_to_resample = instances
810
+
811
+ def aggregation_function(
812
+ instances,
813
+ field_name,
814
+ group_aggregation_func=group_aggregation_func,
815
+ ):
816
+ group_scores = self.get_group_scores(
817
+ instances, [field_name], group_aggregation_func
818
+ )
819
+ return nan_mean(
820
+ [group["score"]["instance"][field_name] for group in group_scores]
821
+ )
822
+
823
+ return scores_to_resample, aggregation_function
824
 
825
  @abstractmethod
826
+ def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
 
 
827
  pass
828
 
829
 
 
840
  self,
841
  references: List[List[str]],
842
  predictions: List[str],
843
+ task_data: List[Dict],
844
  ) -> dict:
845
  ids = [str(uuid.uuid4()).replace("-", "") for _ in range(len(predictions))]
846
  formatted_predictions = [
 
861
  class Accuracy(InstanceMetric):
862
  reduction_map = {"mean": ["accuracy"]}
863
  main_score = "accuracy"
864
+ ci_scores = ["accuracy"]
865
 
866
  def compute(
867
+ self, references: List[Any], prediction: Any, task_data: List[Dict]
868
  ) -> dict:
869
  result = {
870
  self.main_score: float(
 
879
  class StringContainment(InstanceMetric):
880
  reduction_map = {"mean": ["string_containment"]}
881
  main_score = "string_containment"
882
+ ci_scores = ["string_containment"]
883
 
884
  def compute(
885
+ self, references: List[Any], prediction: Any, task_data: List[Dict]
886
  ) -> dict:
887
  result = {
888
  self.main_score: float(
889
+ any(str(reference) in str(prediction) for reference in references)
890
  )
891
  }
892
  result["score"] = result[self.main_score]
 
902
  )
903
  metric: Metric = None
904
 
905
+ def disable_confidence_interval_calculation(self):
906
+ return self.metric.disable_confidence_interval_calculation()
907
+
908
+ def set_n_resamples(self, n_resample):
909
+ if isinstance(self.metric, MetricWithConfidenceInterval):
910
+ self.metric.set_n_resamples(n_resample)
911
+
912
  def verify(self):
913
  assert self.main_score is not None, "main_score is not set"
914
 
 
973
  self,
974
  references: List[List[Any]],
975
  predictions: List[Any],
976
+ task_data: List[Dict],
977
  ) -> dict:
978
+ passed_task_data = {}
979
  for additional_input_field in self.hf_additional_input_fields:
980
  assert (
981
+ additional_input_field in task_data[0]
982
+ ), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
983
+ passed_task_data[additional_input_field] = [
984
  additional_input[additional_input_field]
985
+ for additional_input in task_data
986
  ]
987
  for additional_input_field in self.hf_additional_input_fields_pass_one_value:
988
  assert (
989
+ additional_input_field in task_data[0]
990
+ ), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
991
 
992
  values = {
993
  additional_input[additional_input_field]
994
+ for additional_input in task_data
995
  }
996
  assert (
997
  len(values) == 1
998
  ), f"Values of '{additional_input_field}' field required by {__class__.__name__} should all be the same, but have multiple values {values}"
999
 
1000
+ passed_task_data[additional_input_field] = next(iter(values))
1001
 
1002
+ # add check that all required fields in self.metrics are in passed_task_data print(passed_task_data)
1003
  result = self.metric.compute(
1004
  predictions=predictions,
1005
  references=references,
1006
+ **passed_task_data,
1007
  **self.hf_compute_args,
1008
  )
1009
  if self.hf_main_score:
 
1045
  self,
1046
  references: List[List[str]],
1047
  predictions: List[str],
1048
+ task_data: List[Any],
1049
  ) -> List[Dict[str, Any]]:
1050
+ passed_task_data = {}
1051
  for additional_input_field in self.hf_additional_input_fields:
1052
  assert (
1053
+ additional_input_field in task_data[0]
1054
+ ), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
1055
+ passed_task_data[additional_input_field] = [
1056
  additional_input[additional_input_field]
1057
+ for additional_input in task_data
1058
  ]
1059
+ # add check that all required fields in self.metrics are in passed_task_data
1060
 
1061
  scores = self.metric.compute(
1062
  predictions=predictions,
1063
  references=references,
1064
+ **passed_task_data,
1065
  **self.hf_compute_args,
1066
  )
1067
 
 
1096
  self,
1097
  references: List[List[str]],
1098
  predictions: List[str],
1099
+ task_data: List[Dict],
1100
  ) -> dict:
1101
  assert all(
1102
  len(reference) == 1 for reference in references
 
1118
  average=self.average,
1119
  )
1120
  if isinstance(result["f1"], numpy.ndarray):
 
 
1121
  final_result = {self.main_score: mean(result["f1"])}
1122
  for i, label in enumerate(labels):
1123
  final_result["f1_" + self.id_to_str[label]] = result["f1"][i]
 
1144
  _metric = None
1145
  main_score = "f1_macro"
1146
  average = None # Report per class then aggregate by mean
 
1147
  metric = "f1"
1148
 
1149
  def prepare(self):
 
1168
  self,
1169
  references: List[List[str]],
1170
  predictions: List[List[str]],
1171
+ task_data: List[Dict],
1172
  ) -> dict:
1173
  self.str_to_id = {}
1174
  self.id_to_str = {}
 
1176
  self._validate_references_and_prediction(references, predictions)
1177
  references = [reference[0] for reference in references]
1178
 
1179
+ labels = list({label for reference in references for label in reference})
1180
+
 
 
 
1181
  # if no classes are left then F1 is not defined
 
1182
  if len(labels) == 0:
1183
  return {self.main_score: float("nan")}
1184
 
 
1206
  labels=labels_param,
1207
  )
1208
  if isinstance(result[self.metric], numpy.ndarray):
 
 
1209
  assert (
1210
  len(result[self.metric]) == len(labels)
1211
  ), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
 
1278
 
1279
  sent_split_newline: bool = True
1280
 
1281
+ _requirements_list: List[str] = ["nltk", "rouge_score"]
1282
+
1283
  def prepare(self):
1284
  super().prepare()
1285
 
 
1292
  nltk.download("punkt")
1293
  self.sent_tokenize = nltk.sent_tokenize
1294
 
1295
+ def compute(self, references, predictions, task_data: List[Dict]):
1296
  if self.sent_split_newline:
1297
  predictions = [
1298
  "\n".join(self.sent_tokenize(prediction.strip()))
 
1302
  ["\n".join(self.sent_tokenize(r.strip())) for r in reference]
1303
  for reference in references
1304
  ]
1305
+ return super().compute(references, predictions, task_data)
1306
 
1307
 
1308
  # Computes char edit distance, ignoring whitespace
1309
  class CharEditDistanceAccuracy(InstanceMetric):
1310
  reduction_map = {"mean": ["char_edit_dist_accuracy"]}
1311
  main_score = "char_edit_dist_accuracy"
1312
+ ci_scores = ["char_edit_dist_accuracy"]
1313
+
1314
+ _requirements_list: List[str] = ["editdistance"]
1315
 
1316
  def prepare(self):
1317
  super().prepare()
 
1319
 
1320
  self.eval = editdistance.eval
1321
 
1322
+ def compute(self, references, prediction: str, task_data: List[Dict]) -> dict:
 
 
1323
  assert (
1324
  len(references) == 1
1325
  ), f"Expected only one reference , but received: {references}"
 
1337
  hf_metric_name = "wer"
1338
  main_score = "wer"
1339
 
1340
+ _requirements_list: List[str] = ["jiwer"]
1341
+
1342
  def compute(
1343
  self,
1344
  references: List[List[str]],
1345
  predictions: List[str],
1346
+ task_data: List[Dict],
1347
  ) -> dict:
1348
  assert all(
1349
  len(reference) == 1 for reference in references
 
1355
  return {self.main_score: result}
1356
 
1357
 
1358
+ class Spearmanr(HuggingfaceMetric):
1359
+ hf_metric_name = "spearmanr"
1360
+ main_score = "spearmanr"
1361
+ process_single_instances = False
1362
+
1363
+
1364
+ class KendallTauMetric(GlobalMetric):
1365
+ main_score = "kendalltau_b"
1366
+ variant = "b"
1367
+ process_single_instances = False
1368
+
1369
+ _requirements_list: List[str] = ["scipy"]
1370
+
1371
+ def prepare(self):
1372
+ from scipy.stats import kendalltau
1373
+
1374
+ self.kendalltau = kendalltau
1375
+
1376
+ def compute(
1377
+ self,
1378
+ references: List[List[str]],
1379
+ predictions: List[str],
1380
+ task_data: List[Dict],
1381
+ ) -> dict:
1382
+ if isinstance(references[0], list):
1383
+ references = [reference[0] for reference in references]
1384
+ references = [to_float_or_default(r) for r in references]
1385
+ predictions = [to_float_or_default(p) for p in predictions]
1386
+
1387
+ kendall_results = self.kendalltau(references, predictions, variant=self.variant)
1388
+ corr = kendall_results.correlation
1389
+ return {
1390
+ self.main_score: corr,
1391
+ f"{self.main_score}_p_val": kendall_results.pvalue,
1392
+ }
1393
+
1394
+
1395
  class MatthewsCorrelation(HuggingfaceMetric):
1396
  hf_metric_name = "matthews_correlation"
1397
  main_score = "matthews_correlation"
 
1407
  self,
1408
  references: List[List[str]],
1409
  predictions: List[str],
1410
+ task_data: List[Dict],
1411
  ) -> dict:
1412
  formatted_references = [
1413
  self.get_str_id(reference[0]) for reference in references
 
1420
  )
1421
 
1422
 
1423
+ class RocAuc(GlobalMetric):
1424
+ main_score = "roc_auc"
1425
+ process_single_instances = False
1426
+ _requirements_list: List[str] = ["sklearn"]
1427
+
1428
+ def prepare(self):
1429
+ from sklearn import metrics
1430
+
1431
+ self.roc_curve = metrics.roc_curve
1432
+ self.auc = metrics.auc
1433
+
1434
+ def compute(
1435
+ self,
1436
+ references: List[List[str]],
1437
+ predictions: List[str],
1438
+ task_data: List[Dict],
1439
+ ) -> dict:
1440
+ if isinstance(references[0], list):
1441
+ references = [reference[0] for reference in references]
1442
+ references = [to_float_or_default(r) for r in references]
1443
+ predictions = [to_float_or_default(p) for p in predictions]
1444
+
1445
+ fpr, tpr, thrs = self.roc_curve(y_true=references, y_score=predictions)
1446
+ roc_auc = self.auc(fpr, tpr)
1447
+ return {self.main_score: roc_auc}
1448
+
1449
+
1450
  class CustomF1(GlobalMetric):
1451
  main_score = "f1_micro"
1452
  groups = None
 
1500
  except ZeroDivisionError:
1501
  return self.zero_division
1502
 
1503
+ def get_groups(self, elements, task_data):
1504
  groups = set()
1505
+ for sublist, additional_input in zip(elements, task_data):
1506
  for e in sublist:
1507
  if self.should_ignore_element(e, additional_input):
1508
  continue
 
1513
  self,
1514
  references: List[List[Any]],
1515
  predictions: List[Any],
1516
+ task_data: List[Dict],
1517
  ) -> dict:
1518
  # in case reference are List[List[List[Any]]] and predictions are List[List[Any]]:
1519
  if (
 
1529
  )
1530
 
1531
  if self.groups is None:
1532
+ groups = self.get_groups(references, task_data)
1533
  else:
1534
  groups = self.groups
1535
  groups_statistics = {}
1536
  for references_batch, predictions_batch, additional_input in zip(
1537
+ references, predictions, task_data
1538
  ):
1539
  grouped_references = self.group_elements(references_batch, additional_input)
1540
  grouped_predictions = self.group_elements(
 
1651
  ci_scores = ["f1", "precision", "recall"]
1652
 
1653
  def compute(
1654
+ self, references: List[Any], prediction: Any, task_data: List[Dict]
1655
  ) -> dict:
1656
  results = [
1657
+ self._compute_single_ref(str(reference), str(prediction))
1658
+ for reference in references
1659
  ]
1660
  return {
1661
  measure: max(r[i] for r in results)
 
1665
  def _compute_single_ref(
1666
  self, reference: Any, prediction: Any
1667
  ) -> Tuple[float, float, float]:
1668
+ prediction_tokens = normalize_answer(str(prediction)).split()
1669
+ reference_tokens = normalize_answer(str(reference)).split()
1670
  common = Counter(prediction_tokens) & Counter(reference_tokens)
1671
  num_same = sum(common.values())
1672
  if num_same == 0:
 
1686
  ci_scores = ["f1", "precision", "recall"]
1687
  model_name: str
1688
 
1689
+ _requirements_list: List[str] = ["bert_score"]
1690
+
1691
  def prepare(self):
1692
  super().prepare()
1693
+ self.hf_compute_args = {"model_type": self.model_name, "batch_size": 16}
1694
 
1695
 
1696
  class SentenceBert(BulkInstanceMetric):
 
1700
 
1701
  model_name: str
1702
 
1703
+ _requirements_list: List[str] = ["sentence_transformers"]
1704
+
1705
  def prepare(self):
1706
  super().prepare()
1707
+ import torch
1708
  from sentence_transformers import SentenceTransformer
1709
  from sentence_transformers import util as sbert_util
1710
 
1711
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
1712
+ self.model = SentenceTransformer(self.model_name, device=self.device)
1713
  self.util = sbert_util
1714
 
1715
  def compute(
1716
  self,
1717
  references: List[List[Any]],
1718
  predictions: List[Any],
1719
+ task_data: List[Dict],
1720
  ) -> List[Dict[str, Any]]:
1721
  scores = []
1722
 
 
1731
  count += len(ref_group)
1732
 
1733
  # compute s-bert embeddings
1734
+ preds_emb = self.model.encode(predictions, device=self.device)
1735
  refs_emb = self.model.encode(
1736
+ [ref for ref_group in references for ref in ref_group], device=self.device
1737
  )
1738
 
1739
  # for each candidate, pick the reference with the highest score
 
1751
 
1752
  model_name: str
1753
 
1754
+ _requirements_list: List[str] = ["transformers"]
1755
+
1756
  def prepare(self):
1757
  super().prepare()
1758
+ import torch
1759
  from transformers import pipeline
1760
 
1761
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
1762
+ self.pipe = pipeline(
1763
+ "text-classification", model=self.model_name, device=device
1764
+ )
1765
 
1766
  def compute(
1767
  self,
1768
  references: List[List[Any]],
1769
  predictions: List[Any],
1770
+ task_data: List[Dict],
1771
  ) -> List[Dict[str, Any]]:
1772
  # treat the references as the questions and the predictions as answers
1773
  # assume a single reference
 
1793
  batch_size: int = 32
1794
  model_name: str
1795
 
1796
+ _requirements_list: List[str] = ["transformers"]
1797
+
1798
  def compute(
1799
  self,
1800
  references: List[List[Any]],
1801
  predictions: List[Any],
1802
+ task_data: List[Dict],
1803
  ) -> List[Dict[str, Any]]:
1804
  """Computes the likelihood of generating text Y after text X - P(Y|X).
1805
 
1806
+ :param predictions: the list of Y texts = the targets of the generation
1807
+ :param references: the list of list of X texts = the sources of the generation
1808
 
1809
+ :return: the likelihood of generating text Y_i after each text X_i_j = P(Y_i|X_i_1), ..., P(Y_i|X_i_n) for every i.
1810
  """
1811
  sources = []
1812
  targets = []
1813
  for prediction, instance_references in zip(predictions, references):
1814
  for instance_reference in instance_references:
1815
+ sources.append(f"{self.perplexity_prompt} {instance_reference}")
1816
+ targets.append(prediction)
1817
 
1818
  from transformers import AutoConfig
1819
 
 
1854
  from transformers import AutoTokenizer
1855
 
1856
  self.model_name = model_name
1857
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
1858
+ self.model = (
1859
+ self.model_class().from_pretrained(self.model_name).to(self.device)
1860
+ )
1861
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
 
 
1862
 
1863
  def compute_lm(
1864
  self, source: List[str], target: List[str], batch_size: int
 
1951
  return AutoModelForSeq2SeqLM
1952
 
1953
  def compute_batch(self, tokens_source, tokens_target):
1954
+ tokens_docs_ids = tokens_source["input_ids"].to(self.device)
1955
+ attention = tokens_source["attention_mask"].to(self.device)
1956
+ labels = tokens_target["input_ids"].to(self.device)
 
 
 
 
 
 
 
1957
 
1958
  logits = self.model(
1959
  input_ids=tokens_docs_ids.long(),
 
1993
  # replace the padding token in the labels by -100
1994
  labels[labels == self.tokenizer.pad_token_id] = -100
1995
 
1996
+ tokens = tokens.to(self.device)
1997
+ attention = attention.to(self.device)
1998
+ labels = labels.to(self.device)
 
 
 
1999
 
2000
  # no need to pass labels as we calculate the loss below per document
2001
  model_output = self.model(
 
2029
 
2030
  main_score = "nDCG"
2031
 
2032
+ _requirements_list: List[str] = ["sklearn"]
2033
+
2034
  def prepare(self):
2035
  from sklearn.metrics import ndcg_score
2036
 
 
2041
  self,
2042
  references: List[List[Any]],
2043
  predictions: List[Any],
2044
+ task_data: List[Any],
2045
  ) -> dict:
2046
  from collections import defaultdict
 
2047
 
2048
  query_to_predictions_and_references = defaultdict(lambda: [[], []])
2049
+ for reference, pred, inputs_dict in zip(references, predictions, task_data):
 
 
2050
  query = inputs_dict.get("query")
2051
  query_to_predictions_and_references[query][0].append(pred)
2052
  query_to_predictions_and_references[query][1].append(reference)
 
2076
 
2077
 
2078
  class RetrievalMetric(InstanceMetric):
2079
+ def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
 
 
2080
  # digest input
2081
  pred_ids: List[Any] = prediction
2082
  ref_ids: List[Any] = list(dict.fromkeys(references))
 
2149
  class MRR(RetrievalMetric):
2150
  reduction_map = {"mean": ["mrr"]}
2151
  main_score = "mrr"
2152
+ ci_scores = ["mrr"]
2153
 
2154
  def _compute(
2155
  self,
 
2166
  class MAP(RetrievalMetric):
2167
  reduction_map = {"mean": ["map"]}
2168
  main_score = "map"
2169
+ ci_scores = ["map"]
2170
 
2171
  def _compute(
2172
  self,
 
2235
 
2236
  def should_ignore_element(self, element, additional_input):
2237
  return element == "none"
2238
+
2239
+
2240
+ class RemoteMetric(SingleStreamOperator, Metric):
2241
+ """A metric that runs another metric remotely.
2242
+
2243
+ main_score: the score updated by this metric.
2244
+ endpoint: the remote host that supports the remote metric execution.
2245
+ metric_name: the name of the metric that is executed remotely.
2246
+ api_key: optional, passed to the remote metric with the input, allows secure authentication.
2247
+ """
2248
+
2249
+ main_score: str = None
2250
+ endpoint: str
2251
+ metric_name: str
2252
+ api_key: str = None
2253
+
2254
+ @staticmethod
2255
+ def wrap_inner_metric_pipeline_metric(
2256
+ metric_pipeline: MetricPipeline, remote_metrics_endpoint: str
2257
+ ) -> MetricPipeline:
2258
+ """Wrap the inner metric in a MetricPipeline with a RemoteMetric.
2259
+
2260
+ When executing the returned MetricPipeline, the inner metric will be computed
2261
+ remotely (pre and post processing steps in the MetricPipeline will be computed locally).
2262
+ """
2263
+ local_inner_metric = metric_pipeline.metric
2264
+ metric_pipeline = deepcopy(
2265
+ metric_pipeline
2266
+ ) # To avoid unintentional changes to the catalog contents
2267
+ metric_pipeline.metric = RemoteMetric(
2268
+ main_score=local_inner_metric.main_score,
2269
+ metric_name=local_inner_metric.artifact_identifier,
2270
+ endpoint=remote_metrics_endpoint,
2271
+ )
2272
+ return metric_pipeline
2273
+
2274
+ def get_metric_url(self) -> str:
2275
+ return f"{self.endpoint}/{self.metric_name}"
2276
+
2277
+ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
2278
+ predictions, references, additional_inputs, instances = self.consume_stream(
2279
+ stream
2280
+ )
2281
+ metric_request = self.create_metric_request(
2282
+ predictions, references, additional_inputs
2283
+ )
2284
+ metric_response = self.get_metric_response(metric_request)
2285
+ self.update_instance_scores(instances, metric_response.instances_scores)
2286
+ self.set_global_score(instances, metric_response.global_score)
2287
+ yield from instances
2288
+
2289
+ @staticmethod
2290
+ def create_metric_request(predictions, references, additional_inputs):
2291
+ instance_inputs = [
2292
+ InstanceInput(
2293
+ prediction=prediction,
2294
+ references=reference,
2295
+ additional_inputs=additional_input,
2296
+ )
2297
+ for prediction, reference, additional_input in zip(
2298
+ predictions, references, additional_inputs
2299
+ )
2300
+ ]
2301
+ return MetricRequest(instance_inputs=instance_inputs)
2302
+
2303
+ def get_metric_response(self, metric_request: MetricRequest) -> MetricResponse:
2304
+ import requests
2305
+
2306
+ response = requests.post(
2307
+ url=self.get_metric_url(),
2308
+ json=metric_request.to_dict(),
2309
+ headers={"Authorization": f"Bearer {self.api_key}"},
2310
+ )
2311
+ response.raise_for_status()
2312
+ response_json = response.json()
2313
+ return MetricResponse(**response_json)
2314
+
2315
+ def disable_confidence_interval_calculation(self):
2316
+ """Confidence intervals are always disabled for RemoteMetric.
2317
+
2318
+ No need to do anything.
2319
+ """
2320
+ pass
2321
+
2322
+ def set_n_resamples(self, n_resample):
2323
+ """Since confidence intervals are always disabled for remote metrics, this is a no-op."""
2324
+ pass
2325
+
2326
+
2327
+ def validate_subgroup_types(
2328
+ subgroup_scores_dict: Dict[str, List],
2329
+ control_subgroup_types: List[str],
2330
+ comparison_subgroup_types: List[str],
2331
+ ):
2332
+ """Validate a dict of subgroup type instance score lists, and subgroup type lists.
2333
+
2334
+ Args:
2335
+ subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
2336
+ control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
2337
+ comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
2338
+ to be compared to the control group.
2339
+
2340
+ Returns:
2341
+ dict with all NaN scores removed; control_subgroup_types and comparison_subgroup_types will have non-unique elements removed
2342
+ """
2343
+ # note: subgroup_scores_dict is already a defaultdict of lists, so don't need to check that keys in control_ and comparison_subgroup_types exist in it
2344
+ # remove any NaNs
2345
+ subgroup_scores_dict.update(
2346
+ {
2347
+ subgroup_name: [score for score in score_list if not np.isnan(score)]
2348
+ for subgroup_name, score_list in subgroup_scores_dict.items()
2349
+ }
2350
+ )
2351
+ assert isinstance(
2352
+ control_subgroup_types, list
2353
+ ), "control_subgroup_types must be a list"
2354
+ assert isinstance(
2355
+ comparison_subgroup_types, list
2356
+ ), "comparison_subgroup_types must be a list"
2357
+ # make sure each list is unique, so that labels aren't double-counted
2358
+ control_subgroup_types = list(set(control_subgroup_types))
2359
+ comparison_subgroup_types = list(set(comparison_subgroup_types))
2360
+
2361
+ return subgroup_scores_dict, control_subgroup_types, comparison_subgroup_types
2362
+
2363
+
2364
+ def performance_drop_rate(
2365
+ subgroup_scores_dict: Dict[str, List],
2366
+ control_subgroup_types: List[str],
2367
+ comparison_subgroup_types: List[str],
2368
+ ):
2369
+ """Percentage decrease of mean performance on test elements relative to that on a baseline (control).
2370
+
2371
+ from https://arxiv.org/pdf/2306.04528.pdf.
2372
+
2373
+ Args:
2374
+ subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
2375
+ control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
2376
+ comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
2377
+ to be compared to the control group.
2378
+
2379
+ Returns:
2380
+ numeric PDR metric.
2381
+ If only one element (no test set) or the first is 0 (percentage change is undefined) return NaN
2382
+ otherwise, calculate PDR
2383
+ """
2384
+ (
2385
+ subgroup_scores_dict,
2386
+ control_subgroup_types,
2387
+ comparison_subgroup_types,
2388
+ ) = validate_subgroup_types(
2389
+ subgroup_scores_dict, control_subgroup_types, comparison_subgroup_types
2390
+ )
2391
+
2392
+ # combine all scores from each label (if there are more than 1 in each group) into a list
2393
+ group_scores_list = [
2394
+ np.concatenate(
2395
+ [subgroup_scores_dict[subgroup_name] for subgroup_name in name_list]
2396
+ )
2397
+ for name_list in [control_subgroup_types, comparison_subgroup_types]
2398
+ ]
2399
+ if any(len(scores) == 0 for scores in group_scores_list):
2400
+ # no comparison can be made since there is not at least one score per type
2401
+ return np.nan
2402
+ control_mean = mean(group_scores_list[0])
2403
+ comparison_mean = mean(group_scores_list[1])
2404
+ if control_mean == 0:
2405
+ # return 0 if comparison is also 0
2406
+ if comparison_mean == 0:
2407
+ return 0
2408
+ return np.nan
2409
+ # otherwise, take the percentage change (which may also be 0)
2410
+ return 1 - comparison_mean / control_mean
2411
+
2412
+
2413
+ def interpret_effect_size(x: float):
2414
+ """Return a string rule-of-thumb interpretation of an effect size value, as defined by Cohen/Sawilowsky.
2415
+
2416
+ See https://en.wikipedia.org/wiki/Effect_size;
2417
+ Cohen, Jacob (1988). Statistical Power Analysis for the Behavioral Sciences; and
2418
+ Sawilowsky, S (2009). "New effect size rules of thumb". Journal of Modern Applied Statistical Methods. 8 (2): 467-474.
2419
+
2420
+ Value has interpretation of
2421
+ - essentially 0 if |x| < 0.01
2422
+ - very small if 0.01 <= |x| < 0.2
2423
+ - small difference if 0.2 <= |x| < 0.5
2424
+ - a medium difference if 0.5 <= |x| < 0.8
2425
+ - a large difference if 0.8 <= |x| < 1.2
2426
+ - a very large difference if 1.2 <= |x| < 2.0
2427
+ - a huge difference if 2.0 <= |x|
2428
+
2429
+ Args:
2430
+ x: float effect size value
2431
+
2432
+ Returns:
2433
+ string interpretation
2434
+ """
2435
+ import pandas as pd
2436
+
2437
+ # assign a label according to threshold of the absolute value
2438
+ return pd.cut(
2439
+ x=[np.abs(x)],
2440
+ right=False,
2441
+ bins=[-1, 0.01, 0.2, 0.5, 0.8, 1.2, 2.0, np.Inf],
2442
+ labels=[
2443
+ "essentially zero",
2444
+ "very small",
2445
+ "small",
2446
+ "medium",
2447
+ "large",
2448
+ "very large",
2449
+ "huge",
2450
+ ],
2451
+ )[0]
2452
+
2453
+
2454
+ def normalized_cohens_h(
2455
+ subgroup_scores_dict: Dict[str, List],
2456
+ control_subgroup_types: List[str],
2457
+ comparison_subgroup_types: List[str],
2458
+ interpret=False,
2459
+ ):
2460
+ """Cohen's h effect size between two proportions, normalized to interval [-1,1].
2461
+
2462
+ Allows for change-type metric when the baseline is 0 (percentage change, and thus PDR, is undefined)
2463
+ https://en.wikipedia.org/wiki/Cohen%27s_h
2464
+
2465
+ Cohen's h effect size metric between two proportions p2 and p1 is 2 * (arcsin(sqrt(p2)) - arcsin(sqrt(p1))).
2466
+ h in -pi, pi, with +/-pi representing the largest increase/decrease (p1=0, p2=1), or (p1=1, p2=0).
2467
+ h=0 is no change. Unlike percentage change, h is defined even if the baseline (p1) is 0.
2468
+ Assumes the scores are in [0,1], either continuous or binary; hence taking the average of a group of scores yields a proportion..
2469
+ Calculates the change in the average of the other_scores relative to the average of the baseline_scores. We rescale this to [-1,1] from [-pi,pi] for clarity, where +- 1 are the most extreme changes, and 0 is no change
2470
+
2471
+ Interpretation: the original unscaled Cohen's h can be interpreted according to function interpret_effect_size
2472
+
2473
+ Thus, the rule of interpreting the effect of the normalized value is to use the same thresholds divided by pi
2474
+ - essentially 0 if |norm h| < 0.0031831
2475
+ - very small if 0.0031831 <= |norm h| < 0.06366198
2476
+ - small difference if 0.06366198 <= |norm h| < 0.15915494
2477
+ - a medium difference if 0.15915494 <= |norm h| < 0.25464791
2478
+ - a large difference if 0.25464791 <= |norm h| < 0.38197186
2479
+ - a very large difference if 0.38197186 <= |norm h| < 0.63661977
2480
+ - a huge difference if 0.63661977 <= |norm h|
2481
+ Args:
2482
+ subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
2483
+ control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
2484
+ comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
2485
+ to be compared to the control group.
2486
+ interpret: boolean, whether to interpret the significance of the score or not
2487
+ Returns:
2488
+ float score between -1 and 1, and a string interpretation if interpret=True
2489
+ """
2490
+ (
2491
+ subgroup_scores_dict,
2492
+ control_subgroup_types,
2493
+ comparison_subgroup_types,
2494
+ ) = validate_subgroup_types(
2495
+ subgroup_scores_dict, control_subgroup_types, comparison_subgroup_types
2496
+ )
2497
+
2498
+ # requires scores to be in [0,1]
2499
+ for subgroup_name, score_list in subgroup_scores_dict.items():
2500
+ assert all(
2501
+ 0 <= score <= 1 for score in score_list
2502
+ ), f"all {subgroup_name} scores must be in [0,1]"
2503
+
2504
+ # combine all scores from each label (if there are more than 1 in each group) into a list
2505
+ group_scores_list = [
2506
+ np.concatenate(
2507
+ [subgroup_scores_dict[subgroup_name] for subgroup_name in name_list]
2508
+ )
2509
+ for name_list in [control_subgroup_types, comparison_subgroup_types]
2510
+ ]
2511
+
2512
+ if any(len(scores) == 0 for scores in group_scores_list):
2513
+ # no comparison can be made since there is not at least one score per type
2514
+ h, norm_h = np.nan, np.nan
2515
+ else:
2516
+ control_mean = mean(group_scores_list[0])
2517
+ comparison_mean = mean(group_scores_list[1])
2518
+ h = 2 * (np.arcsin(np.sqrt(comparison_mean)) - np.arcsin(np.sqrt(control_mean)))
2519
+ norm_h = np.clip(a=h / np.pi, a_min=-1, a_max=1)
2520
+
2521
+ if not interpret:
2522
+ return norm_h
2523
+
2524
+ return norm_h, interpret_effect_size(h)
2525
+
2526
+
2527
+ def normalized_hedges_g(
2528
+ subgroup_scores_dict: Dict[str, List[float]],
2529
+ control_subgroup_types: List[str],
2530
+ comparison_subgroup_types: List[str],
2531
+ interpret=False,
2532
+ ):
2533
+ """Hedge's g effect size between mean of two samples, normalized to interval [-1,1]. Better than Cohen's d for small sample sizes.
2534
+
2535
+ Takes into account the variances within the samples, not just the means.
2536
+
2537
+ Args:
2538
+ subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
2539
+ control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
2540
+ comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
2541
+ to be compared to the control group.
2542
+ interpret: boolean, whether to interpret the significance of the score or not
2543
+ Returns:
2544
+ float score between -1 and 1, and a string interpretation if interpret=True
2545
+ """
2546
+ (
2547
+ subgroup_scores_dict,
2548
+ control_subgroup_types,
2549
+ comparison_subgroup_types,
2550
+ ) = validate_subgroup_types(
2551
+ subgroup_scores_dict, control_subgroup_types, comparison_subgroup_types
2552
+ )
2553
+
2554
+ # combine all scores from each label (if there are more than 1 in each group) into a list
2555
+ group_scores_list = [
2556
+ np.concatenate(
2557
+ [subgroup_scores_dict[subgroup_name] for subgroup_name in name_list]
2558
+ )
2559
+ for name_list in [control_subgroup_types, comparison_subgroup_types]
2560
+ ]
2561
+
2562
+ group_n = [len(scores) for scores in group_scores_list]
2563
+ if any(nn == 0 for nn in group_n) or all(nn <= 1 for nn in group_n):
2564
+ # if at least one sample size is 0 for one type, no comparison can be made at all
2565
+ # if both sample sizes are 1, then the denominator is undefined since divide by n1 + n2 - 2
2566
+ # so require at least one sample to have > 1 observation, and both to have >= 1.
2567
+ g, norm_g = np.nan, np.nan
2568
+ else:
2569
+ # otherwise, calculate the variances
2570
+ group_mean = [mean(scores) for scores in group_scores_list]
2571
+ # sample variance with 1 degree of freedom (denominator n-1); if n=1, return 0 since otherwise throws an error
2572
+ group_var = [
2573
+ 0.0 if nn == 1 else np.var(scores, ddof=1)
2574
+ for scores, nn in zip(group_scores_list, group_n)
2575
+ ]
2576
+ var_total = sum([(nn - 1) * vv for vv, nn in zip(group_var, group_n)])
2577
+ pooled_sd = np.sqrt(var_total / (sum(group_n) - 2))
2578
+
2579
+ max_absolute_value = 5
2580
+ gmd = float(group_mean[1] - group_mean[0])
2581
+
2582
+ if gmd == 0:
2583
+ # if exactly the same, return 0
2584
+ g = 0.0
2585
+ else:
2586
+ try:
2587
+ g = gmd / pooled_sd
2588
+ except ZeroDivisionError:
2589
+ # return a large effect size to avoid explosion if there is zero variance
2590
+ g = np.sign(gmd) * max_absolute_value
2591
+
2592
+ n = sum(group_n)
2593
+ if 3 < n < 50:
2594
+ # small sample adjustment see https://www.itl.nist.gov/div898/software/dataplot/refman2/auxillar/hedgeg.htm
2595
+ # the multiplier is 0 if n <= 3
2596
+ g *= ((n - 3) / (n - 2.25)) * np.sqrt((n - 2) / n)
2597
+ # clip it at a very large value so it doesn't become infinite if the variance (denominator) is very small or 0
2598
+ g = float(np.clip(a=g, a_min=-1 * max_absolute_value, a_max=max_absolute_value))
2599
+ norm_g = g / max_absolute_value
2600
+
2601
+ if not interpret:
2602
+ return norm_g
2603
+ return norm_g, interpret_effect_size(g)
2604
+
2605
+
2606
+ def mean_subgroup_score(
2607
+ subgroup_scores_dict: Dict[str, List], subgroup_types: List[str]
2608
+ ):
2609
+ """Return the mean instance score for a subset (possibly a single type) of variants (not a comparison).
2610
+
2611
+ Args:
2612
+ subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
2613
+ subgroup_types: the keys (subgroup types) for which the average will be computed.
2614
+
2615
+ Returns:
2616
+ float score
2617
+ """
2618
+ subgroup_scores_dict, subgroup_types, _ = validate_subgroup_types(
2619
+ subgroup_scores_dict, subgroup_types, []
2620
+ )
2621
+
2622
+ # combine all desired subgroup scores
2623
+ score_list = np.concatenate(
2624
+ [subgroup_scores_dict[subgroup_name] for subgroup_name in subgroup_types]
2625
+ )
2626
+ if len(score_list) == 0:
2627
+ # no scores to use
2628
+ return np.nan
2629
+ return mean(score_list)
2630
+
2631
+
2632
+ # metrics using mean reduction
2633
+ class GroupMeanAccuracy(Accuracy):
2634
+ reduction_map = {"group_mean": {"agg_func": ["mean", nan_mean, False]}}
2635
+
2636
+
2637
+ class FixedGroupMeanAccuracy(Accuracy):
2638
+ # the same as GroupMeanAccuracy, except the groups are fixed and are resampled together
2639
+ reduction_map = {"group_mean": {"agg_func": ["mean", nan_mean, True]}}
2640
+
2641
+
2642
+ # same as above, now using StringContainment
2643
+ class GroupMeanStringContainment(StringContainment):
2644
+ reduction_map = {"group_mean": {"agg_func": ["mean", nan_mean, False]}}
2645
+
2646
+
2647
+ class FixedGroupMeanStringContainment(StringContainment):
2648
+ # the same as GroupMeanStringContainment, except the groups are fixed and are resampled together
2649
+ reduction_map = {"group_mean": {"agg_func": ["mean", nan_mean, True]}}
2650
+
2651
+
2652
+ # take only the (fixed) group mean of baseline or other (paraphrases) scores
2653
+ class FixedGroupMeanBaselineAccuracy(Accuracy):
2654
+ subgroup_column = "variant_type"
2655
+ # take mean of "original" variants only
2656
+ reduction_map = {
2657
+ "group_mean": {
2658
+ "agg_func": [
2659
+ "mean_baseline",
2660
+ lambda scd: mean_subgroup_score(
2661
+ subgroup_scores_dict=scd, subgroup_types=["original"]
2662
+ ),
2663
+ True,
2664
+ ],
2665
+ }
2666
+ }
2667
+
2668
+
2669
+ class FixedGroupMeanParaphraseAccuracy(Accuracy):
2670
+ subgroup_column = "variant_type"
2671
+ # take mean of "paraphrase" variants only
2672
+ reduction_map = {
2673
+ "group_mean": {
2674
+ "agg_func": [
2675
+ "mean_paraphrase",
2676
+ lambda scd: mean_subgroup_score(
2677
+ subgroup_scores_dict=scd, subgroup_types=["paraphrase"]
2678
+ ),
2679
+ True,
2680
+ ],
2681
+ }
2682
+ }
2683
+
2684
+
2685
+ # same as above but using StringContainment
2686
+ class FixedGroupMeanBaselineStringContainment(StringContainment):
2687
+ subgroup_column = "variant_type"
2688
+ # take mean of "original" variants only
2689
+ reduction_map = {
2690
+ "group_mean": {
2691
+ "agg_func": [
2692
+ "mean_baseline",
2693
+ lambda scd: mean_subgroup_score(
2694
+ subgroup_scores_dict=scd, subgroup_types=["original"]
2695
+ ),
2696
+ True,
2697
+ ],
2698
+ }
2699
+ }
2700
+
2701
+
2702
+ class FixedGroupMeanParaphraseStringContainment(StringContainment):
2703
+ subgroup_column = "variant_type"
2704
+ # take mean of "paraphrase" variants only
2705
+ reduction_map = {
2706
+ "group_mean": {
2707
+ "agg_func": [
2708
+ "mean_paraphrase",
2709
+ lambda scd: mean_subgroup_score(
2710
+ subgroup_scores_dict=scd, subgroup_types=["paraphrase"]
2711
+ ),
2712
+ True,
2713
+ ],
2714
+ }
2715
+ }
2716
+
2717
+
2718
+ # using PDR
2719
+ class FixedGroupPDRParaphraseAccuracy(Accuracy):
2720
+ subgroup_column = "variant_type"
2721
+ reduction_map = {
2722
+ "group_mean": {
2723
+ "agg_func": [
2724
+ "pdr_paraphrase",
2725
+ lambda scd: performance_drop_rate(
2726
+ subgroup_scores_dict=scd,
2727
+ control_subgroup_types=["original"],
2728
+ comparison_subgroup_types=["paraphrase"],
2729
+ ),
2730
+ True,
2731
+ ],
2732
+ }
2733
+ }
2734
+
2735
+
2736
+ class FixedGroupPDRParaphraseStringContainment(StringContainment):
2737
+ subgroup_column = "variant_type"
2738
+ reduction_map = {
2739
+ "group_mean": {
2740
+ "agg_func": [
2741
+ "pdr_paraphrase",
2742
+ lambda scd: performance_drop_rate(
2743
+ subgroup_scores_dict=scd,
2744
+ control_subgroup_types=["original"],
2745
+ comparison_subgroup_types=["paraphrase"],
2746
+ ),
2747
+ True,
2748
+ ],
2749
+ }
2750
+ }
2751
+
2752
+
2753
+ class GroupMeanTokenOverlap(TokenOverlap):
2754
+ reduction_map = {
2755
+ "group_mean": {
2756
+ "agg_func": ["mean", nan_mean, False],
2757
+ "score_fields": ["f1", "precision", "recall"],
2758
+ }
2759
+ }
2760
+
2761
+
2762
+ # using Cohens's h for proportions
2763
+ class FixedGroupNormCohensHParaphraseAccuracy(Accuracy):
2764
+ subgroup_column = "variant_type"
2765
+ reduction_map = {
2766
+ "group_mean": {
2767
+ "agg_func": [
2768
+ "norm_cohens_h_paraphrase",
2769
+ lambda scd: normalized_cohens_h(
2770
+ subgroup_scores_dict=scd,
2771
+ control_subgroup_types=["original"],
2772
+ comparison_subgroup_types=["paraphrase"],
2773
+ ),
2774
+ True,
2775
+ ],
2776
+ }
2777
+ }
2778
+
2779
+
2780
+ class FixedGroupNormCohensHParaphraseStringContainment(StringContainment):
2781
+ subgroup_column = "variant_type"
2782
+ reduction_map = {
2783
+ "group_mean": {
2784
+ "agg_func": [
2785
+ "norm_cohens_h_paraphrase",
2786
+ lambda scd: normalized_cohens_h(
2787
+ subgroup_scores_dict=scd,
2788
+ control_subgroup_types=["original"],
2789
+ comparison_subgroup_types=["paraphrase"],
2790
+ ),
2791
+ True,
2792
+ ],
2793
+ }
2794
+ }
2795
+
2796
+
2797
+ # using Hedges' g (takes into account internal variation in group scores)
2798
+ class FixedGroupNormHedgesGParaphraseAccuracy(Accuracy):
2799
+ subgroup_column = "variant_type"
2800
+ reduction_map = {
2801
+ "group_mean": {
2802
+ "agg_func": [
2803
+ "norm_hedges_g_paraphrase",
2804
+ lambda scd: normalized_hedges_g(
2805
+ subgroup_scores_dict=scd,
2806
+ control_subgroup_types=["original"],
2807
+ comparison_subgroup_types=["paraphrase"],
2808
+ ),
2809
+ True,
2810
+ ],
2811
+ }
2812
+ }
2813
+
2814
+
2815
+ class FixedGroupNormHedgesGParaphraseStringContainment(StringContainment):
2816
+ subgroup_column = "variant_type"
2817
+ reduction_map = {
2818
+ "group_mean": {
2819
+ "agg_func": [
2820
+ "norm_hedges_g_paraphrase",
2821
+ lambda scd: normalized_hedges_g(
2822
+ subgroup_scores_dict=scd,
2823
+ control_subgroup_types=["original"],
2824
+ comparison_subgroup_types=["paraphrase"],
2825
+ ),
2826
+ True,
2827
+ ],
2828
+ }
2829
+ }
2830
+
2831
+
2832
+ # for above metrics, take absolute value of group score first; this measures variation in either direction
2833
+ class FixedGroupAbsvalNormCohensHParaphraseAccuracy(Accuracy):
2834
+ subgroup_column = "variant_type"
2835
+ reduction_map = {
2836
+ "group_mean": {
2837
+ "agg_func": [
2838
+ "absval_norm_cohens_h_paraphrase",
2839
+ lambda scd: np.abs(
2840
+ normalized_cohens_h(
2841
+ subgroup_scores_dict=scd,
2842
+ control_subgroup_types=["original"],
2843
+ comparison_subgroup_types=["paraphrase"],
2844
+ )
2845
+ ),
2846
+ True,
2847
+ ],
2848
+ }
2849
+ }
2850
+
2851
+
2852
+ class FixedGroupAbsvalNormCohensHParaphraseStringContainment(StringContainment):
2853
+ subgroup_column = "variant_type"
2854
+ reduction_map = {
2855
+ "group_mean": {
2856
+ "agg_func": [
2857
+ "absval_norm_cohens_h_paraphrase",
2858
+ lambda scd: np.abs(
2859
+ normalized_cohens_h(
2860
+ subgroup_scores_dict=scd,
2861
+ control_subgroup_types=["original"],
2862
+ comparison_subgroup_types=["paraphrase"],
2863
+ )
2864
+ ),
2865
+ True,
2866
+ ],
2867
+ }
2868
+ }
2869
+
2870
+
2871
+ class FixedGroupAbsvalNormHedgesGParaphraseAccuracy(Accuracy):
2872
+ subgroup_column = "variant_type"
2873
+ reduction_map = {
2874
+ "group_mean": {
2875
+ "agg_func": [
2876
+ "absval_norm_hedges_g_paraphrase",
2877
+ lambda scd: np.abs(
2878
+ normalized_hedges_g(
2879
+ subgroup_scores_dict=scd,
2880
+ control_subgroup_types=["original"],
2881
+ comparison_subgroup_types=["paraphrase"],
2882
+ )
2883
+ ),
2884
+ True,
2885
+ ],
2886
+ }
2887
+ }
2888
+
2889
+
2890
+ class FixedGroupAbsvalNormHedgesGParaphraseStringContainment(StringContainment):
2891
+ subgroup_column = "variant_type"
2892
+ reduction_map = {
2893
+ "group_mean": {
2894
+ "agg_func": [
2895
+ "absval_norm_hedges_g_paraphrase",
2896
+ lambda scd: np.abs(
2897
+ normalized_hedges_g(
2898
+ subgroup_scores_dict=scd,
2899
+ control_subgroup_types=["original"],
2900
+ comparison_subgroup_types=["paraphrase"],
2901
+ )
2902
+ ),
2903
+ True,
2904
+ ],
2905
+ }
2906
+ }