Skip to content

Commit

Permalink
Promote OTXModel to PyTorchLightningModule and deprecate OTXLitModule (
Browse files Browse the repository at this point in the history
…#3155)

* Refactor

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix draem

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix ruff

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix pickling errors during HPO

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix draem test error

Signed-off-by: Kim, Vinnam <[email protected]>

* Mark xfail to test_otx_ov_test

Signed-off-by: Kim, Vinnam <[email protected]>

* Remove metric overriding for hlabel model in test_otx_export_infer

Signed-off-by: Kim, Vinnam <[email protected]>

---------

Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim authored Mar 21, 2024
1 parent 21373fa commit bec32b6
Show file tree
Hide file tree
Showing 163 changed files with 3,181 additions and 3,198 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ ignore = [
"TCH001", # typing-only-first-party-import, Sometimes this causes an incorrect error.
# flake8-fixme
"FIX002", # line-contains-todo

"E731", # Do not assign a `lambda` expression, use a `def`
"TD003", # Missing issue link on the line following this TODO
]

# Allow autofix for all enabled rules (when `--fix`) is provided.
Expand Down
31 changes: 28 additions & 3 deletions src/otx/algo/action_classification/movinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,42 @@
#
"""X3D model implementation."""

from __future__ import annotations

from typing import TYPE_CHECKING

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.action_classification import MMActionCompatibleModel
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
from otx.core.model.action_classification import MMActionCompatibleModel
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable

from otx.core.metrics import MetricCallable


class MoViNet(MMActionCompatibleModel):
"""MoViNet Model."""

def __init__(self, num_classes: int) -> None:
def __init__(
self,
num_classes: int,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
config = read_mmconfig("movinet")
super().__init__(num_classes=num_classes, config=config)
super().__init__(
num_classes=num_classes,
config=config,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
Expand Down
30 changes: 27 additions & 3 deletions src/otx/algo/action_classification/x3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,42 @@
# SPDX-License-Identifier: Apache-2.0
#
"""X3D model implementation."""
from __future__ import annotations

from typing import TYPE_CHECKING

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.action_classification import MMActionCompatibleModel
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
from otx.core.model.action_classification import MMActionCompatibleModel
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable

from otx.core.metrics import MetricCallable


class X3D(MMActionCompatibleModel):
"""X3D Model."""

def __init__(self, num_classes: int) -> None:
def __init__(
self,
num_classes: int,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
config = read_mmconfig("x3d")
super().__init__(num_classes=num_classes, config=config)
super().__init__(
num_classes=num_classes,
config=config,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
Expand Down
30 changes: 27 additions & 3 deletions src/otx/algo/action_detection/x3d_fastrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,42 @@
"""X3DFastRCNN model implementation."""
from __future__ import annotations

from typing import TYPE_CHECKING

from otx.algo.utils.mmconfig import read_mmconfig
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.model.entity.action_detection import MMActionCompatibleModel
from otx.core.metrics.mean_ap import MeanAPCallable
from otx.core.model.action_detection import MMActionCompatibleModel
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable

from otx.core.metrics import MetricCallable


class X3DFastRCNN(MMActionCompatibleModel):
"""X3D Model."""

def __init__(self, num_classes: int, topk: int | tuple[int]):
def __init__(
self,
num_classes: int,
topk: int | tuple[int],
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
metric: MetricCallable = MeanAPCallable,
torch_compile: bool = False,
) -> None:
config = read_mmconfig("x3d_fastrcnn")
config.roi_head.bbox_head.topk = topk
super().__init__(num_classes=num_classes, config=config)
super().__init__(
num_classes=num_classes,
config=config,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
Expand Down
84 changes: 80 additions & 4 deletions src/otx/algo/anomaly/draem.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
"""OTX Draem model."""
"""OTX AnomalibDraem model."""
# TODO(someone): Revisit mypy errors after OTXLitModule deprecation and anomaly refactoring
# mypy: ignore-errors

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from typing import TYPE_CHECKING

from anomalib.models.image import Draem as AnomalibDraem

from otx.core.model.entity.base import OTXModel
from otx.core.model.module.anomaly import OTXAnomaly
from otx.core.model.anomaly import OTXAnomaly
from otx.core.model.base import OTXModel

if TYPE_CHECKING:
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.optim.optimizer import Optimizer

from otx.core.model.anomaly import AnomalyModelInputs


class Draem(OTXAnomaly, OTXModel, AnomalibDraem):
"""OTX Draem model.
"""OTX AnomalibDraem model.
Args:
enable_sspcab (bool): Enable SSPCAB training. Defaults to ``False``.
Expand Down Expand Up @@ -40,3 +50,69 @@ def __init__(
anomaly_source_path=anomaly_source_path,
beta=beta,
)

def configure_metric(self) -> None:
"""This does not follow OTX metric configuration."""
return

def configure_optimizers(self) -> tuple[list[Optimizer], list[Optimizer]] | None:
"""DRAEM does not follow OTX optimizer configuration."""
return AnomalibDraem.configure_optimizers(self)

def on_validation_epoch_start(self) -> None:
"""Callback triggered when the validation epoch starts."""
AnomalibDraem.on_validation_epoch_start(self)

def on_test_epoch_start(self) -> None:
"""Callback triggered when the test epoch starts."""
AnomalibDraem.on_test_epoch_start(self)

def on_validation_epoch_end(self) -> None:
"""Callback triggered when the validation epoch ends."""
AnomalibDraem.on_validation_epoch_end(self)

def on_test_epoch_end(self) -> None:
"""Callback triggered when the test epoch ends."""
AnomalibDraem.on_test_epoch_end(self)

def training_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call training step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibDraem.training_step(self, inputs, batch_idx) # type: ignore[misc]

def validation_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call validation step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibDraem.validation_step(self, inputs, batch_idx) # type: ignore[misc]

def test_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibDraem.test_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def predict_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibDraem.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]
7 changes: 5 additions & 2 deletions src/otx/algo/anomaly/openvino_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
All anomaly models use the same AnomalyDetection model from ModelAPI.
"""
# TODO(someone): Revisit mypy errors after OTXLitModule deprecation and anomaly refactoring
# mypy: ignore-errors

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
Expand All @@ -12,8 +14,8 @@

from lightning.pytorch import LightningModule

from otx.core.model.entity.base import OTXModel, OVModel
from otx.core.model.module.anomaly.anomaly_lightning import AnomalyModelInputs
from otx.core.model.anomaly import AnomalyModelInputs
from otx.core.model.base import OTXModel, OVModel

if TYPE_CHECKING:
from openvino.model_api.models import Model
Expand All @@ -34,6 +36,7 @@ def __init__(
use_throughput_mode: bool = True,
model_api_configuration: dict[str, Any] | None = None,
num_classes: int = 2,
**kwargs,
) -> None:
super().__init__(
num_classes=num_classes, # NOTE: Ideally this should be set to 2 always
Expand Down
80 changes: 78 additions & 2 deletions src/otx/algo/anomaly/padim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# TODO(someone): Revisit mypy errors after OTXLitModule deprecation and anomaly refactoring
# mypy: ignore-errors

from __future__ import annotations

from typing import TYPE_CHECKING

from anomalib.models.image import Padim as AnomalibPadim

from otx.core.model.entity.base import OTXModel
from otx.core.model.module.anomaly import OTXAnomaly
from otx.core.model.anomaly import OTXAnomaly
from otx.core.model.base import OTXModel

if TYPE_CHECKING:
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.optim.optimizer import Optimizer

from otx.core.model.anomaly import AnomalyModelInputs


class Padim(OTXAnomaly, OTXModel, AnomalibPadim):
Expand Down Expand Up @@ -40,3 +50,69 @@ def __init__(
pre_trained=pre_trained,
n_features=n_features,
)

def configure_optimizers(self) -> tuple[list[Optimizer], list[Optimizer]] | None:
"""PADIM doesn't require optimization, therefore returns no optimizers."""
return

def configure_metric(self) -> None:
"""This does not follow OTX metric configuration."""
return

def on_validation_epoch_start(self) -> None:
"""Callback triggered when the validation epoch starts."""
AnomalibPadim.on_validation_epoch_start(self)

def on_test_epoch_start(self) -> None:
"""Callback triggered when the test epoch starts."""
AnomalibPadim.on_test_epoch_start(self)

def on_validation_epoch_end(self) -> None:
"""Callback triggered when the validation epoch ends."""
AnomalibPadim.on_validation_epoch_end(self)

def on_test_epoch_end(self) -> None:
"""Callback triggered when the test epoch ends."""
AnomalibPadim.on_test_epoch_end(self)

def training_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call training step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.training_step(self, inputs, batch_idx) # type: ignore[misc]

def validation_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call validation step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.validation_step(self, inputs, batch_idx) # type: ignore[misc]

def test_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.test_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def predict_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]
Loading

0 comments on commit bec32b6

Please sign in to comment.