diff --git a/src/otx/cli/cli.py b/src/otx/cli/cli.py index c9aab524d3a..a16c22cc35f 100644 --- a/src/otx/cli/cli.py +++ b/src/otx/cli/cli.py @@ -8,6 +8,7 @@ import dataclasses import sys +from copy import deepcopy from pathlib import Path from typing import TYPE_CHECKING, Any, Optional from warnings import warn @@ -28,7 +29,6 @@ if TYPE_CHECKING: from jsonargparse._actions import _ActionSubCommands - from otx.core.metrics import MetricCallable _ENGINE_AVAILABLE = True try: @@ -157,7 +157,7 @@ def engine_subcommand_parser(subcommand: str, **kwargs) -> tuple[ArgumentParser, OTXModel, "model", required=False, - skip={"optimizer", "scheduler", "metric"}, + skip={"optimizer", "scheduler"}, **model_kwargs, ) # Datamodule Settings @@ -345,7 +345,6 @@ def instantiate_classes(self, instantiate_engine: bool = True) -> None: if self.subcommand in self.engine_subcommands(): # For num_classes update, Model and Metric are instantiated separately. model_config = self.config[self.subcommand].pop("model") - metric_config = self.config[self.subcommand].get("metric") # Instantiate the things that don't need to special handling self.config_init = self.parser.instantiate_classes(self.config) @@ -353,10 +352,7 @@ def instantiate_classes(self, instantiate_engine: bool = True) -> None: self.datamodule = self.get_config_value(self.config_init, "data") # Instantiate the model and needed components - self.model, self.optimizer, self.scheduler = self.instantiate_model( - model_config=model_config, - metric_config=metric_config, - ) + self.model, self.optimizer, self.scheduler = self.instantiate_model(model_config=model_config) if instantiate_engine: self.engine = self.instantiate_engine() @@ -377,26 +373,7 @@ def instantiate_engine(self) -> Engine: **engine_kwargs, ) - def instantiate_metric(self, metric_config: Namespace) -> MetricCallable | None: - """Instantiate the metric based on the metric_config. - - It also pathces the num_classes according to the model classes information. - - Args: - metric_config (Namespace): The metric configuration. - """ - from otx.core.utils.instantiators import partial_instantiate_class - - if metric_config and self.subcommand in ["train", "test"]: - metric_kwargs = self.get_config_value(metric_config, "metric", namespace_to_dict(metric_config)) - metric = partial_instantiate_class(metric_kwargs) - return metric[0] if isinstance(metric, list) else metric - - msg = "The configuration of metric is None." - warn(msg, stacklevel=2) - return None - - def instantiate_model(self, model_config: Namespace, metric_config: Namespace) -> tuple: + def instantiate_model(self, model_config: Namespace) -> tuple: """Instantiate the model based on the subcommand. This method checks if the subcommand is one of the engine subcommands. @@ -436,7 +413,7 @@ def instantiate_model(self, model_config: Namespace, metric_config: Namespace) - if optimizers: # Updates the instantiated optimizer. model_config.init_args.optimizer = optimizers - self.config_init[self.subcommand]["optimizer"] = optimizers + self.config_init[self.subcommand]["optimizer"] = optimizer_kwargs scheduler_kwargs = self.get_config_value(self.config_init, "scheduler", {}) scheduler_kwargs = scheduler_kwargs if isinstance(scheduler_kwargs, list) else [scheduler_kwargs] @@ -444,13 +421,7 @@ def instantiate_model(self, model_config: Namespace, metric_config: Namespace) - if schedulers: # Updates the instantiated scheduler. model_config.init_args.scheduler = schedulers - self.config_init[self.subcommand]["scheduler"] = schedulers - - # Instantiate the metric with changing the num_classes - metric = self.instantiate_metric(metric_config) - if metric: - model_config.init_args.metric = metric - self.config_init[self.subcommand]["metric"] = metric + self.config_init[self.subcommand]["scheduler"] = scheduler_kwargs # Parses the OTXModel separately to update num_classes. model_parser = ArgumentParser() @@ -520,18 +491,19 @@ def save_config(self, work_dir: Path) -> None: The configuration is saved as a YAML file in the engine's working directory. """ self.config[self.subcommand].pop("workspace", None) - # TODO(vinnamki): Do not save for now. - # Revisit it after changing the optimizer and scheduler instantiating. - # self.get_subcommand_parser(self.subcommand).save( - # cfg=self.config.get(str(self.subcommand), self.config), - # path=work_dir / "configs.yaml", - # overwrite=True, - # multifile=False, - # skip_check=True, - # ) - # For assert statement in the test - with (work_dir / "configs.yaml").open("w") as fp: - yaml.safe_dump({"model": None, "engine": None, "data": None}, fp) + # TODO(vinnamki): Revisit it after changing the optimizer and scheduler instantiating. + cfg = deepcopy(self.config.get(str(self.subcommand), self.config)) + cfg.model.init_args.pop("optimizer") + cfg.model.init_args.pop("scheduler") + cfg.model.init_args.pop("hlabel_info") + + self.get_subcommand_parser(self.subcommand).save( + cfg=cfg, + path=work_dir / "configs.yaml", + overwrite=True, + multifile=False, + skip_check=True, + ) # if train -> Update `.latest` folder self.update_latest(work_dir=work_dir) @@ -586,6 +558,7 @@ def run(self) -> None: fn(**fn_kwargs) except Exception: self.console.print_exception(width=self.console.width) + raise self.save_config(work_dir=Path(self.engine.work_dir)) else: msg = f"Unrecognized subcommand: {self.subcommand}" diff --git a/src/otx/core/metrics/fmeasure.py b/src/otx/core/metrics/fmeasure.py index 19be72844dc..bdcaa6af21f 100644 --- a/src/otx/core/metrics/fmeasure.py +++ b/src/otx/core/metrics/fmeasure.py @@ -11,6 +11,8 @@ from torch import Tensor from torchmetrics import Metric +from otx.core.types.label import LabelInfo + logger = logging.getLogger() ALL_CLASSES_NAME = "All Classes" @@ -632,9 +634,7 @@ class FMeasure(Metric): to True. Args: - num_classes (int): The number of classes. - best_confidence_threshold (float | None): Pre-defined best confidence threshold. If this value is None, then - FMeasure will find best confidence threshold. Defaults to None. + label_info (int): Dataclass including label information. vary_nms_threshold (bool): if True the maximal F-measure is determined by optimizing for different NMS threshold values. Defaults to False. cross_class_nms (bool): Whether non-max suppression should be applied cross-class. If True this will eliminate @@ -643,22 +643,31 @@ class FMeasure(Metric): def __init__( self, - num_classes: int, - best_confidence_threshold: float | None = None, + label_info: LabelInfo, vary_nms_threshold: bool = False, cross_class_nms: bool = False, ): super().__init__() self.vary_nms_threshold = vary_nms_threshold self.cross_class_nms = cross_class_nms - self.preds: list[list[tuple]] = [] - self.targets: list[list[tuple]] = [] - self.num_classes: int = num_classes + self.label_info: LabelInfo = label_info self._f_measure_per_confidence: dict | None = None self._f_measure_per_nms: dict | None = None - self._best_confidence_threshold: float | None = best_confidence_threshold + self._best_confidence_threshold: float | None = None self._best_nms_threshold: float | None = None + self._f_measure = 0.0 + + self.reset() + + def reset(self) -> None: + """Reset for every validation and test epoch. + + Please be careful that some variables should not be reset for each epoch. + """ + super().reset() + self.preds: list[list[tuple]] = [] + self.targets: list[list[tuple]] = [] def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]]) -> None: """Update total predictions and targets from given batch predicitons and targets.""" @@ -680,8 +689,14 @@ def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]] ], ) - def compute(self) -> dict: - """Compute f1 score metric.""" + def compute(self, best_confidence_threshold: float | None = None) -> dict: + """Compute f1 score metric. + + Args: + best_confidence_threshold (float | None): Pre-defined best confidence threshold. + If this value is None, then FMeasure will find best confidence threshold and + store it as member variable. Defaults to None. + """ boxes_pair = _FMeasureCalculator(self.targets, self.preds) result = boxes_pair.evaluate_detections( result_based_nms_threshold=self.vary_nms_threshold, @@ -690,26 +705,37 @@ def compute(self) -> dict: ) self._f_measure_per_label = {label: result.best_f_measure_per_class[label] for label in self.classes} - if self.best_confidence_threshold is not None: + if best_confidence_threshold is not None: (index,) = np.where( - np.isclose(list(np.arange(*boxes_pair.confidence_range)), self.best_confidence_threshold), + np.isclose(list(np.arange(*boxes_pair.confidence_range)), best_confidence_threshold), ) - self._f_measure = result.per_confidence.all_classes_f_measure_curve[int(index)] + computed_f_measure = result.per_confidence.all_classes_f_measure_curve[int(index)] else: self._f_measure_per_confidence = { "xs": list(np.arange(*boxes_pair.confidence_range)), "ys": result.per_confidence.all_classes_f_measure_curve, } - self._best_confidence_threshold = result.per_confidence.best_threshold + computed_f_measure = result.best_f_measure + best_confidence_threshold = result.per_confidence.best_threshold + + # TODO(jaegukhyun): There was no reset() function in this metric + # There are some variables dependent on the best F1 metric, e.g., best_confidence_threshold + # Now we added reset() function and revise some mechanism about it. However, + # It is still unsure that it is correctly working with the implemented reset function. + # Need to revisit. See other metric implement and this to learn how they work + # https://github.com/Lightning-AI/torchmetrics/blob/v1.2.1/src/torchmetrics/metric.py + if self._f_measure < computed_f_measure: self._f_measure = result.best_f_measure + self._best_confidence_threshold = best_confidence_threshold - if self.vary_nms_threshold and result.per_nms is not None: - self._f_measure_per_nms = { - "xs": list(np.arange(*boxes_pair.nms_range)), - "ys": result.per_nms.all_classes_f_measure_curve, - } - self._best_nms_threshold = result.per_nms.best_threshold - return {"f1-score": Tensor([self.f_measure])} + if self.vary_nms_threshold and result.per_nms is not None: + self._f_measure_per_nms = { + "xs": list(np.arange(*boxes_pair.nms_range)), + "ys": result.per_nms.all_classes_f_measure_curve, + } + self._best_nms_threshold = result.per_nms.best_threshold + + return {"f1-score": Tensor([computed_f_measure])} @property def f_measure(self) -> float: @@ -727,15 +753,16 @@ def f_measure_per_confidence(self) -> None | dict: return self._f_measure_per_confidence @property - def best_confidence_threshold(self) -> None | float: + def best_confidence_threshold(self) -> float: """Returns best confidence threshold as ScoreMetric if exists.""" + if self._best_confidence_threshold is None: + msg = ( + "Cannot obtain best_confidence_threshold updated previously. " + "Please execute self.update(best_confidence_threshold=None) first." + ) + raise RuntimeError(msg) return self._best_confidence_threshold - @best_confidence_threshold.setter - def best_confidence_threshold(self, value: float) -> None: - """Setter for best_confidence_threshold.""" - self._best_confidence_threshold = value - @property def f_measure_per_nms(self) -> None | dict: """Returns the curve for f-measure per nms threshold as CurveMetric if exists.""" @@ -749,7 +776,11 @@ def best_nms_threshold(self) -> None | float: @property def classes(self) -> list[str]: """Class information of dataset.""" - if self.num_classes is None: - msg = "classes is called before num_classes is set." - raise ValueError(msg) - return [str(idx) for idx in range(self.num_classes)] + return self.label_info.label_names + + +def _f_measure_callable(label_info: LabelInfo) -> FMeasure: + return FMeasure(label_info=label_info) + + +FMeasureCallable = _f_measure_callable diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index 5ed32c9525d..ab781c5b3ca 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -6,11 +6,12 @@ from __future__ import annotations import contextlib +import inspect import json import logging import warnings from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Generic, NamedTuple +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple import numpy as np import openvino @@ -290,8 +291,14 @@ def _convert_pred_entity_to_compute_metric( """Convert given inputs to a Python dictionary for the metric computation.""" raise NotImplementedError - def _log_metrics(self, meter: Metric, key: str) -> None: - results: dict[str, Tensor] = meter.compute() + def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None: + sig = inspect.signature(meter.compute) + filtered_kwargs = {key: value for key, value in compute_kwargs.items() if key in sig.parameters} + if removed_kwargs := set(compute_kwargs.keys()).difference(filtered_kwargs.keys()): + msg = f"These keyword arguments are removed since they are not in the function signature: {removed_kwargs}" + logger.debug(msg) + + results: dict[str, Tensor] = meter.compute(**filtered_kwargs) if not isinstance(results, dict): raise TypeError(results) diff --git a/src/otx/core/model/detection.py b/src/otx/core/model/detection.py index 5f7a02c2f42..a29b3432f06 100644 --- a/src/otx/core/model/detection.py +++ b/src/otx/core/model/detection.py @@ -6,10 +6,9 @@ from __future__ import annotations import copy -import json import logging as log import types -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal import torch from openvino.model_api.models import Model @@ -62,7 +61,6 @@ def __init__( torch_compile=torch_compile, ) self.tile_config = TileConfig() - self.test_meta_info: dict[str, Any] = {} def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity | DetBatchPredEntityWithXAI: """Unpack detection tiles. @@ -106,9 +104,10 @@ def _export_parameters(self) -> dict[str, Any]: { ("model_info", "model_type"): "ssd", ("model_info", "task_type"): "detection", - ("model_info", "confidence_threshold"): str(0.0), # it was able to be set in OTX 1.X + ("model_info", "confidence_threshold"): str( + self.hparams.get("best_confidence_threshold", 0.0), + ), # it was able to be set in OTX 1.X ("model_info", "iou_threshold"): str(0.5), - ("model_info", "test_meta_info"): json.dumps(self.test_meta_info), }, ) if self.tile_config.enable_tiler: @@ -155,25 +154,33 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: For detection, it is need to update confidence threshold information when the metric is FMeasure. """ - if confidence_threshold := ckpt.get("confidence_threshold", None) or ( + if best_confidence_threshold := ckpt.get("confidence_threshold", None) or ( (hyper_parameters := ckpt.get("hyper_parameters", None)) - and (confidence_threshold := hyper_parameters.get("confidence_threshold", None)) + and (best_confidence_threshold := hyper_parameters.get("best_confidence_threshold", None)) ): - self.test_meta_info["best_confidence_threshold"] = confidence_threshold - self.test_meta_info["vary_confidence_threshold"] = False + self.hparams["best_confidence_threshold"] = best_confidence_threshold super().load_state_dict(ckpt, *args, **kwargs) - def configure_metric(self) -> None: - """Configure the metric.""" - super().configure_metric() - for key, value in self.test_meta_info.items(): - if hasattr(self.metric, key): - setattr(self.metric, key, value) + def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None: + if key == "val": + retval = super()._log_metrics(meter, key) - def _log_metrics(self, meter: Metric, key: str) -> None: - super()._log_metrics(meter, key) - if hasattr(meter, "best_confidence_threshold"): - self.hparams["confidence_threshold"] = meter.best_confidence_threshold + # NOTE: Validation metric logging can update `best_confidence_threshold` + if best_confidence_threshold := getattr(meter, "best_confidence_threshold", None): + self.hparams["best_confidence_threshold"] = best_confidence_threshold + + return retval + + if key == "test": + # NOTE: Test metric logging should use `best_confidence_threshold` found previously. + best_confidence_threshold = self.hparams.get("best_confidence_threshold", None) + compute_kwargs = ( + {"best_confidence_threshold": best_confidence_threshold} if best_confidence_threshold else {} + ) + + return super()._log_metrics(meter, key, **compute_kwargs) + + raise ValueError(key) class ExplainableOTXDetModel(OTXDetectionModel): @@ -503,7 +510,6 @@ def __init__( metric: MetricCallable = MeanAPCallable, **kwargs, ) -> None: - self.test_meta_info: dict[str, Any] = {} super().__init__( model_name=model_name, model_type=model_type, @@ -540,10 +546,20 @@ def _create_model(self) -> Model: plugin_config=plugin_config, model_parameters=self.model_adapter_parameters, ) - for name, info in model_adapter.model.rt_info["model_info"].items(): - if name == "test_meta_info": - for key, value in json.loads(info.value).items(): - self.test_meta_info[key] = value + + if model_adapter.model.has_rt_info(["model_info", "confidence_threshold"]): + best_confidence_threshold = model_adapter.model.get_rt_info(["model_info", "confidence_threshold"]).value + self.hparams["best_confidence_threshold"] = best_confidence_threshold + else: + msg = ( + "Cannot get best_confidence_threshold from OpenVINO IR's rt_info. " + "Please check whether this model is trained by OTX or not. " + "Without this information, it can produce a wrong F1 metric score. " + "At this time, it will be set as the default value = 0.0." + ) + log.warning(msg) + self.hparams["best_confidence_threshold"] = 0.0 + return Model.create_model(model_adapter, model_type=self.model_type, configuration=self.model_api_configuration) def _customize_outputs( @@ -635,3 +651,8 @@ def _convert_pred_entity_to_compute_metric( for bboxes, labels in zip(inputs.bboxes, inputs.labels) ], } + + def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None: + best_confidence_threshold = self.hparams.get("best_confidence_threshold", 0.0) + compute_kwargs = {"best_confidence_threshold": best_confidence_threshold} + return super()._log_metrics(meter, key, **compute_kwargs) diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index 0592c2bea0e..491877f9fd6 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -5,11 +5,10 @@ from __future__ import annotations -import json import logging as log import types from copy import copy -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal import numpy as np import torch @@ -46,6 +45,7 @@ from omegaconf import DictConfig from openvino.model_api.models.utils import InstanceSegmentationResult from torch import nn + from torchmetrics import Metric from otx.core.metrics import MetricCallable @@ -76,7 +76,6 @@ def __init__( torch_compile=torch_compile, ) self.tile_config = TileConfig() - self.test_meta_info: dict[str, Any] = {} def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchPredEntity: """Unpack instance segmentation tiles. @@ -122,9 +121,10 @@ def _export_parameters(self) -> dict[str, Any]: { ("model_info", "model_type"): "MaskRCNN", ("model_info", "task_type"): "instance_segmentation", - ("model_info", "confidence_threshold"): str(0.0), # it was able to be set in OTX 1.X + ("model_info", "confidence_threshold"): str( + self.hparams.get("best_confidence_threshold", 0.0), + ), # it was able to be set in OTX 1.X ("model_info", "iou_threshold"): str(0.5), - ("model_info", "test_meta_info"): json.dumps(self.test_meta_info), }, ) @@ -155,20 +155,33 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: For detection, it is need to update confidence threshold information when the metric is FMeasure. """ - if "confidence_threshold" in ckpt: - self.test_meta_info["best_confidence_threshold"] = ckpt["confidence_threshold"] - self.test_meta_info["vary_confidence_threshold"] = False - elif "confidence_threshold" in ckpt["hyper_parameters"]: - self.test_meta_info["best_confidence_threshold"] = ckpt["hyper_parameters"]["confidence_threshold"] - self.test_meta_info["vary_confidence_threshold"] = False + if best_confidence_threshold := ckpt.get("confidence_threshold", None) or ( + (hyper_parameters := ckpt.get("hyper_parameters", None)) + and (best_confidence_threshold := hyper_parameters.get("best_confidence_threshold", None)) + ): + self.hparams["best_confidence_threshold"] = best_confidence_threshold super().load_state_dict(ckpt, *args, **kwargs) - def configure_metric(self) -> None: - """Configure the metric.""" - super().configure_metric() - for key, value in self.test_meta_info.items(): - if hasattr(self.metric, key): - setattr(self.metric, key, value) + def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None: + if key == "val": + retval = super()._log_metrics(meter, key) + + # NOTE: Validation metric logging can update `best_confidence_threshold` + if best_confidence_threshold := getattr(meter, "best_confidence_threshold", None): + self.hparams["best_confidence_threshold"] = best_confidence_threshold + + return retval + + if key == "test": + # NOTE: Test metric logging should use `best_confidence_threshold` found previously. + best_confidence_threshold = self.hparams.get("best_confidence_threshold", None) + compute_kwargs = ( + {"best_confidence_threshold": best_confidence_threshold} if best_confidence_threshold else {} + ) + + return super()._log_metrics(meter, key, **compute_kwargs) + + raise ValueError(key) def _convert_pred_entity_to_compute_metric( self, @@ -549,7 +562,6 @@ def __init__( metric: MetricCallable = MaskRLEMeanAPCallable, **kwargs, ) -> None: - self.test_meta_info: dict[str, Any] = {} super().__init__( model_name=model_name, model_type=model_type, @@ -586,10 +598,20 @@ def _create_model(self) -> Model: plugin_config=plugin_config, model_parameters=self.model_adapter_parameters, ) - for name, info in model_adapter.model.rt_info["model_info"].items(): - if name == "test_meta_info": - for key, value in json.loads(info.value).items(): - self.test_meta_info[key] = value + + if model_adapter.model.has_rt_info(["model_info", "confidence_threshold"]): + best_confidence_threshold = model_adapter.model.get_rt_info(["model_info", "confidence_threshold"]).value + self.hparams["best_confidence_threshold"] = best_confidence_threshold + else: + msg = ( + "Cannot get best_confidence_threshold from OpenVINO IR's rt_info. " + "Please check whether this model is trained by OTX or not. " + "Without this information, it can produce a wrong F1 metric score. " + "At this time, it will be set as the default value = 0.0." + ) + log.warning(msg) + self.hparams["best_confidence_threshold"] = 0.0 + return Model.create_model(model_adapter, model_type=self.model_type, configuration=self.model_api_configuration) def _customize_outputs( @@ -709,3 +731,8 @@ def _convert_pred_entity_to_compute_metric( }, ) return {"preds": pred_info, "target": target_info} + + def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None: + best_confidence_threshold = self.hparams.get("best_confidence_threshold", 0.0) + compute_kwargs = {"best_confidence_threshold": best_confidence_threshold} + return super()._log_metrics(meter, key, **compute_kwargs) diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 8c228133bd5..44d757630b5 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -7,8 +7,9 @@ import inspect import logging +from contextlib import contextmanager from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Literal +from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Iterator, Literal from warnings import warn import torch @@ -35,7 +36,6 @@ from lightning.pytorch.loggers import Logger from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from pytorch_lightning.trainer.connectors.accelerator_connector import _PRECISION_INPUT - from torchmetrics import Metric from otx.core.metrics import MetricCallable @@ -55,6 +55,24 @@ } +@contextmanager +def override_metric_callable(model: OTXModel, new_metric_callable: MetricCallable | None) -> Iterator[OTXModel]: + """Override `OTXModel.metric_callable` to change the evaluation metric. + + Args: + model: Model to override its metric callable + new_metric_callable: If not None, override the model's one with this. Otherwise, do not override. + """ + if new_metric_callable is None: + yield model + return + + orig_metric_callable = model.metric_callable + model.metric_callable = new_metric_callable + yield model + model.metric_callable = orig_metric_callable + + class Engine: """OTX Engine. @@ -166,7 +184,7 @@ def train( callbacks: list[Callback] | Callback | None = None, logger: Logger | Iterable[Logger] | bool | None = None, resume: bool = False, - metric: Metric | MetricCallable | None = None, + metric: MetricCallable | None = None, run_hpo: bool = False, hpo_config: HpoConfig | None = None, **kwargs, @@ -184,8 +202,8 @@ def train( callbacks (list[Callback] | Callback | None, optional): The callbacks to be used during training. logger (Logger | Iterable[Logger] | bool | None, optional): The logger(s) to be used. Defaults to None. resume (bool, optional): If True, tries to resume training from existing checkpoint. - metric (Metric | MetricCallable | None): The metric for the validation and test. - It could be None at export, predict, etc. + metric (MetricCallable | None): If not None, it will override `OTXModel.metric_callable` with the given + metric callable. It will temporarilly change the evaluation metric for the validation and test. run_hpo (bool, optional): If True, optimizer hyper parameters before training a model. hpo_config (HpoConfig | None, optional): Configuration for HPO. **kwargs: Additional keyword arguments for pl.Trainer configuration. @@ -223,7 +241,6 @@ def train( otx train --data_root --config ``` """ - metric = metric if metric is not None else self._auto_configurator.get_metric() if run_hpo: if hpo_config is None: hpo_config = HpoConfig() @@ -263,19 +280,28 @@ def train( logging.warning(msg) self.model.label_info = self.datamodule.label_info - self.trainer.fit( - model=self.model, - datamodule=self.datamodule, - **fit_kwargs, - ) + with override_metric_callable(model=self.model, new_metric_callable=metric) as model: + self.trainer.fit( + model=model, + datamodule=self.datamodule, + **fit_kwargs, + ) self.checkpoint = self.trainer.checkpoint_callback.best_model_path + + if not isinstance(self.checkpoint, (Path, str)): + msg = "self.checkpoint should be Path or str at this time." + raise TypeError(msg) + + best_checkpoint_symlink = Path(self.work_dir) / "best_checkpoint.ckpt" + best_checkpoint_symlink.symlink_to(self.checkpoint) + return self.trainer.callback_metrics def test( self, checkpoint: PathLike | None = None, datamodule: EVAL_DATALOADERS | OTXDataModule | None = None, - metric: Metric | MetricCallable | None = None, + metric: MetricCallable | None = None, **kwargs, ) -> dict: """Run the testing phase of the engine. @@ -284,8 +310,8 @@ def test( datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module containing the test data. checkpoint (PathLike | None, optional): Path to the checkpoint file to load the model from. Defaults to None. - metric (Metric | MetricCallable | None): The metric for the validation and test. - It could be None at export, predict, etc. + metric (MetricCallable | None): If not None, it will override `OTXModel.metric_callable` with the given + metric callable. It will temporarilly change the evaluation metric for the validation and test. **kwargs: Additional keyword arguments for pl.Trainer configuration. Returns: @@ -318,8 +344,6 @@ def test( datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test") model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info) - metric = metric if metric is not None else self._auto_configurator.get_metric() - # NOTE, trainer.test takes only lightning based checkpoint. # So, it can't take the OTX1.x checkpoint. if checkpoint is not None and not is_ir_ckpt: @@ -340,10 +364,11 @@ def test( # TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test # raise ValueError() - self.trainer.test( - model=model, - dataloaders=datamodule, - ) + with override_metric_callable(model=model, new_metric_callable=metric) as model: + self.trainer.test( + model=model, + dataloaders=datamodule, + ) return self.trainer.callback_metrics diff --git a/tests/integration/cli/test_auto_configuration.py b/tests/integration/cli/test_auto_configuration.py index 069e54ca2c6..61fe8a2cfa8 100644 --- a/tests/integration/cli/test_auto_configuration.py +++ b/tests/integration/cli/test_auto_configuration.py @@ -8,7 +8,7 @@ from otx.core.types.task import OTXTaskType from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK -from tests.integration.cli.utils import run_main +from tests.utils import run_main @pytest.mark.parametrize("task", pytest.TASK_LIST) diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index 7d8618ddf27..166b2b250f6 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -7,9 +7,10 @@ import numpy as np import pytest import yaml +from otx.core.types.task import OTXTaskType from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK -from tests.integration.cli.utils import run_main +from tests.utils import run_main @pytest.mark.parametrize( @@ -441,7 +442,7 @@ def test_otx_ov_test( @pytest.mark.parametrize("task", pytest.TASK_LIST) def test_otx_hpo_e2e( - task: str, + task: OTXTaskType, tmp_path: Path, fxt_accelerator: str, fxt_target_dataset_per_task: dict, @@ -462,7 +463,23 @@ def test_otx_hpo_e2e( pytest.xfail(reason="xFail until this root cause is resolved on the Datumaro side.") if task not in DEFAULT_CONFIG_PER_TASK: pytest.skip(f"Task {task} is not supported in the auto-configuration.") - + if task.lower().startswith("anomaly_"): + pytest.xfail( + reason="""This will be fixed soon +│ /home/vinnamki/otx/training_extensions/src/otx/engine/hpo/hpo_api.py:137 in │ +│ hpo_config │ +│ │ +│ 134 │ @hpo_config.setter │ +│ 135 │ def hpo_config(self, hpo_config: HpoConfig | None) -> None: │ +│ 136 │ │ train_dataset_size = len(self._engine.datamodule.subsets["trai │ +│ ❱ 137 │ │ val_dataset_size = len(self._engine.datamodule.subsets["val"]) │ +│ 138 │ │ │ +│ 139 │ │ self._hpo_config: dict[str, Any] = { # default setting │ +│ 140 │ │ │ "save_path": str(self._hpo_workdir), │ +╰──────────────────────────────────────────────────────────────────────────────╯ +KeyError: 'val' + """, + ) task = task.lower() tmp_path_hpo = tmp_path / f"otx_hpo_{task}" tmp_path_hpo.mkdir(parents=True) diff --git a/tests/integration/cli/test_export_inference.py b/tests/integration/cli/test_export_inference.py index 919d46035a7..2bc43e587ba 100644 --- a/tests/integration/cli/test_export_inference.py +++ b/tests/integration/cli/test_export_inference.py @@ -8,7 +8,7 @@ import pandas as pd import pytest -from tests.integration.cli.utils import run_main +from tests.utils import run_main log = logging.getLogger(__name__) diff --git a/tests/perf/test_detection.py b/tests/perf/test_detection.py index f048eb7c357..67617c52ca6 100644 --- a/tests/perf/test_detection.py +++ b/tests/perf/test_detection.py @@ -34,7 +34,7 @@ class TestPerfObjectDetection(PerfTestBase): num_repeat=5, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, @@ -48,7 +48,7 @@ class TestPerfObjectDetection(PerfTestBase): num_repeat=5, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, @@ -60,7 +60,7 @@ class TestPerfObjectDetection(PerfTestBase): num_repeat=5, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, diff --git a/tests/perf/test_instance_segmentation.py b/tests/perf/test_instance_segmentation.py index 2711ea961f4..9a65157d0ce 100644 --- a/tests/perf/test_instance_segmentation.py +++ b/tests/perf/test_instance_segmentation.py @@ -30,7 +30,7 @@ class TestPerfInstanceSegmentation(PerfTestBase): num_repeat=5, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, @@ -44,7 +44,7 @@ class TestPerfInstanceSegmentation(PerfTestBase): num_repeat=5, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, @@ -56,7 +56,7 @@ class TestPerfInstanceSegmentation(PerfTestBase): num_repeat=5, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, @@ -119,7 +119,7 @@ class TestPerfTilingInstanceSegmentation(PerfTestBase): num_repeat=5, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, @@ -133,7 +133,7 @@ class TestPerfTilingInstanceSegmentation(PerfTestBase): num_repeat=5, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, diff --git a/tests/regression/test_regression.py b/tests/regression/test_regression.py index 781d856387f..ff485a93742 100644 --- a/tests/regression/test_regression.py +++ b/tests/regression/test_regression.py @@ -8,10 +8,9 @@ from pathlib import Path import pytest -from otx.cli.cli import OTXCLI -from unittest.mock import patch - +from tests.utils import run_main import mlflow +import pandas as pd @dataclass @@ -69,13 +68,20 @@ def _test_regression( ) with mlflow.start_run(tags=tags, run_name=run_name): command_cfg = [ - "otx", "train", - "--config", f"src/otx/recipe/{test_case.model.task}/{test_case.model.name}.yaml", - "--model.num_classes", str(test_case.dataset.num_classes), - "--data_root", str(data_root), - "--data.config.data_format", test_case.dataset.data_format, - "--work_dir", str(test_case.output_dir), - "--engine.device", fxt_accelerator, + "otx", + "train", + "--config", + f"src/otx/recipe/{test_case.model.task}/{test_case.model.name}.yaml", + "--model.num_classes", + str(test_case.dataset.num_classes), + "--data_root", + str(data_root), + "--data.config.data_format", + test_case.dataset.data_format, + "--work_dir", + str(test_case.output_dir), + "--engine.device", + fxt_accelerator, ] deterministic = test_case.dataset.extra_overrides.pop("deterministic", "False") for key, value in test_case.dataset.extra_overrides.items(): @@ -84,19 +90,40 @@ def _test_regression( train_cfg = command_cfg.copy() train_cfg.extend(["--seed", str(seed)]) train_cfg.extend(["--deterministic", deterministic]) - with patch("sys.argv", train_cfg): - cli = OTXCLI() - train_metrics = cli.engine.trainer.callback_metrics - checkpoint = cli.engine.checkpoint - command_cfg[1] = "test" - command_cfg += ["--checkpoint", checkpoint] - with patch("sys.argv", command_cfg): - cli = OTXCLI() - test_metrics = cli.engine.trainer.callback_metrics - metrics = {**train_metrics, **test_metrics} + + run_main(command_cfg=train_cfg, open_subprocess=True) + checkpoint = test_case.output_dir / ".latest" / "train" / "best_checkpoint.ckpt" + assert checkpoint.exists() + + test_cfg = command_cfg.copy() + test_cfg[1] = "test" + test_cfg += ["--checkpoint", str(checkpoint)] + + # TODO(harimkang): This command cannot create `metrics.csv`` file under test output directory + # Without fixing this, we cannot submit the test metrics from the csv logged file + run_main(command_cfg=test_cfg, open_subprocess=True) + + # This is also not working. It produces an empty dictionary for test_metrics = {} + # with patch("sys.argv", test_cfg): + # cli = OTXCLI() + # test_metrics = cli.engine.trainer.callback_metrics + # mlflow.log_metrics(test_metrics) # Submit metrics to MLFlow Tracker server - mlflow.log_metrics(metrics) + for metric_csv_file in test_case.output_dir.glob("**/metrics.csv"): + self._submit_metric(metric_csv_file) + + def _submit_metric(self, metric_csv_file: Path) -> None: + df = pd.read_csv(metric_csv_file) + for step, sub_df in df.groupby("step"): + sub_df = sub_df.drop("step", axis=1) + + for _, row in sub_df.iterrows(): + row = row.dropna() + metrics = row.to_dict() + mlflow.log_metrics(metrics=metrics, step=step) + + mlflow.log_artifact(local_path=str(metric_csv_file), artifact_path="metrics") class TestMultiClassCls(BaseTest): @@ -118,7 +145,7 @@ class TestMultiClassCls(BaseTest): extra_overrides={ "deterministic": "True", "metric": "otx.core.metrics.accuracy.MulticlassAccuracywithLabelGroup", - } + }, ) for idx in range(1, 4) ] + [ @@ -130,7 +157,7 @@ class TestMultiClassCls(BaseTest): extra_overrides={ "deterministic": "True", "metric": "otx.core.metrics.accuracy.MulticlassAccuracywithLabelGroup", - } + }, ), DatasetTestCase( name=f"multiclass_food101_large", @@ -140,8 +167,8 @@ class TestMultiClassCls(BaseTest): extra_overrides={ "deterministic": "True", "metric": "otx.core.metrics.accuracy.MulticlassAccuracywithLabelGroup", - } - ) + }, + ), ] @pytest.mark.parametrize( @@ -193,7 +220,7 @@ class TestMultilabelCls(BaseTest): extra_overrides={ "deterministic": "True", "metric": "otx.core.metrics.accuracy.MultilabelAccuracywithLabelGroup", - } + }, ) for idx in range(1, 4) ] + [ @@ -205,7 +232,7 @@ class TestMultilabelCls(BaseTest): extra_overrides={ "deterministic": "True", "metric": "otx.core.metrics.accuracy.MultilabelAccuracywithLabelGroup", - } + }, ), DatasetTestCase( name=f"multilabel_food101_large", @@ -215,8 +242,8 @@ class TestMultilabelCls(BaseTest): extra_overrides={ "deterministic": "True", "metric": "otx.core.metrics.accuracy.MultilabelAccuracywithLabelGroup", - } - ) + }, + ), ] @pytest.mark.parametrize( @@ -282,7 +309,6 @@ class TestHlabelCls(BaseTest): "metric": "otx.core.metrics.accuracy.HlabelAccuracy", }, ) - ] @pytest.mark.parametrize( @@ -336,7 +362,7 @@ class TestObjectDetection(BaseTest): num_classes=1, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, @@ -350,7 +376,7 @@ class TestObjectDetection(BaseTest): num_classes=1, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, @@ -362,11 +388,11 @@ class TestObjectDetection(BaseTest): num_classes=1, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, - ) + ), ] @pytest.mark.parametrize( @@ -399,6 +425,7 @@ def test_regression( tmpdir=tmpdir, ) + class TestSemanticSegmentation(BaseTest): # Test case parametrization for model MODEL_TEST_CASES = [ # noqa: RUF012 @@ -434,7 +461,7 @@ class TestSemanticSegmentation(BaseTest): data_format="common_semantic_segmentation_with_subset_dirs", num_classes=2, extra_overrides={}, - ) + ), ] @pytest.mark.parametrize( @@ -467,6 +494,7 @@ def test_regression( tmpdir=tmpdir, ) + class TestInstanceSegmentation(BaseTest): # Test case parametrization for model MODEL_TEST_CASES = [ # noqa: RUF012 @@ -483,7 +511,7 @@ class TestInstanceSegmentation(BaseTest): num_classes=5, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, @@ -497,7 +525,7 @@ class TestInstanceSegmentation(BaseTest): num_classes=2, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, @@ -509,11 +537,11 @@ class TestInstanceSegmentation(BaseTest): num_classes=1, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, - ) + ), ] @pytest.mark.parametrize( @@ -577,7 +605,7 @@ class TestVisualPrompting(BaseTest): data_format="coco", num_classes=1, extra_overrides={"deterministic": "warn"}, - ) + ), ] @pytest.mark.parametrize( @@ -624,10 +652,7 @@ class TestZeroShotVisualPrompting(BaseTest): data_root=Path("zero_shot_visual_prompting/coco_car_person_medium_datumaro"), data_format="datumaro", num_classes=2, - extra_overrides={ - "max_epochs": "1", - "deterministic": "warn" - } + extra_overrides={"max_epochs": "1", "deterministic": "warn"}, ), ] @@ -681,7 +706,7 @@ class TestTileObjectDetection(BaseTest): num_classes=1, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, @@ -693,11 +718,11 @@ class TestTileObjectDetection(BaseTest): num_classes=1, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, - ) + ), ] @pytest.mark.parametrize( @@ -747,7 +772,7 @@ class TestTileInstanceSegmentation(BaseTest): num_classes=1, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, @@ -759,11 +784,11 @@ class TestTileInstanceSegmentation(BaseTest): num_classes=1, extra_overrides={ "deterministic": "True", - "metric": "otx.core.metrics.fmeasure.FMeasure", + "metric": "otx.core.metrics.fmeasure.FMeasureCallable", "callback_monitor": "val/f1-score", "scheduler.monitor": "val/f1-score", }, - ) + ), ] @pytest.mark.parametrize( @@ -809,14 +834,14 @@ class TestActionClassification(BaseTest): data_root=Path("action_classification/ucf-kinetics-5percent"), data_format="kinetics", num_classes=101, - extra_overrides={"max_epochs": "10", "deterministic": "True"} + extra_overrides={"max_epochs": "10", "deterministic": "True"}, ), DatasetTestCase( name="ucf-30percent", data_root=Path("action_classification/ucf-kinetics-30percent"), data_format="kinetics", num_classes=101, - extra_overrides={"max_epochs": "10", "deterministic": "True"} + extra_overrides={"max_epochs": "10", "deterministic": "True"}, ), ] diff --git a/tests/unit/cli/test_cli.py b/tests/unit/cli/test_cli.py index 440d30f3b01..4db177f30d9 100644 --- a/tests/unit/cli/test_cli.py +++ b/tests/unit/cli/test_cli.py @@ -113,6 +113,15 @@ def test_instantiate_classes(self, fxt_train_command, mocker) -> None: assert cli.datamodule == cli.engine.datamodule assert cli.model == cli.engine.model + def test_raise_error_correctly(self, fxt_train_command, mocker) -> None: + mock_engine = mocker.patch("otx.cli.OTXCLI.instantiate_engine") + mock_engine.return_value.train.side_effect = RuntimeError("my_error") + + with pytest.raises(RuntimeError) as exc_info: + OTXCLI() + + exc_info.match("my_error") + @pytest.fixture() def fxt_print_config_scheduler_override_command(self, monkeypatch) -> None: argv = [ @@ -166,7 +175,7 @@ def fxt_metric_override_command(self, monkeypatch) -> None: "--data_root", "tests/assets/car_tree_bug", "--metric", - "otx.core.metrics.fmeasure.FMeasure", + "otx.core.metrics.fmeasure.FMeasureCallable", "--print_config", ] monkeypatch.setattr("sys.argv", argv) @@ -177,9 +186,4 @@ def test_print_metric_override_command(self, fxt_metric_override_command, capfd) OTXCLI() out, _ = capfd.readouterr() result_config = yaml.safe_load(out) - expected_str = """ - metric: - - class_path: otx.core.metrics.fmeasure.FMeasure - """ - expected_config = yaml.safe_load(expected_str) - assert expected_config["metric"][0]["class_path"] == result_config["metric"]["class_path"] + assert result_config["metric"] == "otx.core.metrics.fmeasure._f_measure_callable" diff --git a/tests/unit/core/metrics/test_fmeasure.py b/tests/unit/core/metrics/test_fmeasure.py index 0f364fa1654..b934275b827 100644 --- a/tests/unit/core/metrics/test_fmeasure.py +++ b/tests/unit/core/metrics/test_fmeasure.py @@ -8,6 +8,7 @@ import pytest import torch from otx.core.metrics.fmeasure import FMeasure +from otx.core.types.label import LabelInfo class TestFMeasure: @@ -41,16 +42,26 @@ def fxt_targets(self) -> list[dict[str, torch.Tensor]]: def test_fmeasure(self, fxt_preds, fxt_targets) -> None: """Check whether f1 score is same with OTX1.x version.""" - metric = FMeasure(num_classes=1) + metric = FMeasure(label_info=LabelInfo.from_num_classes(1)) metric.update(fxt_preds, fxt_targets) result = metric.compute() assert result["f1-score"] == 0.5 + best_confidence_threshold = metric.best_confidence_threshold + assert isinstance(best_confidence_threshold, float) + + metric.reset() + assert metric.preds == [] + assert metric.targets == [] + + # TODO(jaegukhyun): Add the following scenario + # 1. Prepare preds and targets which can produce f1-score < 0.5 + # 2. Execute metric.compute() + # 3. Assert best_confidence_threshold == metric.best_confidence_threshold def test_fmeasure_with_fixed_threshold(self, fxt_preds, fxt_targets) -> None: """Check fmeasure can compute f1 score given confidence threshold.""" - metric = FMeasure(num_classes=1) + metric = FMeasure(label_info=LabelInfo.from_num_classes(1)) - metric.best_confidence_threshold = 0.85 metric.update(fxt_preds, fxt_targets) - result = metric.compute() + result = metric.compute(best_confidence_threshold=0.85) assert result["f1-score"] == 0.3333333432674408 diff --git a/tests/unit/core/model/test_detection.py b/tests/unit/core/model/test_detection.py index 860fb52ed8f..786fa47aa09 100644 --- a/tests/unit/core/model/test_detection.py +++ b/tests/unit/core/model/test_detection.py @@ -5,13 +5,12 @@ from __future__ import annotations -from functools import partial from unittest.mock import create_autospec import pytest from lightning.pytorch.cli import ReduceLROnPlateau from otx.algo.schedulers.warmup_schedulers import LinearWarmupScheduler -from otx.core.metrics.fmeasure import FMeasure +from otx.core.metrics.fmeasure import FMeasureCallable from otx.core.model.detection import OTXDetectionModel from torch.optim import Optimizer @@ -32,7 +31,7 @@ def mock_scheduler(self) -> list[LinearWarmupScheduler | ReduceLROnPlateau]: "state_dict": {}, }, { - "hyper_parameters": {"confidence_threshold": 0.35}, + "hyper_parameters": {"best_confidence_threshold": 0.35}, "state_dict": {}, }, ], @@ -52,12 +51,9 @@ def test_configure_metric_with_ckpt( torch_compile=False, optimizer=mock_optimizer, scheduler=mock_scheduler, - metric=partial(FMeasure), + metric=FMeasureCallable, ) model.load_state_dict(mock_ckpt) - assert model.test_meta_info["best_confidence_threshold"] == 0.35 - - model.configure_metric() - assert model.metric.best_confidence_threshold == 0.35 + assert model.hparams["best_confidence_threshold"] == 0.35 diff --git a/tests/integration/cli/utils.py b/tests/utils.py similarity index 62% rename from tests/integration/cli/utils.py rename to tests/utils.py index 749de187250..2d72c6e7a3f 100644 --- a/tests/integration/cli/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ from __future__ import annotations +import logging import subprocess import sys from unittest.mock import patch @@ -18,14 +19,17 @@ def run_main(command_cfg: list[str], open_subprocess: bool) -> None: def _run_main_with_open_subprocess(command_cfg) -> None: - completed = subprocess.run( - [sys.executable, __file__, *command_cfg], # noqa: S603 - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - check=True, - ) - - completed.check_returncode() + try: + subprocess.run( + [sys.executable, __file__, *command_cfg], # noqa: S603 + capture_output=True, + check=True, + ) + except subprocess.CalledProcessError as exc: + stderr = exc.stderr.decode() + msg = f"Fail to run main: stderr={stderr}" + logging.exception(msg) + raise def _run_main(command_cfg) -> None: