Skip to content

Commit

Permalink
fix some metrics feature types (huggingface#867)
Browse files Browse the repository at this point in the history
lhoestq authored Nov 19, 2020
1 parent e983e49 commit 7478935
Showing 4 changed files with 93 additions and 12 deletions.
21 changes: 18 additions & 3 deletions metrics/accuracy/accuracy.py
Original file line number Diff line number Diff line change
@@ -39,22 +39,37 @@
accuracy: Accuracy score.
"""

_CITATION = """\
@article{scikit-learn,
title={Scikit-learn: Machine Learning in {P}ython},
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
journal={Journal of Machine Learning Research},
volume={12},
pages={2825--2830},
year={2011}
}
"""


class Accuracy(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Value("int"),
"references": datasets.Value("int"),
"predictions": datasets.Value("int32"),
"references": datasets.Value("int32"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
)

def _compute(self, predictions, references, normalize=True, sample_weight=None):
return {
"accuracy": accuracy_score(references, predictions, normalize, sample_weight),
"accuracy": accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight),
}
28 changes: 25 additions & 3 deletions metrics/f1/f1.py
Original file line number Diff line number Diff line change
@@ -55,22 +55,44 @@
f1: F1 score.
"""

_CITATION = """\
@article{scikit-learn,
title={Scikit-learn: Machine Learning in {P}ython},
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
journal={Journal of Machine Learning Research},
volume={12},
pages={2825--2830},
year={2011}
}
"""


class F1(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Value("int"),
"references": datasets.Value("int"),
"predictions": datasets.Value("int32"),
"references": datasets.Value("int32"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"],
)

def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None):
return {
"f1": f1_score(references, predictions, labels, pos_label, average, sample_weight),
"f1": f1_score(
references,
predictions,
labels=labels,
pos_label=pos_label,
average=average,
sample_weight=sample_weight,
),
}
28 changes: 25 additions & 3 deletions metrics/precision/precision.py
Original file line number Diff line number Diff line change
@@ -57,22 +57,44 @@
precision: Precision score.
"""

_CITATION = """\
@article{scikit-learn,
title={Scikit-learn: Machine Learning in {P}ython},
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
journal={Journal of Machine Learning Research},
volume={12},
pages={2825--2830},
year={2011}
}
"""


class Precision(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Value("int"),
"references": datasets.Value("int"),
"predictions": datasets.Value("int32"),
"references": datasets.Value("int32"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html"],
)

def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None):
return {
"precision": precision_score(references, predictions, labels, pos_label, average, sample_weight),
"precision": precision_score(
references,
predictions,
labels=labels,
pos_label=pos_label,
average=average,
sample_weight=sample_weight,
),
}
28 changes: 25 additions & 3 deletions metrics/recall/recall.py
Original file line number Diff line number Diff line change
@@ -57,22 +57,44 @@
recall: Recall score.
"""

_CITATION = """\
@article{scikit-learn,
title={Scikit-learn: Machine Learning in {P}ython},
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
journal={Journal of Machine Learning Research},
volume={12},
pages={2825--2830},
year={2011}
}
"""


class Recall(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Value("int"),
"references": datasets.Value("int"),
"predictions": datasets.Value("int32"),
"references": datasets.Value("int32"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html"],
)

def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None):
return {
"recall": recall_score(references, predictions, labels, pos_label, average, sample_weight),
"recall": recall_score(
references,
predictions,
labels=labels,
pos_label=pos_label,
average=average,
sample_weight=sample_weight,
),
}

0 comments on commit 7478935

Please sign in to comment.