diff --git a/src/otx/core/data/dataset/segmentation.py b/src/otx/core/data/dataset/segmentation.py index fbec5c562e5..53975456b67 100644 --- a/src/otx/core/data/dataset/segmentation.py +++ b/src/otx/core/data/dataset/segmentation.py @@ -98,7 +98,7 @@ def _extract_class_mask(item: DatasetItem, img_shape: tuple[int, int], ignore_in msg = "It is not currently support an ignore index which is more than 255." raise ValueError(msg, ignore_index) - # fill mask with background label if we have Polygon/Ellipse annotations + # fill mask with background label if we have Polygon/Ellipse/Bbox annotations fill_value = 0 if isinstance(item.annotations[0], (Ellipse, Polygon, Bbox, RotatedBbox)) else ignore_index class_mask = np.full(shape=img_shape[:2], fill_value=fill_value, dtype=np.uint8) @@ -179,9 +179,9 @@ def __init__( to_tv_image, ) - if self.has_polygons and "background" not in [label_name.lower() for label_name in self.label_info.label_names]: + if self.has_polygons: # insert background class at index 0 since polygons represent only objects - self.label_info.label_names.insert(0, "background") + self.label_info.label_names.insert(0, "otx_background_lbl") self.label_info = SegLabelInfo( label_names=self.label_info.label_names, diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index 5a1cdcac64c..ea6308afc6d 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -1095,11 +1095,6 @@ def model_adapter_parameters(self) -> dict: def _set_label_info(self, label_info: LabelInfoTypes) -> None: """Set this model label information.""" new_label_info = self._dispatch_label_info(label_info) - - if self._label_info != new_label_info: - msg = "OVModel strictly does not allow overwrite label_info if they are different each other." - raise ValueError(msg) - self._label_info = new_label_info def _create_label_info_from_ov_ir(self) -> LabelInfo: diff --git a/src/otx/core/model/segmentation.py b/src/otx/core/model/segmentation.py index 85182944474..a7eecdffe8c 100644 --- a/src/otx/core/model/segmentation.py +++ b/src/otx/core/model/segmentation.py @@ -5,6 +5,7 @@ from __future__ import annotations +import copy import json from abc import abstractmethod from collections.abc import Sequence @@ -165,12 +166,20 @@ def _customize_outputs( @property def _export_parameters(self) -> TaskLevelExportParameters: """Defines parameters required to export a particular model implementation.""" + if self.label_info.label_names[0] == "otx_background_lbl": + # remove otx background label for export + modified_label_info = copy.deepcopy(self.label_info) + modified_label_info.label_names.pop(0) + else: + modified_label_info = self.label_info + return super()._export_parameters.wrap( model_type="Segmentation", task_type="segmentation", return_soft_prediction=True, soft_threshold=0.5, blur_strength=-1, + label_info=modified_label_info, ) @property diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 47647caf7d6..3c0addd547b 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -5,6 +5,7 @@ from __future__ import annotations +import copy import csv import inspect import logging @@ -370,14 +371,22 @@ def test( model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams) if model.label_info != self.datamodule.label_info: - msg = ( - "To launch a test pipeline, the label information should be same " - "between the training and testing datasets. " - "Please check whether you use the same dataset: " - f"model.label_info={model.label_info}, " - f"datamodule.label_info={self.datamodule.label_info}" - ) - raise ValueError(msg) + if ( + self.task == "SEMANTIC_SEGMENTATION" + and "otx_background_lbl" in self.datamodule.label_info.label_names + and (len(self.datamodule.label_info.label_names) - len(model.label_info.label_names) == 1) + ): + # workaround for background label + model.label_info = copy.deepcopy(self.datamodule.label_info) + else: + msg = ( + "To launch a test pipeline, the label information should be same " + "between the training and testing datasets. " + "Please check whether you use the same dataset: " + f"model.label_info={model.label_info}, " + f"datamodule.label_info={self.datamodule.label_info}" + ) + raise ValueError(msg) self._build_trainer(**kwargs) diff --git a/tests/unit/core/data/dataset/test_segmentation.py b/tests/unit/core/data/dataset/test_segmentation.py index 141dc4bf74b..c7e35d0a924 100644 --- a/tests/unit/core/data/dataset/test_segmentation.py +++ b/tests/unit/core/data/dataset/test_segmentation.py @@ -19,7 +19,7 @@ def test_get_item( max_refetch=3, ) assert isinstance(dataset[0], SegDataEntity) - assert "background" in [label_name.lower() for label_name in dataset.label_info.label_names] + assert "otx_background_lbl" in [label_name.lower() for label_name in dataset.label_info.label_names] def test_get_item_from_bbox_dataset( self, @@ -33,4 +33,4 @@ def test_get_item_from_bbox_dataset( ) assert isinstance(dataset[0], SegDataEntity) # OTXSegmentationDataset should add background when getting a dataset which includes only bbox annotations - assert "background" in [label_name.lower() for label_name in dataset.label_info.label_names] + assert "otx_background_lbl" in [label_name.lower() for label_name in dataset.label_info.label_names]