Skip to content

Commit

Permalink
Improve classification metrics (#3285)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored Jan 22, 2025
1 parent 5c6e6c2 commit 2c14291
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 82 deletions.
85 changes: 73 additions & 12 deletions src/helm/benchmark/metrics/classification_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional

from sklearn.metrics import f1_score
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.preprocessing import MultiLabelBinarizer

from helm.benchmark.adaptation.request_state import RequestState
Expand All @@ -9,9 +9,14 @@
from helm.benchmark.metrics.metric import MetricName
from helm.benchmark.metrics.statistic import Stat
from helm.benchmark.scenarios.scenario import Reference
from helm.common.hierarchical_logger import hlog
from helm.common.request import GeneratedOutput


def _normalize_label_text(text: str) -> str:
return normalize_text(text, should_remove_articles=False)


class ClassificationMetric(EvaluateInstancesMetric):
"""Defines metrics for multi-class classification using the generation adapter.
Expand All @@ -23,17 +28,57 @@ class ClassificationMetric(EvaluateInstancesMetric):
reference. The predicted class for each instance is the normalized text of the generation.
Note:
- The set of classes is derived from the correct references from all the instances.
This means that classes may be omitted if they are never used as a correct reference.
- It is highly recommended to specify the set of classes should be specified using the
`labels` parameter. Otherwise, the set of classes is derived from the correct references
from all the instances. This means that classes may be incorrectly omitted if they are never
used as a correct reference.
- The `averages` parameter is a list of averaging methods to be used.
It has the same meaning `average` as in scikit-learn.
- Generations that are not in any of the known classes are counted as a
negative prediction for every class.
- Perturbed classes are considered different classes from unperturbed
classes.
- Currently, multi-label classification is not supported.
"""

def __init__(self, delimiter: Optional[str] = None):
AVERAGE_OPTIONS = ["micro", "macro", "weighted", None]
SCORE_OPTIONS = ["f1", "precision", "recall"]

def __init__(
self,
averages: Optional[List[Optional[str]]] = None,
labels: Optional[List[str]] = None,
scores: Optional[List[str]] = None,
delimiter: Optional[str] = None,
) -> None:
"""Creates metrics for multi-class classification.
:param delimiter: For multi-label classification, the string delimiter between classes in the model's output.
:param average: The list of scores to compute (e.g. "f1", "precision", "recall").
Defaults to ["f1"].
:param average: The averaging methods (e.g. "micro", "macro", "weighted") to be used.
It has the same meaning `average` as in scikit-learn.
Defaults to ["macro", "micro"].
:param labels: The set of labels.
:return: A list of `Stat` objects.
"""
self.averages = averages or ["macro", "micro"]
for average in self.averages:
if average not in ClassificationMetric.AVERAGE_OPTIONS:
raise ValueError(
f"Each value in `averages` must be set to one of {ClassificationMetric.AVERAGE_OPTIONS}."
)
self.scores = scores or ["f1"]
for score_name in self.scores:
if score_name not in ClassificationMetric.SCORE_OPTIONS:
raise ValueError(f"Each value in `scores` must be set to one of {ClassificationMetric.SCORE_OPTIONS}.")
self.delimiter = delimiter
self.labels = labels
hlog(
"WARNING: `labels` were not set on `ClassificationMetric`, "
"so they will be inferred from target references. "
"It is recommend to explicitly set `labels` on `ClassificationMetric`."
)

def is_multi_label(self) -> bool:
return bool(self.delimiter)
Expand All @@ -57,20 +102,36 @@ def evaluate_instances(self, request_states: List[RequestState], eval_cache_path
references = request_state.instance.all_correct_references
if not self.is_multi_label():
assert len(references) == 1
correct_ref_texts = [normalize_text(ref.output.text) for ref in references if ref.output.text]
correct_ref_texts = [_normalize_label_text(ref.output.text) for ref in references if ref.output.text]
y_true.append(correct_ref_texts)

input_text = request_state.result.completions[0].text
predictions = input_text.split(self.delimiter) if self.is_multi_label() else [input_text]
y_pred.append([normalize_text(pred) for pred in predictions if pred])
labels: List[str] = list(set(y for ys in y_true for y in ys))
mlb = MultiLabelBinarizer().fit([labels])
y_pred.append([_normalize_label_text(pred) for pred in predictions if pred])
mlb = MultiLabelBinarizer().fit([self.labels] if self.labels else y_true)
y_true = mlb.transform(y_true)
y_pred = mlb.transform(y_pred)
return [
Stat(MetricName("classification_macro_f1")).add(f1_score(y_pred=y_pred, y_true=y_true, average="macro")),
Stat(MetricName("classification_micro_f1")).add(f1_score(y_pred=y_pred, y_true=y_true, average="micro")),
]
stats: List[Stat] = []
for average in self.averages:
for score_name in self.scores:
if score_name == "f1":
score_value = f1_score(y_pred=y_pred, y_true=y_true, average=average)
elif score_name == "precision":
score_value = precision_score(y_pred=y_pred, y_true=y_true, average=average)
elif score_name == "recall":
score_value = recall_score(y_pred=y_pred, y_true=y_true, average=average)
else:
raise ValueError(
f"Unknown score name: '{score_name}' - expected one of ['f1', 'precision', 'recall']"
)
if average is None:
for mlb_class, class_score_value in zip(mlb.classes_, score_value):
stats.append(
Stat(MetricName(f"classification_{mlb_class}_{score_name}")).add(class_score_value)
)
else:
stats.append(Stat(MetricName(f"classification_{average}_{score_name}")).add(score_value))
return stats


class MultipleChoiceClassificationMetric(EvaluateInstancesMetric):
Expand Down
11 changes: 9 additions & 2 deletions src/helm/benchmark/metrics/common_metric_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,18 @@ def get_language_modeling_metric_specs(names: List[str]) -> List[MetricSpec]:
]


def get_classification_metric_specs(delimiter: Optional[str] = None) -> List[MetricSpec]:
def get_classification_metric_specs(
labels: Optional[List[str]] = None, delimiter: Optional[str] = None
) -> List[MetricSpec]:
extra_args: Dict[str, Any] = {}
if labels:
extra_args["labels"] = labels
if delimiter:
extra_args["delimiter"] = delimiter
return [
MetricSpec(
class_name="helm.benchmark.metrics.classification_metrics.ClassificationMetric",
args={"delimiter": delimiter},
args=extra_args,
)
]

Expand Down
Loading

0 comments on commit 2c14291

Please sign in to comment.