Skip to content

Commit

Permalink
Refactor optimizer and lr scheduler part (#3216)
Browse files Browse the repository at this point in the history
* Refactor optimizer and lr schedulers

* Update tests

* Update recipe

* Fix torch import problem at CLI installation

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix test errors

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix

Signed-off-by: Kim, Vinnam <[email protected]>

* Rollback how OVVisualPromptingModel get model names

Signed-off-by: Kim, Vinnam <[email protected]>

* Update src/otx/algo/anomaly/openvino_model.py

Co-authored-by: Harim Kang <[email protected]>

---------

Signed-off-by: Kim, Vinnam <[email protected]>
Co-authored-by: Harim Kang <[email protected]>
  • Loading branch information
vinnamkim and harimkang authored Mar 29, 2024
1 parent 9a47ee8 commit 375e89e
Show file tree
Hide file tree
Showing 123 changed files with 1,626 additions and 1,340 deletions.
5 changes: 3 additions & 2 deletions src/otx/algo/action_classification/movinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
from otx.core.model.action_classification import MMActionCompatibleModel
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.schedulers import LRSchedulerListCallable

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand All @@ -25,8 +26,8 @@ class MoViNet(MMActionCompatibleModel):
def __init__(
self,
num_classes: int,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/otx/algo/action_classification/x3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
from otx.core.model.action_classification import MMActionCompatibleModel
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.schedulers import LRSchedulerListCallable

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand All @@ -24,8 +25,8 @@ class X3D(MMActionCompatibleModel):
def __init__(
self,
num_classes: int,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/otx/algo/action_detection/x3d_fastrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from otx.core.metrics.mean_ap import MeanAPCallable
from otx.core.model.action_detection import MMActionCompatibleModel
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.schedulers import LRSchedulerListCallable

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand All @@ -25,8 +26,8 @@ def __init__(
self,
num_classes: int,
topk: int | tuple[int],
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MeanAPCallable,
torch_compile: bool = False,
) -> None:
Expand Down
10 changes: 6 additions & 4 deletions src/otx/algo/anomaly/openvino_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
All anomaly models use the same AnomalyDetection model from ModelAPI.
"""

# TODO(someone): Revisit mypy errors after OTXLitModule deprecation and anomaly refactoring
# mypy: ignore-errors

Expand All @@ -12,17 +13,16 @@

from typing import TYPE_CHECKING, Any

from lightning.pytorch import LightningModule

from otx.core.metrics.types import MetricCallable, NullMetricCallable
from otx.core.model.anomaly import AnomalyModelInputs
from otx.core.model.base import OTXModel, OVModel
from otx.core.model.base import OVModel

if TYPE_CHECKING:
from openvino.model_api.models import Model
from openvino.model_api.models.anomaly import AnomalyResult


class AnomalyOpenVINO(OVModel, OTXModel, LightningModule):
class AnomalyOpenVINO(OVModel):
"""Anomaly OpenVINO model."""

# [TODO](ashwinvaidya17): Remove LightningModule once OTXModel is updated to use LightningModule.
Expand All @@ -36,6 +36,7 @@ def __init__(
use_throughput_mode: bool = True,
model_api_configuration: dict[str, Any] | None = None,
num_classes: int = 2,
metric: MetricCallable = NullMetricCallable,
**kwargs,
) -> None:
super().__init__(
Expand All @@ -46,6 +47,7 @@ def __init__(
max_num_requests=max_num_requests,
use_throughput_mode=use_throughput_mode,
model_api_configuration=model_api_configuration,
metric=metric,
)

def _create_model(self) -> Model:
Expand Down
13 changes: 7 additions & 6 deletions src/otx/algo/classification/deit_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
MMPretrainMulticlassClsModel,
MMPretrainMultilabelClsModel,
)
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo

if TYPE_CHECKING:
Expand Down Expand Up @@ -150,8 +151,8 @@ class DeitTinyForHLabelCls(ExplainableDeit, MMPretrainHlabelClsModel):
def __init__(
self,
hlabel_info: HLabelInfo,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallble,
torch_compile: bool = False,
) -> None:
Expand All @@ -177,8 +178,8 @@ class DeitTinyForMulticlassCls(ExplainableDeit, MMPretrainMulticlassClsModel):
def __init__(
self,
num_classes: int,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand All @@ -203,8 +204,8 @@ class DeitTinyForMultilabelCls(ExplainableDeit, MMPretrainMultilabelClsModel):
def __init__(
self,
num_classes: int,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand Down
13 changes: 7 additions & 6 deletions src/otx/algo/classification/efficientnet_b0.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MMPretrainMulticlassClsModel,
MMPretrainMultilabelClsModel,
)
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo

if TYPE_CHECKING:
Expand All @@ -29,8 +30,8 @@ class EfficientNetB0ForHLabelCls(MMPretrainHlabelClsModel):
def __init__(
self,
hlabel_info: HLabelInfo,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallble,
torch_compile: bool = False,
) -> None:
Expand All @@ -57,8 +58,8 @@ def __init__(
self,
num_classes: int,
light: bool = False,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand All @@ -84,8 +85,8 @@ class EfficientNetB0ForMultilabelCls(MMPretrainMultilabelClsModel):
def __init__(
self,
num_classes: int,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand Down
13 changes: 7 additions & 6 deletions src/otx/algo/classification/efficientnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MMPretrainMulticlassClsModel,
MMPretrainMultilabelClsModel,
)
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo

if TYPE_CHECKING:
Expand All @@ -29,8 +30,8 @@ class EfficientNetV2ForHLabelCls(MMPretrainHlabelClsModel):
def __init__(
self,
hlabel_info: HLabelInfo,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallble,
torch_compile: bool = False,
) -> None:
Expand All @@ -57,8 +58,8 @@ def __init__(
self,
num_classes: int,
light: bool = False,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand All @@ -84,8 +85,8 @@ class EfficientNetV2ForMultilabelCls(MMPretrainMultilabelClsModel):
def __init__(
self,
num_classes: int,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
backbone:
name: dinov2_vits14_reg
frozen: false
head:
in_channels: 384
num_classes: 1000
13 changes: 7 additions & 6 deletions src/otx/algo/classification/mobilenet_v3_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MMPretrainMulticlassClsModel,
MMPretrainMultilabelClsModel,
)
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo

if TYPE_CHECKING:
Expand All @@ -29,8 +30,8 @@ class MobileNetV3ForHLabelCls(MMPretrainHlabelClsModel):
def __init__(
self,
hlabel_info: HLabelInfo,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallble,
torch_compile: bool = False,
) -> None:
Expand Down Expand Up @@ -64,8 +65,8 @@ def __init__(
self,
num_classes: int,
light: bool = False,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand Down Expand Up @@ -98,8 +99,8 @@ class MobileNetV3ForMultilabelCls(MMPretrainMultilabelClsModel):
def __init__(
self,
num_classes: int,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
Expand Down
13 changes: 9 additions & 4 deletions src/otx/algo/classification/otx_dino_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from torch import nn

from otx.algo.utils.mmconfig import read_mmconfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import (
MulticlassClsBatchDataEntity,
Expand All @@ -20,11 +21,11 @@
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import OTXMulticlassClsModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.utils.config import inplace_num_classes

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from omegaconf import DictConfig

from otx.core.metrics import MetricCallable

Expand Down Expand Up @@ -76,14 +77,18 @@ class DINOv2RegisterClassifier(OTXMulticlassClsModel):
def __init__(
self,
num_classes: int,
config: DictConfig,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
freeze_backbone: bool = False,
) -> None:
config = read_mmconfig(model_name="dino_v2", subdir_name="multiclass_classification")
config = inplace_num_classes(cfg=config, num_classes=num_classes)
config.backbone.frozen = freeze_backbone

self.config = config

super().__init__(
num_classes=num_classes,
optimizer=optimizer,
Expand Down
5 changes: 3 additions & 2 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import OTXMulticlassClsModel
from otx.core.schedulers import LRSchedulerListCallable

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand Down Expand Up @@ -182,8 +183,8 @@ def __init__(
backbone: TVModelType,
num_classes: int,
loss_callable: Callable[[], nn.Module] = nn.CrossEntropyLoss,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
freeze_backbone: bool = False,
Expand Down
9 changes: 5 additions & 4 deletions src/otx/algo/detection/atss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from otx.core.metrics.mean_ap import MeanAPCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.detection import MMDetCompatibleModel
from otx.core.schedulers import LRSchedulerListCallable

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand All @@ -26,8 +27,8 @@ def __init__(
self,
num_classes: int,
variant: Literal["mobilenetv2", "resnext101"],
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MeanAPCallable,
torch_compile: bool = False,
) -> None:
Expand Down Expand Up @@ -67,8 +68,8 @@ class ATSSR50FPN(MMDetCompatibleModel):
def __init__(
self,
num_classes: int,
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MeanAPCallable,
torch_compile: bool = False,
) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/otx/algo/detection/rtmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from otx.core.metrics.mean_ap import MeanAPCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.detection import MMDetCompatibleModel
from otx.core.schedulers import LRSchedulerListCallable

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand All @@ -26,8 +27,8 @@ def __init__(
self,
num_classes: int,
variant: Literal["tiny"],
optimizer: list[OptimizerCallable] | OptimizerCallable = DefaultOptimizerCallable,
scheduler: list[LRSchedulerCallable] | LRSchedulerCallable = DefaultSchedulerCallable,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MeanAPCallable,
torch_compile: bool = False,
) -> None:
Expand Down
Loading

0 comments on commit 375e89e

Please sign in to comment.