Natooz commited on
Commit
65c885b
1 Parent(s): 3a718e9

adding input assertions

Browse files
Files changed (1) hide show
  1. ece.py +9 -5
ece.py CHANGED
@@ -16,7 +16,7 @@ from typing import Dict
16
 
17
  import evaluate
18
  import datasets
19
- from torch import Tensor, LongTensor, amax
20
  from torchmetrics.functional.classification.calibration_error import (
21
  binary_calibration_error,
22
  multiclass_calibration_error,
@@ -109,14 +109,18 @@ class ECE(evaluate.Metric):
109
  references = LongTensor(references)
110
 
111
  # Determine number of classes / binary or multiclass
 
 
112
  binary = True
113
- if predictions.dim() == references.dim() + 1:
114
  binary = False
115
  if "num_classes" not in kwargs:
116
  kwargs["num_classes"] = int(predictions.shape[1])
117
- else:
118
- raise ValueError("Bad input shape. Expected to have predictions with shape (N,C,...) and references"
119
- f"with shape (N,...), but got {predictions.shape} and {references.shape}")
 
 
120
 
121
  # Compute the calibration
122
  if binary:
 
16
 
17
  import evaluate
18
  import datasets
19
+ from torch import Tensor, LongTensor
20
  from torchmetrics.functional.classification.calibration_error import (
21
  binary_calibration_error,
22
  multiclass_calibration_error,
 
109
  references = LongTensor(references)
110
 
111
  # Determine number of classes / binary or multiclass
112
+ error_msg = "Expected to have predictions with shape (N,C,...) for multiclass or (N,...) for binary, " \
113
+ f"and references with shape (N,...), but got {predictions.shape} and {references.shape}"
114
  binary = True
115
+ if predictions.dim() == references.dim() + 1: # multiclass
116
  binary = False
117
  if "num_classes" not in kwargs:
118
  kwargs["num_classes"] = int(predictions.shape[1])
119
+ elif predictions.dim() == references.dim() and "num_classes" in kwargs:
120
+ raise ValueError("You gave the num_classes argument, with predictions and references having the"
121
+ "same number of dimensions. " + error_msg)
122
+ elif predictions.dim() != references.dim():
123
+ raise ValueError("Bad input shape. " + error_msg)
124
 
125
  # Compute the calibration
126
  if binary: