diff --git a/src/otx/algo/anomaly/padim.py b/src/otx/algo/anomaly/padim.py index 501f0c1a121..ede378ab365 100644 --- a/src/otx/algo/anomaly/padim.py +++ b/src/otx/algo/anomaly/padim.py @@ -9,23 +9,17 @@ from typing import TYPE_CHECKING, Literal -from anomalib.callbacks.normalization.min_max_normalization import _MinMaxNormalizationCallback -from anomalib.callbacks.post_processor import _PostProcessorCallback from anomalib.models.image import Padim as AnomalibPadim -from otx.core.model.anomaly import OTXAnomaly +from otx.core.model.anomaly import AnomalyMixin, OTXAnomaly from otx.core.types.label import AnomalyLabelInfo from otx.core.types.task import OTXTaskType if TYPE_CHECKING: - from lightning.pytorch.utilities.types import STEP_OUTPUT - from torch.optim.optimizer import Optimizer - - from otx.core.model.anomaly import AnomalyModelInputs, AnomalyModelOutputs from otx.core.types.label import LabelInfoTypes -class Padim(OTXAnomaly, AnomalibPadim): +class Padim(AnomalyMixin, AnomalibPadim, OTXAnomaly): """OTX Padim model. Args: @@ -55,100 +49,11 @@ def __init__( ] = OTXTaskType.ANOMALY_CLASSIFICATION, input_size: tuple[int, int] = (256, 256), ) -> None: - OTXAnomaly.__init__(self, label_info=label_info, input_size=input_size) - AnomalibPadim.__init__( - self, + self.input_size = input_size + self.task = OTXTaskType(task) + super().__init__( backbone=backbone, layers=layers, pre_trained=pre_trained, n_features=n_features, ) - self.task = task - self.input_size = input_size - - def configure_optimizers(self) -> tuple[list[Optimizer], list[Optimizer]] | None: - """PADIM doesn't require optimization, therefore returns no optimizers.""" - return - - def configure_metric(self) -> None: - """This does not follow OTX metric configuration.""" - return - - def on_train_epoch_end(self) -> None: - """Callback triggered when the training epoch ends.""" - return AnomalibPadim.on_train_epoch_end(self) - - def on_validation_start(self) -> None: - """Callback triggered when the validation starts.""" - return AnomalibPadim.on_validation_start(self) - - def on_validation_epoch_start(self) -> None: - """Callback triggered when the validation epoch starts.""" - AnomalibPadim.on_validation_epoch_start(self) - - def on_test_epoch_start(self) -> None: - """Callback triggered when the test epoch starts.""" - AnomalibPadim.on_test_epoch_start(self) - - def on_validation_epoch_end(self) -> None: - """Callback triggered when the validation epoch ends.""" - AnomalibPadim.on_validation_epoch_end(self) - - def on_test_epoch_end(self) -> None: - """Callback triggered when the test epoch ends.""" - AnomalibPadim.on_test_epoch_end(self) - - def training_step( - self, - inputs: AnomalyModelInputs, - batch_idx: int = 0, - ) -> STEP_OUTPUT: - """Call training step of the anomalib model.""" - if not isinstance(inputs, dict): - inputs = self._customize_inputs(inputs) - return AnomalibPadim.training_step(self, inputs, batch_idx) # type: ignore[misc] - - def validation_step( - self, - inputs: AnomalyModelInputs, - batch_idx: int = 0, - ) -> STEP_OUTPUT: - """Call validation step of the anomalib model.""" - if not isinstance(inputs, dict): - inputs = self._customize_inputs(inputs) - return AnomalibPadim.validation_step(self, inputs, batch_idx) # type: ignore[misc] - - def test_step( - self, - inputs: AnomalyModelInputs, - batch_idx: int = 0, - **kwargs, - ) -> STEP_OUTPUT: - """Call test step of the anomalib model.""" - if not isinstance(inputs, dict): - inputs = self._customize_inputs(inputs) - return AnomalibPadim.test_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc] - - def predict_step( - self, - inputs: AnomalyModelInputs, - batch_idx: int = 0, - **kwargs, - ) -> STEP_OUTPUT: - """Call test step of the anomalib model.""" - if not isinstance(inputs, dict): - inputs = self._customize_inputs(inputs) - return AnomalibPadim.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc] - - def forward( - self, - inputs: AnomalyModelInputs, - ) -> AnomalyModelOutputs: - """Wrap forward method of the Anomalib model.""" - outputs = self.validation_step(inputs) - # TODO(Ashwin): update forward implementation to comply with other OTX models - _PostProcessorCallback._post_process(outputs) # noqa: SLF001 - _PostProcessorCallback._compute_scores_and_labels(self, outputs) # noqa: SLF001 - _MinMaxNormalizationCallback._normalize_batch(outputs, self) # noqa: SLF001 - - return self._customize_outputs(outputs=outputs, inputs=inputs) diff --git a/src/otx/algo/anomaly/stfpm.py b/src/otx/algo/anomaly/stfpm.py index f75281e3fa9..fcb4b3fa88f 100644 --- a/src/otx/algo/anomaly/stfpm.py +++ b/src/otx/algo/anomaly/stfpm.py @@ -9,23 +9,17 @@ from typing import TYPE_CHECKING, Literal, Sequence -from anomalib.callbacks.normalization.min_max_normalization import _MinMaxNormalizationCallback -from anomalib.callbacks.post_processor import _PostProcessorCallback from anomalib.models.image.stfpm import Stfpm as AnomalibStfpm -from otx.core.model.anomaly import OTXAnomaly +from otx.core.model.anomaly import AnomalyMixin, OTXAnomaly from otx.core.types.label import AnomalyLabelInfo from otx.core.types.task import OTXTaskType if TYPE_CHECKING: - from lightning.pytorch.utilities.types import STEP_OUTPUT - from torch.optim.optimizer import Optimizer - - from otx.core.model.anomaly import AnomalyModelInputs, AnomalyModelOutputs from otx.core.types.label import LabelInfoTypes -class Stfpm(OTXAnomaly, AnomalibStfpm): +class Stfpm(AnomalyMixin, AnomalibStfpm, OTXAnomaly): """OTX STFPM model. Args: @@ -52,95 +46,9 @@ def __init__( input_size: tuple[int, int] = (256, 256), **kwargs, ) -> None: - OTXAnomaly.__init__(self, label_info=label_info, input_size=input_size) - AnomalibStfpm.__init__( - self, + self.input_size = input_size + self.task = OTXTaskType(task) + super().__init__( backbone=backbone, layers=layers, ) - self.task = task - self.input_size = input_size - - @property - def trainable_model(self) -> str: - """Used by configure optimizer.""" - return "student_model" - - def configure_metric(self) -> None: - """This does not follow OTX metric configuration.""" - return - - def configure_optimizers(self) -> tuple[list[Optimizer], list[Optimizer]] | None: - """STFPM does not follow OTX optimizer configuration.""" - return AnomalibStfpm.configure_optimizers(self) - - def on_validation_epoch_start(self) -> None: - """Callback triggered when the validation epoch starts.""" - AnomalibStfpm.on_validation_epoch_start(self) - - def on_test_epoch_start(self) -> None: - """Callback triggered when the test epoch starts.""" - AnomalibStfpm.on_test_epoch_start(self) - - def on_validation_epoch_end(self) -> None: - """Callback triggered when the validation epoch ends.""" - AnomalibStfpm.on_validation_epoch_end(self) - - def on_test_epoch_end(self) -> None: - """Callback triggered when the test epoch ends.""" - AnomalibStfpm.on_test_epoch_end(self) - - def training_step( - self, - inputs: AnomalyModelInputs, - batch_idx: int = 0, - ) -> STEP_OUTPUT: - """Call training step of the anomalib model.""" - if not isinstance(inputs, dict): - inputs = self._customize_inputs(inputs) - return AnomalibStfpm.training_step(self, inputs, batch_idx) # type: ignore[misc] - - def validation_step( - self, - inputs: AnomalyModelInputs, - batch_idx: int = 0, - ) -> STEP_OUTPUT: - """Call validation step of the anomalib model.""" - if not isinstance(inputs, dict): - inputs = self._customize_inputs(inputs) - return AnomalibStfpm.validation_step(self, inputs, batch_idx) # type: ignore[misc] - - def test_step( - self, - inputs: AnomalyModelInputs, - batch_idx: int = 0, - **kwargs, - ) -> STEP_OUTPUT: - """Call test step of the anomalib model.""" - if not isinstance(inputs, dict): - inputs = self._customize_inputs(inputs) - return AnomalibStfpm.test_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc] - - def predict_step( - self, - inputs: AnomalyModelInputs, - batch_idx: int = 0, - **kwargs, - ) -> STEP_OUTPUT: - """Call test step of the anomalib model.""" - if not isinstance(inputs, dict): - inputs = self._customize_inputs(inputs) - return AnomalibStfpm.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc] - - def forward( - self, - inputs: AnomalyModelInputs, - ) -> AnomalyModelOutputs: - """Wrap forward method of the Anomalib model.""" - outputs = self.validation_step(inputs) - # TODO(Ashwin): update forward implementation to comply with other OTX models - _PostProcessorCallback._post_process(outputs) # noqa: SLF001 - _PostProcessorCallback._compute_scores_and_labels(self, outputs) # noqa: SLF001 - _MinMaxNormalizationCallback._normalize_batch(outputs, self) # noqa: SLF001 - - return self._customize_outputs(outputs=outputs, inputs=inputs) diff --git a/src/otx/core/model/anomaly.py b/src/otx/core/model/anomaly.py index 59b2a5dd1da..4354386a442 100644 --- a/src/otx/core/model/anomaly.py +++ b/src/otx/core/model/anomaly.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, TypeAlias +from typing import TYPE_CHECKING, Any, Sequence, TypeAlias import torch from anomalib import TaskType as AnomalibTaskType @@ -14,6 +14,7 @@ from anomalib.callbacks.thresholding import _ThresholdCallback from torch import nn +from otx import __version__ from otx.core.data.entity.anomaly import ( AnomalyClassificationBatchPrediction, AnomalyClassificationDataBatch, @@ -26,10 +27,13 @@ from otx.core.exporter.anomaly import OTXAnomalyModelExporter from otx.core.model.base import OTXModel from otx.core.types.export import OTXExportFormatType +from otx.core.types.label import AnomalyLabelInfo from otx.core.types.precision import OTXPrecisionType from otx.core.types.task import OTXTaskType +from otx.core.utils.utils import remove_state_dict_prefix if TYPE_CHECKING: + import types from pathlib import Path from anomalib.metrics import AnomalibMetricCollection @@ -37,9 +41,10 @@ from lightning.pytorch import Trainer from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable + from lightning.pytorch.utilities.types import STEP_OUTPUT + from torch.optim.optimizer import Optimizer from torchmetrics import Metric - from otx.core.types.label import LabelInfoTypes AnomalyModelInputs: TypeAlias = ( AnomalyClassificationDataBatch | AnomalySegmentationDataBatch | AnomalyDetectionDataBatch @@ -57,8 +62,8 @@ class OTXAnomaly(OTXModel): Model input size in the order of height and width. Defaults to None. """ - def __init__(self, label_info: LabelInfoTypes, input_size: tuple[int, int]) -> None: - super().__init__(label_info=label_info, input_size=input_size) + def __init__(self) -> None: + super().__init__(label_info=AnomalyLabelInfo(), input_size=self.input_size) self.optimizer: list[OptimizerCallable] | OptimizerCallable = None self.scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = None self.trainer: Trainer @@ -71,32 +76,21 @@ def __init__(self, label_info: LabelInfoTypes, input_size: tuple[int, int]) -> N self.image_metrics: AnomalibMetricCollection self.pixel_metrics: AnomalibMetricCollection - def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: - """Callback on saving checkpoint.""" - super().on_save_checkpoint(checkpoint) # type: ignore[misc] - - attrs = ["_task_type", "_input_size", "image_threshold", "pixel_threshold"] - checkpoint["anomaly"] = {key: getattr(self, key, None) for key in attrs} - - def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: - """Callback on loading checkpoint.""" - super().on_load_checkpoint(checkpoint) # type: ignore[misc] - if anomaly_attrs := checkpoint.get("anomaly"): - for key, value in anomaly_attrs.items(): - setattr(self, key, value) - - @property # type: ignore[override] - def input_size(self) -> tuple[int, int]: - """Returns the input size of the model. + def save_hyperparameters( + self, + *args: Any, # noqa: ANN401 + ignore: Sequence[str] | str | None = None, + frame: types.FrameType | None = None, + logger: bool = True, + ) -> None: + """Ignore task from hyperparameters. - Returns: - tuple[int, int]: The input size of the model as a tuple of (height, width). + Need to ignore task from hyperparameters as it is passed as a string from the CLI. This causes + ``log_hyperparameters`` to fail as it does not match with instance of ``OTXTaskType`` from + ``OTXDataModule``. """ - return self._input_size - - @input_size.setter - def input_size(self, value: tuple[int, int]) -> None: - self._input_size = value + ignore = ["task"] if ignore is None else [*ignore, "task"] + return super().save_hyperparameters(*args, ignore=ignore, frame=frame, logger=logger) @property def task(self) -> AnomalibTaskType: @@ -130,6 +124,10 @@ def _get_values_from_transforms( std_value = tuple(value * 255 for value in transform.std) # type: ignore[assignment] return mean_value, std_value + def configure_metric(self) -> None: + """This does not follow OTX metric configuration.""" + return + @property def trainable_model(self) -> str | None: """Use this to return the name of the model that needs to be trained. @@ -157,6 +155,22 @@ def configure_callbacks(self) -> list[Callback]: ), ] + def on_validation_epoch_start(self) -> None: + """Don't call OTXModel's ``on_validation_epoch_start``.""" + return + + def on_test_epoch_start(self) -> None: + """Don't call OTXModel's ``on_test_epoch_start``.""" + return + + def on_validation_epoch_end(self) -> None: + """Don't call OTXModel's ``on_validation_epoch_end``.""" + return + + def on_test_epoch_end(self) -> None: + """Don't call OTXModel's ``on_test_epoch_end``.""" + return + def on_predict_batch_end( self, outputs: dict, @@ -246,6 +260,9 @@ def _exporter(self) -> OTXAnomalyModelExporter: "input_names": ["input"], "output_names": ["output"], } + if self.input_size is None: + msg = "Input size is not defined" + raise ValueError(msg) return OTXAnomalyModelExporter( image_shape=self.input_size, image_threshold=self.image_threshold.value.cpu().numpy().tolist(), @@ -299,14 +316,14 @@ def get_dummy_input(self, batch_size: int = 1) -> AnomalyModelInputs: ori_shape=img.shape, ), ) - if self.task == AnomalibTaskType.CLASSIFICATION: + if self.task == OTXTaskType.ANOMALY_CLASSIFICATION: return AnomalyClassificationDataBatch( batch_size=batch_size, images=images, imgs_info=infos, labels=[torch.LongTensor(0)], ) - if self.task == AnomalibTaskType.SEGMENTATION: + if self.task == OTXTaskType.ANOMALY_SEGMENTATION: return AnomalySegmentationDataBatch( batch_size=batch_size, images=images, @@ -314,7 +331,7 @@ def get_dummy_input(self, batch_size: int = 1) -> AnomalyModelInputs: labels=[torch.LongTensor(0)], masks=torch.tensor(0), ) - if self.task == AnomalibTaskType.DETECTION: + if self.task == OTXTaskType.ANOMALY_DETECTION: return AnomalyDetectionDataBatch( batch_size=batch_size, images=images, @@ -326,3 +343,111 @@ def get_dummy_input(self, batch_size: int = 1) -> AnomalyModelInputs: msg = "Wrong anomaly task type" raise RuntimeError(msg) + + +class AnomalyMixin: + """Mixin inherited before AnomalibModule to override OTXModel methods.""" + + def configure_optimizers(self) -> tuple[list[Optimizer], list[Optimizer]] | None: + """Call AnomlibModule's configure optimizer.""" + return super().configure_optimizers() # type: ignore[misc] + + def on_train_epoch_end(self) -> None: + """Callback triggered when the training epoch ends.""" + return super().on_train_epoch_end() # type: ignore[misc] + + def on_validation_start(self) -> None: + """Callback triggered when the validation starts.""" + return super().on_validation_start() # type: ignore[misc] + + def training_step( + self, + inputs: AnomalyModelInputs, + batch_idx: int = 0, + ) -> STEP_OUTPUT: + """Call training step of the anomalib model.""" + if not isinstance(inputs, dict): + inputs = self._customize_inputs(inputs) # type: ignore[attr-defined] + return super().training_step(inputs, batch_idx) # type: ignore[misc] + + def validation_step( + self, + inputs: AnomalyModelInputs, + batch_idx: int = 0, + ) -> STEP_OUTPUT: + """Call validation step of the anomalib model.""" + if not isinstance(inputs, dict): + inputs = self._customize_inputs(inputs) # type: ignore[attr-defined] + return super().validation_step(inputs, batch_idx) # type: ignore[misc] + + def test_step( + self, + inputs: AnomalyModelInputs, + batch_idx: int = 0, + **kwargs, + ) -> STEP_OUTPUT: + """Call test step of the anomalib model.""" + if not isinstance(inputs, dict): + inputs = self._customize_inputs(inputs) # type: ignore[attr-defined] + return super().test_step(inputs, batch_idx, **kwargs) # type: ignore[misc] + + def predict_step( + self, + inputs: AnomalyModelInputs, + batch_idx: int = 0, + **kwargs, + ) -> STEP_OUTPUT: + """Call test step of the anomalib model.""" + if not isinstance(inputs, dict): + inputs = self._customize_inputs(inputs) # type: ignore[attr-defined] + return super().predict_step(inputs, batch_idx, **kwargs) # type: ignore[misc] + + def forward( + self, + inputs: AnomalyModelInputs, + ) -> AnomalyModelOutputs: + """Wrap forward method of the Anomalib model.""" + outputs = self.validation_step(inputs) + # TODO(Ashwin): update forward implementation to comply with other OTX models + _PostProcessorCallback._post_process(outputs) # noqa: SLF001 + _PostProcessorCallback._compute_scores_and_labels(self, outputs) # noqa: SLF001 + _MinMaxNormalizationCallback._normalize_batch(outputs, self) # noqa: SLF001 + + return self._customize_outputs(outputs=outputs, inputs=inputs) # type: ignore[attr-defined] + + @property # type: ignore[override] + def input_size(self) -> tuple[int, int]: + """Returns the input size of the model. + + Returns: + tuple[int, int]: The input size of the model as a tuple of (height, width). + """ + return self._input_shape # since _input_size is re-defined in the base class. + + @input_size.setter + def input_size(self, value: tuple[int, int]) -> None: + self._input_shape = value + + def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Callback on saving checkpoint.""" + if self.torch_compile: # type: ignore[attr-defined] + # If torch_compile is True, a prefix key named _orig_mod. is added to the state_dict. Remove this. + compiled_state_dict = checkpoint["state_dict"] + checkpoint["state_dict"] = remove_state_dict_prefix(compiled_state_dict, "_orig_mod.") + # calls Anomalib's on_save_checkpoint + super().on_save_checkpoint(checkpoint) # type: ignore[misc] + + checkpoint["label_info"] = self.label_info # type: ignore[attr-defined] + checkpoint["otx_version"] = __version__ + checkpoint["tile_config"] = self.tile_config # type: ignore[attr-defined] + + attrs = ["_input_shape", "image_threshold", "pixel_threshold"] + checkpoint["anomaly"] = {key: getattr(self, key, None) for key in attrs} + + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Callback on loading checkpoint.""" + # calls Anomalib's on_load_checkpoint + super().on_load_checkpoint(checkpoint) # type: ignore[misc] + if anomaly_attrs := checkpoint.get("anomaly"): + for key, value in anomaly_attrs.items(): + setattr(self, key, value)