Skip to content

Commit

Permalink
Fix F1 metric (#3187)
Browse files Browse the repository at this point in the history
* Move LabelInfo and HLabelInfo to otx.core.types.label

Signed-off-by: Kim, Vinnam <[email protected]>

* Refactor LabelInfo

 - Prevent label_info from being over-written for OVModel
 - Remove runtime patch for HLabel head
 - Move LabelInfo to the dedicate source file otx.core.types.label

* Fix tests

* Remove label_info equivalance check from visualprompting label_info setter

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix test_otx_explain_e2e not to overwrite model.num_classes

Signed-off-by: Kim, Vinnam <[email protected]>

* Add missing unit tests

Signed-off-by: Kim, Vinnam <[email protected]>

* Improve accessbility for otx.core.types

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix NullLabelInfo unit test

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix FMeasure and best_confidence_threshold mechanism

* Update F1 test overriding to FMeasureCallable

 - Revisit regression test as well
 - Fix unit tests according to F1Measure rework

* Add missing tests/utils.py

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix OTXCLI to return error code if an exception is raised

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix intg test errors

Signed-off-by: Kim, Vinnam <[email protected]>

* Update src/otx/core/metrics/fmeasure.py

Co-authored-by: Harim Kang <[email protected]>

* Add debug msg

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix ruff error

Signed-off-by: Kim, Vinnam <[email protected]>

---------

Signed-off-by: Kim, Vinnam <[email protected]>
Co-authored-by: Harim Kang <[email protected]>
  • Loading branch information
vinnamkim and harimkang authored Mar 25, 2024
1 parent 6abf23d commit ec23ba8
Show file tree
Hide file tree
Showing 16 changed files with 381 additions and 240 deletions.
67 changes: 20 additions & 47 deletions src/otx/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import dataclasses
import sys
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
from warnings import warn
Expand All @@ -28,7 +29,6 @@
if TYPE_CHECKING:
from jsonargparse._actions import _ActionSubCommands

from otx.core.metrics import MetricCallable

_ENGINE_AVAILABLE = True
try:
Expand Down Expand Up @@ -157,7 +157,7 @@ def engine_subcommand_parser(subcommand: str, **kwargs) -> tuple[ArgumentParser,
OTXModel,
"model",
required=False,
skip={"optimizer", "scheduler", "metric"},
skip={"optimizer", "scheduler"},
**model_kwargs,
)
# Datamodule Settings
Expand Down Expand Up @@ -345,18 +345,14 @@ def instantiate_classes(self, instantiate_engine: bool = True) -> None:
if self.subcommand in self.engine_subcommands():
# For num_classes update, Model and Metric are instantiated separately.
model_config = self.config[self.subcommand].pop("model")
metric_config = self.config[self.subcommand].get("metric")

# Instantiate the things that don't need to special handling
self.config_init = self.parser.instantiate_classes(self.config)
self.workspace = self.get_config_value(self.config_init, "workspace")
self.datamodule = self.get_config_value(self.config_init, "data")

# Instantiate the model and needed components
self.model, self.optimizer, self.scheduler = self.instantiate_model(
model_config=model_config,
metric_config=metric_config,
)
self.model, self.optimizer, self.scheduler = self.instantiate_model(model_config=model_config)

if instantiate_engine:
self.engine = self.instantiate_engine()
Expand All @@ -377,26 +373,7 @@ def instantiate_engine(self) -> Engine:
**engine_kwargs,
)

def instantiate_metric(self, metric_config: Namespace) -> MetricCallable | None:
"""Instantiate the metric based on the metric_config.
It also pathces the num_classes according to the model classes information.
Args:
metric_config (Namespace): The metric configuration.
"""
from otx.core.utils.instantiators import partial_instantiate_class

if metric_config and self.subcommand in ["train", "test"]:
metric_kwargs = self.get_config_value(metric_config, "metric", namespace_to_dict(metric_config))
metric = partial_instantiate_class(metric_kwargs)
return metric[0] if isinstance(metric, list) else metric

msg = "The configuration of metric is None."
warn(msg, stacklevel=2)
return None

def instantiate_model(self, model_config: Namespace, metric_config: Namespace) -> tuple:
def instantiate_model(self, model_config: Namespace) -> tuple:
"""Instantiate the model based on the subcommand.
This method checks if the subcommand is one of the engine subcommands.
Expand Down Expand Up @@ -436,21 +413,15 @@ def instantiate_model(self, model_config: Namespace, metric_config: Namespace) -
if optimizers:
# Updates the instantiated optimizer.
model_config.init_args.optimizer = optimizers
self.config_init[self.subcommand]["optimizer"] = optimizers
self.config_init[self.subcommand]["optimizer"] = optimizer_kwargs

scheduler_kwargs = self.get_config_value(self.config_init, "scheduler", {})
scheduler_kwargs = scheduler_kwargs if isinstance(scheduler_kwargs, list) else [scheduler_kwargs]
schedulers = partial_instantiate_class([_sch for _sch in scheduler_kwargs if _sch])
if schedulers:
# Updates the instantiated scheduler.
model_config.init_args.scheduler = schedulers
self.config_init[self.subcommand]["scheduler"] = schedulers

# Instantiate the metric with changing the num_classes
metric = self.instantiate_metric(metric_config)
if metric:
model_config.init_args.metric = metric
self.config_init[self.subcommand]["metric"] = metric
self.config_init[self.subcommand]["scheduler"] = scheduler_kwargs

# Parses the OTXModel separately to update num_classes.
model_parser = ArgumentParser()
Expand Down Expand Up @@ -520,18 +491,19 @@ def save_config(self, work_dir: Path) -> None:
The configuration is saved as a YAML file in the engine's working directory.
"""
self.config[self.subcommand].pop("workspace", None)
# TODO(vinnamki): Do not save for now.
# Revisit it after changing the optimizer and scheduler instantiating.
# self.get_subcommand_parser(self.subcommand).save(
# cfg=self.config.get(str(self.subcommand), self.config),
# path=work_dir / "configs.yaml",
# overwrite=True,
# multifile=False,
# skip_check=True,
# )
# For assert statement in the test
with (work_dir / "configs.yaml").open("w") as fp:
yaml.safe_dump({"model": None, "engine": None, "data": None}, fp)
# TODO(vinnamki): Revisit it after changing the optimizer and scheduler instantiating.
cfg = deepcopy(self.config.get(str(self.subcommand), self.config))
cfg.model.init_args.pop("optimizer")
cfg.model.init_args.pop("scheduler")
cfg.model.init_args.pop("hlabel_info")

self.get_subcommand_parser(self.subcommand).save(
cfg=cfg,
path=work_dir / "configs.yaml",
overwrite=True,
multifile=False,
skip_check=True,
)

# if train -> Update `.latest` folder
self.update_latest(work_dir=work_dir)
Expand Down Expand Up @@ -586,6 +558,7 @@ def run(self) -> None:
fn(**fn_kwargs)
except Exception:
self.console.print_exception(width=self.console.width)
raise
self.save_config(work_dir=Path(self.engine.work_dir))
else:
msg = f"Unrecognized subcommand: {self.subcommand}"
Expand Down
95 changes: 63 additions & 32 deletions src/otx/core/metrics/fmeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torch import Tensor
from torchmetrics import Metric

from otx.core.types.label import LabelInfo

logger = logging.getLogger()
ALL_CLASSES_NAME = "All Classes"

Expand Down Expand Up @@ -632,9 +634,7 @@ class FMeasure(Metric):
to True.
Args:
num_classes (int): The number of classes.
best_confidence_threshold (float | None): Pre-defined best confidence threshold. If this value is None, then
FMeasure will find best confidence threshold. Defaults to None.
label_info (int): Dataclass including label information.
vary_nms_threshold (bool): if True the maximal F-measure is determined by optimizing for different NMS threshold
values. Defaults to False.
cross_class_nms (bool): Whether non-max suppression should be applied cross-class. If True this will eliminate
Expand All @@ -643,22 +643,31 @@ class FMeasure(Metric):

def __init__(
self,
num_classes: int,
best_confidence_threshold: float | None = None,
label_info: LabelInfo,
vary_nms_threshold: bool = False,
cross_class_nms: bool = False,
):
super().__init__()
self.vary_nms_threshold = vary_nms_threshold
self.cross_class_nms = cross_class_nms
self.preds: list[list[tuple]] = []
self.targets: list[list[tuple]] = []
self.num_classes: int = num_classes
self.label_info: LabelInfo = label_info

self._f_measure_per_confidence: dict | None = None
self._f_measure_per_nms: dict | None = None
self._best_confidence_threshold: float | None = best_confidence_threshold
self._best_confidence_threshold: float | None = None
self._best_nms_threshold: float | None = None
self._f_measure = 0.0

self.reset()

def reset(self) -> None:
"""Reset for every validation and test epoch.
Please be careful that some variables should not be reset for each epoch.
"""
super().reset()
self.preds: list[list[tuple]] = []
self.targets: list[list[tuple]] = []

def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]) -> None:
"""Update total predictions and targets from given batch predicitons and targets."""
Expand All @@ -680,8 +689,14 @@ def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]
],
)

def compute(self) -> dict:
"""Compute f1 score metric."""
def compute(self, best_confidence_threshold: float | None = None) -> dict:
"""Compute f1 score metric.
Args:
best_confidence_threshold (float | None): Pre-defined best confidence threshold.
If this value is None, then FMeasure will find best confidence threshold and
store it as member variable. Defaults to None.
"""
boxes_pair = _FMeasureCalculator(self.targets, self.preds)
result = boxes_pair.evaluate_detections(
result_based_nms_threshold=self.vary_nms_threshold,
Expand All @@ -690,26 +705,37 @@ def compute(self) -> dict:
)
self._f_measure_per_label = {label: result.best_f_measure_per_class[label] for label in self.classes}

if self.best_confidence_threshold is not None:
if best_confidence_threshold is not None:
(index,) = np.where(
np.isclose(list(np.arange(*boxes_pair.confidence_range)), self.best_confidence_threshold),
np.isclose(list(np.arange(*boxes_pair.confidence_range)), best_confidence_threshold),
)
self._f_measure = result.per_confidence.all_classes_f_measure_curve[int(index)]
computed_f_measure = result.per_confidence.all_classes_f_measure_curve[int(index)]
else:
self._f_measure_per_confidence = {
"xs": list(np.arange(*boxes_pair.confidence_range)),
"ys": result.per_confidence.all_classes_f_measure_curve,
}
self._best_confidence_threshold = result.per_confidence.best_threshold
computed_f_measure = result.best_f_measure
best_confidence_threshold = result.per_confidence.best_threshold

# TODO(jaegukhyun): There was no reset() function in this metric
# There are some variables dependent on the best F1 metric, e.g., best_confidence_threshold
# Now we added reset() function and revise some mechanism about it. However,
# It is still unsure that it is correctly working with the implemented reset function.
# Need to revisit. See other metric implement and this to learn how they work
# https://github.com/Lightning-AI/torchmetrics/blob/v1.2.1/src/torchmetrics/metric.py
if self._f_measure < computed_f_measure:
self._f_measure = result.best_f_measure
self._best_confidence_threshold = best_confidence_threshold

if self.vary_nms_threshold and result.per_nms is not None:
self._f_measure_per_nms = {
"xs": list(np.arange(*boxes_pair.nms_range)),
"ys": result.per_nms.all_classes_f_measure_curve,
}
self._best_nms_threshold = result.per_nms.best_threshold
return {"f1-score": Tensor([self.f_measure])}
if self.vary_nms_threshold and result.per_nms is not None:
self._f_measure_per_nms = {
"xs": list(np.arange(*boxes_pair.nms_range)),
"ys": result.per_nms.all_classes_f_measure_curve,
}
self._best_nms_threshold = result.per_nms.best_threshold

return {"f1-score": Tensor([computed_f_measure])}

@property
def f_measure(self) -> float:
Expand All @@ -727,15 +753,16 @@ def f_measure_per_confidence(self) -> None | dict:
return self._f_measure_per_confidence

@property
def best_confidence_threshold(self) -> None | float:
def best_confidence_threshold(self) -> float:
"""Returns best confidence threshold as ScoreMetric if exists."""
if self._best_confidence_threshold is None:
msg = (
"Cannot obtain best_confidence_threshold updated previously. "
"Please execute self.update(best_confidence_threshold=None) first."
)
raise RuntimeError(msg)
return self._best_confidence_threshold

@best_confidence_threshold.setter
def best_confidence_threshold(self, value: float) -> None:
"""Setter for best_confidence_threshold."""
self._best_confidence_threshold = value

@property
def f_measure_per_nms(self) -> None | dict:
"""Returns the curve for f-measure per nms threshold as CurveMetric if exists."""
Expand All @@ -749,7 +776,11 @@ def best_nms_threshold(self) -> None | float:
@property
def classes(self) -> list[str]:
"""Class information of dataset."""
if self.num_classes is None:
msg = "classes is called before num_classes is set."
raise ValueError(msg)
return [str(idx) for idx in range(self.num_classes)]
return self.label_info.label_names


def _f_measure_callable(label_info: LabelInfo) -> FMeasure:
return FMeasure(label_info=label_info)


FMeasureCallable = _f_measure_callable
13 changes: 10 additions & 3 deletions src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from __future__ import annotations

import contextlib
import inspect
import json
import logging
import warnings
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Generic, NamedTuple
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple

import numpy as np
import openvino
Expand Down Expand Up @@ -290,8 +291,14 @@ def _convert_pred_entity_to_compute_metric(
"""Convert given inputs to a Python dictionary for the metric computation."""
raise NotImplementedError

def _log_metrics(self, meter: Metric, key: str) -> None:
results: dict[str, Tensor] = meter.compute()
def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None:
sig = inspect.signature(meter.compute)
filtered_kwargs = {key: value for key, value in compute_kwargs.items() if key in sig.parameters}
if removed_kwargs := set(compute_kwargs.keys()).difference(filtered_kwargs.keys()):
msg = f"These keyword arguments are removed since they are not in the function signature: {removed_kwargs}"
logger.debug(msg)

results: dict[str, Tensor] = meter.compute(**filtered_kwargs)

if not isinstance(results, dict):
raise TypeError(results)
Expand Down
Loading

0 comments on commit ec23ba8

Please sign in to comment.