upload
Browse files- .idea/.gitignore +3 -0
- .idea/ece.iml +12 -0
- .idea/inspectionProfiles/Project_Default.xml +14 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- README.md +34 -20
- ece.py +61 -40
.idea/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
.idea/ece.iml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="jdk" jdkName="Python 3.8 (venv) (6)" jdkType="Python SDK" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
<component name="PyDocumentationSettings">
|
9 |
+
<option name="format" value="PLAIN" />
|
10 |
+
<option name="myDocStringFormat" value="Plain" />
|
11 |
+
</component>
|
12 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
5 |
+
<option name="ignoredPackages">
|
6 |
+
<value>
|
7 |
+
<list size="1">
|
8 |
+
<item index="0" class="java.lang.String" itemvalue="torch" />
|
9 |
+
</list>
|
10 |
+
</value>
|
11 |
+
</option>
|
12 |
+
</inspection_tool>
|
13 |
+
</profile>
|
14 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (venv) (6)" project-jdk-type="Python SDK" />
|
4 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/ece.iml" filepath="$PROJECT_DIR$/.idea/ece.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
README.md
CHANGED
@@ -5,7 +5,7 @@ datasets:
|
|
5 |
tags:
|
6 |
- evaluate
|
7 |
- metric
|
8 |
-
description: "
|
9 |
sdk: gradio
|
10 |
sdk_version: 3.19.1
|
11 |
app_file: app.py
|
@@ -14,37 +14,51 @@ pinned: false
|
|
14 |
|
15 |
# Metric Card for ECE
|
16 |
|
17 |
-
***Module Card Instructions:*** *Fill out the following subsections. Feel free to take a look at existing metric cards if you'd like examples.*
|
18 |
-
|
19 |
## Metric Description
|
20 |
-
*Give a brief overview of this metric, including what task(s) it is usually used for, if any.*
|
21 |
|
22 |
-
|
23 |
-
|
|
|
24 |
|
25 |
-
|
26 |
|
27 |
### Inputs
|
28 |
*List all input arguments in the format below*
|
29 |
-
- **input_field** *(
|
|
|
30 |
|
31 |
### Output Values
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
*State the range of possible values that the metric's output can take, as well as what in that range is considered good. For example: "This metric can take on any value between 0 and 100, inclusive. Higher scores are better."*
|
36 |
-
|
37 |
-
#### Values from Popular Papers
|
38 |
-
*Give examples, preferrably with links to leaderboards or publications, to papers that have reported this metric, along with the values they have reported.*
|
39 |
|
40 |
### Examples
|
41 |
-
*Give code examples of the metric being used. Try to include examples that clear up any potential ambiguity left from the metric description above. If possible, provide a range of examples that show both typical and atypical results, as well as examples where a variety of input parameters are passed.*
|
42 |
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
## Citation
|
47 |
-
*Cite the source where this metric was introduced.*
|
48 |
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
tags:
|
6 |
- evaluate
|
7 |
- metric
|
8 |
+
description: "Expected calibration error (ECE)"
|
9 |
sdk: gradio
|
10 |
sdk_version: 3.19.1
|
11 |
app_file: app.py
|
|
|
14 |
|
15 |
# Metric Card for ECE
|
16 |
|
|
|
|
|
17 |
## Metric Description
|
|
|
18 |
|
19 |
+
This metrics computes the expected calibration error (ECE).
|
20 |
+
It directly calls the torchmetrics package:
|
21 |
+
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html
|
22 |
|
23 |
+
## How to Use
|
24 |
|
25 |
### Inputs
|
26 |
*List all input arguments in the format below*
|
27 |
+
- **input_field** *(tensor or numpy array, float32): predictions (after softmax). They must have a shape (N,C,...) if multiclass, or (N,...) if binary.*
|
28 |
+
- **references** *(tensor or numpy array, int64): reference for each prediction, with a shape (N,...).*
|
29 |
|
30 |
### Output Values
|
31 |
|
32 |
+
ECE as float.
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
### Examples
|
|
|
35 |
|
36 |
+
```Python
|
37 |
+
ce = evaluate.load("Natooz/ece")
|
38 |
+
results = ece.compute(
|
39 |
+
references=np.array([[0.25, 0.20, 0.55],
|
40 |
+
[0.55, 0.05, 0.40],
|
41 |
+
[0.10, 0.30, 0.60],
|
42 |
+
[0.90, 0.05, 0.05]]),
|
43 |
+
predictions=np.array(),
|
44 |
+
num_classes=3,
|
45 |
+
n_bins=3,
|
46 |
+
norm="l1",
|
47 |
+
)
|
48 |
+
print(results)
|
49 |
+
```
|
50 |
|
51 |
## Citation
|
|
|
52 |
|
53 |
+
```bibtex
|
54 |
+
@inproceedings{NEURIPS2019_f8c0c968,
|
55 |
+
author = {Kumar, Ananya and Liang, Percy S and Ma, Tengyu},
|
56 |
+
booktitle = {Advances in Neural Information Processing Systems},
|
57 |
+
editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett},
|
58 |
+
publisher = {Curran Associates, Inc.},
|
59 |
+
title = {Verified Uncertainty Calibration},
|
60 |
+
url = {https://papers.nips.cc/paper_files/paper/2019/hash/f8c0c968632845cd133308b1a494967f-Abstract.html},
|
61 |
+
volume = {32},
|
62 |
+
year = {2019}
|
63 |
+
}
|
64 |
+
```
|
ece.py
CHANGED
@@ -11,58 +11,63 @@
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
-
|
|
|
15 |
|
16 |
import evaluate
|
17 |
import datasets
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
-
# TODO: Add BibTeX citation
|
21 |
_CITATION = """\
|
22 |
-
@InProceedings{huggingface:
|
23 |
-
title = {
|
24 |
-
authors={
|
25 |
-
year={
|
26 |
}
|
27 |
"""
|
28 |
|
29 |
-
# TODO: Add description of the module here
|
30 |
_DESCRIPTION = """\
|
31 |
-
This
|
|
|
|
|
32 |
"""
|
33 |
|
34 |
|
35 |
-
# TODO: Add description of the arguments of the module here
|
36 |
_KWARGS_DESCRIPTION = """
|
37 |
Calculates how good are predictions given some references, using certain scores
|
38 |
Args:
|
39 |
-
predictions: list of predictions to score.
|
40 |
-
|
41 |
-
references: list of reference for each prediction. Each
|
42 |
-
reference should be a string with tokens separated by spaces.
|
43 |
Returns:
|
44 |
-
|
45 |
-
another_score: description of the second score,
|
46 |
Examples:
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
>>> print(results)
|
53 |
-
{'
|
54 |
"""
|
55 |
|
56 |
-
# TODO: Define external resources urls if needed
|
57 |
-
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
|
58 |
-
|
59 |
|
60 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
61 |
class ECE(evaluate.Metric):
|
62 |
-
"""
|
|
|
|
|
|
|
63 |
|
64 |
def _info(self):
|
65 |
-
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
66 |
return evaluate.MetricInfo(
|
67 |
# This is the description that will appear on the modules page.
|
68 |
module_type="metric",
|
@@ -71,25 +76,41 @@ class ECE(evaluate.Metric):
|
|
71 |
inputs_description=_KWARGS_DESCRIPTION,
|
72 |
# This defines the format of each prediction and reference
|
73 |
features=datasets.Features({
|
74 |
-
'predictions': datasets.Value('
|
75 |
'references': datasets.Value('int64'),
|
76 |
}),
|
77 |
# Homepage of the module for documentation
|
78 |
-
homepage="
|
79 |
# Additional links to the codebase or references
|
80 |
-
codebase_urls=["
|
81 |
-
reference_urls=["
|
82 |
)
|
83 |
|
84 |
-
def
|
85 |
-
"""
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
93 |
return {
|
94 |
-
"
|
95 |
-
}
|
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Dict
|
16 |
|
17 |
import evaluate
|
18 |
import datasets
|
19 |
+
from torch import from_numpy, amax
|
20 |
+
from torchmetrics.functional.classification.calibration_error import binary_calibration_error, multiclass_calibration_error
|
21 |
+
from numpy import ndarray
|
22 |
|
23 |
|
|
|
24 |
_CITATION = """\
|
25 |
+
@InProceedings{huggingface:ece,
|
26 |
+
title = {Expected calibration error (ECE)},
|
27 |
+
authors={Nathan Fradet},
|
28 |
+
year={2023}
|
29 |
}
|
30 |
"""
|
31 |
|
|
|
32 |
_DESCRIPTION = """\
|
33 |
+
This metrics computes the expected calibration error (ECE).
|
34 |
+
It directly calls the torchmetrics package:
|
35 |
+
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html
|
36 |
"""
|
37 |
|
38 |
|
|
|
39 |
_KWARGS_DESCRIPTION = """
|
40 |
Calculates how good are predictions given some references, using certain scores
|
41 |
Args:
|
42 |
+
predictions: list of predictions to score. They must have a shape (N,C,...) if multiclass, or (N,...) if binary.
|
43 |
+
references: list of reference for each prediction, with a shape (N,...).
|
|
|
|
|
44 |
Returns:
|
45 |
+
ece: expected calibration error
|
|
|
46 |
Examples:
|
47 |
+
>>> ece = evaluate.load("Natooz/ece")
|
48 |
+
>>> results = ece.compute(
|
49 |
+
... references=np.array([[0.25, 0.20, 0.55],
|
50 |
+
... [0.55, 0.05, 0.40],
|
51 |
+
... [0.10, 0.30, 0.60],
|
52 |
+
... [0.90, 0.05, 0.05]]),
|
53 |
+
... predictions=np.array(),
|
54 |
+
... num_classes=3,
|
55 |
+
... n_bins=3,
|
56 |
+
... norm="l1",
|
57 |
+
... )
|
58 |
>>> print(results)
|
59 |
+
{'ece': 0.2000}
|
60 |
"""
|
61 |
|
|
|
|
|
|
|
62 |
|
63 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
64 |
class ECE(evaluate.Metric):
|
65 |
+
"""
|
66 |
+
Proxy to the BinaryCalibrationError (ECE) metric of the torchmetrics package:
|
67 |
+
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html
|
68 |
+
"""
|
69 |
|
70 |
def _info(self):
|
|
|
71 |
return evaluate.MetricInfo(
|
72 |
# This is the description that will appear on the modules page.
|
73 |
module_type="metric",
|
|
|
76 |
inputs_description=_KWARGS_DESCRIPTION,
|
77 |
# This defines the format of each prediction and reference
|
78 |
features=datasets.Features({
|
79 |
+
'predictions': datasets.Value('float32'),
|
80 |
'references': datasets.Value('int64'),
|
81 |
}),
|
82 |
# Homepage of the module for documentation
|
83 |
+
homepage="https://huggingface.co/spaces/Natooz/ece",
|
84 |
# Additional links to the codebase or references
|
85 |
+
codebase_urls=["https://github.com/Lightning-AI/torchmetrics/blob/v0.11.4/src/torchmetrics/classification/calibration_error.py"],
|
86 |
+
reference_urls=["https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html"]
|
87 |
)
|
88 |
|
89 |
+
def _compute(self, predictions=None, references=None, **kwargs) -> Dict[str, float]:
|
90 |
+
"""Returns the ece.
|
91 |
+
See the torchmetrics documentation for more information on the arguments to pass.
|
92 |
+
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html
|
93 |
+
predictions: (N,C,...) if multiclass or (N,...) if binary
|
94 |
+
references: (N,...)
|
95 |
+
|
96 |
+
If "num_classes" is not provided in a multiclasses setting, the number maximum label index will
|
97 |
+
be used as "num_classes".
|
98 |
+
"""
|
99 |
+
# Convert the input
|
100 |
+
if isinstance(predictions, ndarray):
|
101 |
+
predictions = from_numpy(predictions)
|
102 |
+
if isinstance(references, ndarray):
|
103 |
+
references = from_numpy(references)
|
104 |
+
|
105 |
+
max_label = amax(references, list(range(references.dim())))
|
106 |
+
if max_label > 1 and "num_classes" not in kwargs:
|
107 |
+
kwargs["num_classes"] = max_label
|
108 |
|
109 |
+
# Compute the calibration
|
110 |
+
if max_label > 1:
|
111 |
+
ece = multiclass_calibration_error(predictions, references, **kwargs)
|
112 |
+
else:
|
113 |
+
ece = binary_calibration_error(predictions, references, **kwargs)
|
114 |
return {
|
115 |
+
"ece": float(ece),
|
116 |
+
}
|