Skip to content

Commit

Permalink
Fix wrong ov datamodule update in predict, test, optimize (#3139)
Browse files Browse the repository at this point in the history
* Fix wrong ov datamodule update

* Fix pre-commit

* Fix tiler config

* Remove ov dataset update in optimize

* Fix ov pipeline update function
  • Loading branch information
harimkang authored Mar 19, 2024
1 parent 106c111 commit 021e793
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 19 deletions.
11 changes: 7 additions & 4 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def test(

is_ir_ckpt = Path(str(checkpoint)).suffix in [".xml", ".onnx"]
if is_ir_ckpt and not isinstance(model, OVModel):
datamodule = self._auto_configurator.get_ov_datamodule()
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)

metric = metric if metric is not None else self._auto_configurator.get_metric()
Expand Down Expand Up @@ -384,7 +384,7 @@ def predict(

is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"]
if is_ir_ckpt and not isinstance(model, OVModel):
datamodule = self._auto_configurator.get_ov_datamodule()
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)

lit_module = self._build_lightning_module(
Expand Down Expand Up @@ -533,7 +533,10 @@ def optimize(

model = self.model
if not isinstance(model, OVModel):
datamodule = self._auto_configurator.get_ov_datamodule()
optimize_datamodule = self._auto_configurator.update_ov_subset_pipeline(
datamodule=optimize_datamodule,
subset="train",
)
model = self._auto_configurator.get_ov_model(
model_name=str(checkpoint),
label_info=optimize_datamodule.label_info,
Expand Down Expand Up @@ -594,7 +597,7 @@ def explain(

is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"]
if is_ir_ckpt and not isinstance(model, OVModel):
datamodule = self._auto_configurator.get_ov_datamodule()
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)

lit_module = self._build_lightning_module(
Expand Down
35 changes: 20 additions & 15 deletions src/otx/engine/utils/auto_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING
from warnings import warn

import datumaro
from lightning.pytorch.cli import instantiate_class
Expand Down Expand Up @@ -333,22 +334,26 @@ def get_ov_model(self, model_name: str, label_info: LabelInfo) -> OVModel:
num_classes=label_info.num_classes,
)

def get_ov_datamodule(self) -> OTXDataModule:
"""Returns an instance of OTXDataModule configured with the specified data root and data module configuration.
def update_ov_subset_pipeline(self, datamodule: OTXDataModule, subset: str = "test") -> OTXDataModule:
"""Returns an OTXDataModule object with OpenVINO subset transforms applied.
Args:
datamodule (OTXDataModule): The original OTXDataModule object.
subset (str, optional): The subset to update. Defaults to "test".
Returns:
OTXDataModule: An instance of OTXDataModule.
OTXDataModule: The modified OTXDataModule object with OpenVINO subset transforms applied.
"""
config = self._load_default_config(model_name="openvino_model")
config["data"]["config"]["data_root"] = self.data_root
data_config = config["data"]["config"].copy()
return OTXDataModule(
task=config["data"]["task"],
config=DataModuleConfig(
train_subset=SubsetConfig(**data_config.pop("train_subset")),
val_subset=SubsetConfig(**data_config.pop("val_subset")),
test_subset=SubsetConfig(**data_config.pop("test_subset")),
tile_config=TileConfig(**data_config.pop("tile_config", {})),
**data_config,
),
data_configuration = datamodule.config
ov_test_config = self._load_default_config(model_name="openvino_model")["data"]["config"][f"{subset}_subset"]
subset_config = getattr(data_configuration, f"{subset}_subset")
subset_config.transform_lib_type = ov_test_config["transform_lib_type"]
subset_config.transforms = ov_test_config["transforms"]
data_configuration.tile_config.enable_tiler = False
msg = (
f"For OpenVINO IR models, Update the following {subset} transforms: {subset_config.transforms}"
f"and transform_lib_type: {subset_config.transform_lib_type}"
"And the tiler is disabled."
)
warn(msg, stacklevel=1)
return OTXDataModule(task=datamodule.task, config=data_configuration)
24 changes: 24 additions & 0 deletions tests/unit/engine/utils/test_auto_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from otx.core.data.module import OTXDataModule
from otx.core.model.entity.base import OTXModel
from otx.core.types.task import OTXTaskType
from otx.core.types.transformer_libs import TransformLibType
from otx.engine.utils.auto_configurator import (
DEFAULT_CONFIG_PER_TASK,
AutoConfigurator,
Expand Down Expand Up @@ -140,3 +141,26 @@ def test_get_scheduler(self) -> None:
assert callable(sch)
else:
assert callable(scheduler)

def test_update_ov_subset_pipeline(self) -> None:
data_root = "tests/assets/car_tree_bug"
auto_configurator = AutoConfigurator(data_root=data_root, task="DETECTION")

datamodule = auto_configurator.get_datamodule()

assert datamodule.config.test_subset.transforms == [
{"type": "LoadImageFromFile"},
{"type": "Resize", "scale": [992, 736], "keep_ratio": False},
{"type": "LoadAnnotations", "with_bbox": True},
{
"type": "PackDetInputs",
"meta_keys": ["ori_filename", "scale_factor", "ori_shape", "filename", "img_shape", "pad_shape"],
},
]

assert datamodule.config.test_subset.transform_lib_type == TransformLibType.MMDET

updated_datamodule = auto_configurator.update_ov_subset_pipeline(datamodule, subset="test")
assert updated_datamodule.config.test_subset.transforms == [{"class_path": "torchvision.transforms.v2.ToImage"}]

assert updated_datamodule.config.test_subset.transform_lib_type == TransformLibType.TORCHVISION

0 comments on commit 021e793

Please sign in to comment.