Skip to content

Commit

Permalink
provide another solution
Browse files Browse the repository at this point in the history
  • Loading branch information
kprokofi committed Oct 10, 2024
1 parent f9da6cd commit 28f2eb9
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 26 deletions.
6 changes: 3 additions & 3 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _extract_class_mask(item: DatasetItem, img_shape: tuple[int, int], ignore_in
msg = "It is not currently support an ignore index which is more than 255."
raise ValueError(msg, ignore_index)

# fill mask with background label if we have Polygon/Ellipse annotations
# fill mask with background label if we have Polygon/Ellipse/Bbox annotations
fill_value = 0 if isinstance(item.annotations[0], (Ellipse, Polygon, Bbox, RotatedBbox)) else ignore_index
class_mask = np.full(shape=img_shape[:2], fill_value=fill_value, dtype=np.uint8)

Expand Down Expand Up @@ -179,9 +179,9 @@ def __init__(
to_tv_image,
)

if self.has_polygons and "background" not in [label_name.lower() for label_name in self.label_info.label_names]:
if self.has_polygons:
# insert background class at index 0 since polygons represent only objects
self.label_info.label_names.insert(0, "background")
self.label_info.label_names.insert(0, "otx_background_lbl")

self.label_info = SegLabelInfo(
label_names=self.label_info.label_names,
Expand Down
5 changes: 0 additions & 5 deletions src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,11 +1095,6 @@ def model_adapter_parameters(self) -> dict:
def _set_label_info(self, label_info: LabelInfoTypes) -> None:
"""Set this model label information."""
new_label_info = self._dispatch_label_info(label_info)

if self._label_info != new_label_info:
msg = "OVModel strictly does not allow overwrite label_info if they are different each other."
raise ValueError(msg)

self._label_info = new_label_info

def _create_label_info_from_ov_ir(self) -> LabelInfo:
Expand Down
9 changes: 9 additions & 0 deletions src/otx/core/model/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import copy
import json
from abc import abstractmethod
from collections.abc import Sequence
Expand Down Expand Up @@ -165,12 +166,20 @@ def _customize_outputs(
@property
def _export_parameters(self) -> TaskLevelExportParameters:
"""Defines parameters required to export a particular model implementation."""
if self.label_info.label_names[0] == "otx_background_lbl":
# remove otx background label for export
modified_label_info = copy.deepcopy(self.label_info)
modified_label_info.label_names.pop(0)
else:
modified_label_info = self.label_info

return super()._export_parameters.wrap(
model_type="Segmentation",
task_type="segmentation",
return_soft_prediction=True,
soft_threshold=0.5,
blur_strength=-1,
label_info=modified_label_info,
)

@property
Expand Down
12 changes: 2 additions & 10 deletions src/otx/core/types/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,17 @@ def to_metadata(self) -> dict[tuple[str, str], str]:
dict[tuple[str, str], str]: It will be directly delivered to
OpenVINO IR's `rt_info` or ONNX metadata slot.
"""
label_names = self.label_info.label_names
if self.task_type == "instance_segmentation":
# Instance segmentation needs to add empty label
all_labels = "otx_empty_lbl "
all_label_ids = "None "
for lbl in label_names:
all_labels += lbl.replace(" ", "_") + " "
all_label_ids += lbl.replace(" ", "_") + " "
elif self.task_type == "semantic_segmentation" and label_names[0].lower() == "background":
# Semantic segmentation needs to remove first background label
all_labels = ""
all_label_ids = ""
for lbl in label_names[1:]:
for lbl in self.label_info.label_names:
all_labels += lbl.replace(" ", "_") + " "
all_label_ids += lbl.replace(" ", "_") + " "
else:
all_labels = ""
all_label_ids = ""
for lbl in label_names:
for lbl in self.label_info.label_names:
all_labels += lbl.replace(" ", "_") + " "
all_label_ids += lbl.replace(" ", "_") + " "

Expand Down
25 changes: 17 additions & 8 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import copy
import csv
import inspect
import logging
Expand Down Expand Up @@ -370,14 +371,22 @@ def test(
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams)

if model.label_info != self.datamodule.label_info:
msg = (
"To launch a test pipeline, the label information should be same "
"between the training and testing datasets. "
"Please check whether you use the same dataset: "
f"model.label_info={model.label_info}, "
f"datamodule.label_info={self.datamodule.label_info}"
)
raise ValueError(msg)
if (
self.task == "SEMANTIC_SEGMENTATION"
and "otx_background_lbl" in self.datamodule.label_info.label_names
and (len(self.datamodule.label_info.label_names) - len(model.label_info.label_names) == 1)
):
# workaround for background label
model.label_info = copy.deepcopy(self.datamodule.label_info)
else:
msg = (
"To launch a test pipeline, the label information should be same "
"between the training and testing datasets. "
"Please check whether you use the same dataset: "
f"model.label_info={model.label_info}, "
f"datamodule.label_info={self.datamodule.label_info}"
)
raise ValueError(msg)

self._build_trainer(**kwargs)

Expand Down

0 comments on commit 28f2eb9

Please sign in to comment.