From 0fc7de901b2fed78eba4b6dc5cbd6014ff95d06e Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Fri, 10 May 2024 16:08:40 +0100 Subject: [PATCH] fix unittest --- src/otx/core/model/instance_segmentation.py | 25 ++-- .../maskrcnn_efficientnetb2b.yaml | 119 +++++++++++------- .../rotated_detection/maskrcnn_r50.yaml | 117 ++++++++++------- .../unit/core/model/test_inst_segmentation.py | 14 +-- tests/unit/engine/utils/test_api.py | 1 - 5 files changed, 159 insertions(+), 117 deletions(-) diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index 50c53a5f8d8..523fb5a2af6 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -12,12 +12,13 @@ import numpy as np import torch -from mmengine.structures.instance_data import InstanceData from model_api.models import Model from model_api.tilers import InstanceSegmentationTiler from torchvision import tv_tensors from otx.algo.explain.explain_algo import InstSegExplainAlgo, feature_vector_fn +from otx.algo.instance_segmentation.mmdet.models.detectors.two_stage import TwoStageDetector +from otx.algo.utils.mmengine_utils import InstanceData from otx.core.config.data import TileConfig from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity @@ -35,8 +36,6 @@ if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from mmdet.models.data_preprocessors import DetDataPreprocessor - from mmdet.models.detectors import TwoStageDetector - from mmdet.structures import OptSampleList from model_api.models.utils import InstanceSegmentationResult from omegaconf import DictConfig from torch import nn @@ -257,8 +256,7 @@ def forward_explain(self, inputs: InstanceSegBatchDataEntity) -> InstanceSegBatc @staticmethod def _forward_explain_inst_seg( self: TwoStageDetector, - inputs: torch.Tensor, - data_samples: OptSampleList = None, + entity: InstanceSegBatchDataEntity, mode: str = "tensor", # noqa: ARG004 ) -> dict[str, torch.Tensor]: """Forward func of the BaseDetector instance, which located in is in ExplainableOTXInstanceSegModel().model.""" @@ -267,20 +265,21 @@ def _forward_explain_inst_seg( for param in self.parameters(): param.requires_grad = False - x = self.extract_feat(inputs) + x = self.extract_feat(entity.images) feature_vector = self.feature_vector_fn(x) - predictions = self.get_results_from_head(x, data_samples) + predictions = self.get_results_from_head(x, entity) if isinstance(predictions, tuple) and isinstance(predictions[0], torch.Tensor): # Export case, consists of tensors # For OV task saliency map are generated on MAPI side saliency_map = torch.empty(1, dtype=torch.uint8) - elif isinstance(predictions, list) and isinstance(predictions[0], InstanceData): # Predict case, consists of InstanceData saliency_map = self.explain_fn(predictions) - predictions = self.add_pred_to_datasample(data_samples, predictions) + else: + msg = f"Unexpected predictions type: {type(predictions)}" + raise TypeError(msg) return { "predictions": predictions, @@ -291,7 +290,7 @@ def _forward_explain_inst_seg( def get_results_from_head( self, x: tuple[torch.Tensor], - data_samples: OptSampleList | None, + entity: InstanceSegBatchDataEntity, ) -> tuple[torch.Tensor] | list[InstanceData]: """Get the results from the head of the instance segmentation model. @@ -306,9 +305,9 @@ def get_results_from_head( from otx.algo.instance_segmentation.rtmdet_inst import RTMDetInstTiny if isinstance(self, RTMDetInstTiny): - return self.model.bbox_head.predict(x, data_samples, rescale=False) - rpn_results_list = self.model.rpn_head.predict(x, data_samples, rescale=False) - return self.model.roi_head.predict(x, rpn_results_list, data_samples, rescale=True) + return self.model.bbox_head.predict(x, entity, rescale=False) + rpn_results_list = self.model.rpn_head.predict(x, entity, rescale=False) + return self.model.roi_head.predict(x, rpn_results_list, entity, rescale=True) def get_explain_fn(self) -> Callable: """Returns explain function.""" diff --git a/src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml b/src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml index 96de6588121..e454cc82d88 100644 --- a/src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml +++ b/src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml @@ -28,66 +28,89 @@ engine: callback_monitor: val/map_50 -data: ../_base_/data/mmdet_base.yaml +data: ../_base_/data/torchvision_base.yaml overrides: max_epochs: 100 data: task: ROTATED_DETECTION config: + stack_images: true + data_format: coco_instances include_polygons: true train_subset: batch_size: 4 transforms: - - type: LoadImageFromFile - backend_args: null - - type: LoadAnnotations - with_bbox: true - with_mask: true - - type: Resize - keep_ratio: true - scale: - - 1024 - - 1024 - - type: RandomFlip - prob: 0.5 - - type: PackDetInputs + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + keep_ratio: true + transform_bbox: true + transform_mask: true + scale: + - 1024 + - 1024 + - class_path: otx.core.data.transform_libs.torchvision.Pad + init_args: + size_divisor: 32 + transform_mask: true + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 + is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [1.0, 1.0, 1.0] + sampler: + class_path: otx.algo.samplers.balanced_sampler.BalancedSampler val_subset: batch_size: 1 transforms: - - type: LoadImageFromFile - backend_args: null - - type: Resize - keep_ratio: true - scale: - - 1024 - - 1024 - - type: LoadAnnotations - with_bbox: true - with_mask: true - - type: PackDetInputs - meta_keys: - - img_id - - img_path - - ori_shape - - img_shape - - scale_factor + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + keep_ratio: true + transform_bbox: false + transform_mask: false + scale: + - 1024 + - 1024 + - class_path: otx.core.data.transform_libs.torchvision.Pad + init_args: + size_divisor: 32 + transform_mask: false + is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [1.0, 1.0, 1.0] test_subset: batch_size: 1 transforms: - - type: LoadImageFromFile - backend_args: null - - type: Resize - keep_ratio: true - scale: - - 1024 - - 1024 - - type: LoadAnnotations - with_bbox: true - with_mask: true - - type: PackDetInputs - meta_keys: - - img_id - - img_path - - ori_shape - - img_shape - - scale_factor + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + keep_ratio: true + transform_bbox: false + transform_mask: false + scale: + - 1024 + - 1024 + - class_path: otx.core.data.transform_libs.torchvision.Pad + init_args: + size_divisor: 32 + transform_mask: false + is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [1.0, 1.0, 1.0] diff --git a/src/otx/recipe/rotated_detection/maskrcnn_r50.yaml b/src/otx/recipe/rotated_detection/maskrcnn_r50.yaml index d7bd68e5eeb..aa40d055000 100644 --- a/src/otx/recipe/rotated_detection/maskrcnn_r50.yaml +++ b/src/otx/recipe/rotated_detection/maskrcnn_r50.yaml @@ -28,66 +28,87 @@ engine: callback_monitor: val/map_50 -data: ../_base_/data/mmdet_base.yaml +data: ../_base_/data/torchvision_base.yaml overrides: max_epochs: 100 data: task: ROTATED_DETECTION config: + stack_images: true + data_format: coco_instances include_polygons: true train_subset: batch_size: 4 transforms: - - type: LoadImageFromFile - backend_args: null - - type: LoadAnnotations - with_bbox: true - with_mask: true - - type: Resize - keep_ratio: true - scale: - - 1024 - - 1024 - - type: RandomFlip - prob: 0.5 - - type: PackDetInputs + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + keep_ratio: true + transform_bbox: true + transform_mask: true + scale: + - 1024 + - 1024 + - class_path: otx.core.data.transform_libs.torchvision.Pad + init_args: + size_divisor: 32 + transform_mask: true + - class_path: otx.core.data.transform_libs.torchvision.RandomFlip + init_args: + prob: 0.5 + is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] val_subset: batch_size: 1 transforms: - - type: LoadImageFromFile - backend_args: null - - type: Resize - keep_ratio: true - scale: - - 1024 - - 1024 - - type: LoadAnnotations - with_bbox: true - with_mask: true - - type: PackDetInputs - meta_keys: - - img_id - - img_path - - ori_shape - - img_shape - - scale_factor + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + keep_ratio: true + transform_bbox: false + transform_mask: false + scale: + - 1024 + - 1024 + - class_path: otx.core.data.transform_libs.torchvision.Pad + init_args: + size_divisor: 32 + transform_mask: false + is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] test_subset: batch_size: 1 transforms: - - type: LoadImageFromFile - backend_args: null - - type: Resize - keep_ratio: true - scale: - - 1024 - - 1024 - - type: LoadAnnotations - with_bbox: true - with_mask: true - - type: PackDetInputs - meta_keys: - - img_id - - img_path - - ori_shape - - img_shape - - scale_factor + - class_path: otx.core.data.transform_libs.torchvision.Resize + init_args: + keep_ratio: true + transform_bbox: false + transform_mask: false + scale: + - 1024 + - 1024 + - class_path: otx.core.data.transform_libs.torchvision.Pad + init_args: + size_divisor: 32 + transform_mask: false + is_numpy_to_tvtensor: true + - class_path: torchvision.transforms.v2.ToDtype + init_args: + dtype: ${as_torch_dtype:torch.float32} + scale: False + - class_path: torchvision.transforms.v2.Normalize + init_args: + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] diff --git a/tests/unit/core/model/test_inst_segmentation.py b/tests/unit/core/model/test_inst_segmentation.py index f199a31fb7f..8da37a16221 100644 --- a/tests/unit/core/model/test_inst_segmentation.py +++ b/tests/unit/core/model/test_inst_segmentation.py @@ -26,11 +26,12 @@ def test_get_explain_fn(self, otx_model): explain_fn = otx_model.get_explain_fn() assert callable(explain_fn) - def test_forward_explain_inst_seg(self, otx_model, fxt_data_sample): - inputs = torch.randn(1, 3, 224, 224) + def test_forward_explain_inst_seg(self, otx_model, fxt_inst_seg_data_entity): + inputs = fxt_inst_seg_data_entity[2] + inputs.images = torch.randn(1, 3, 224, 224) otx_model.model.feature_vector_fn = feature_vector_fn otx_model.model.explain_fn = otx_model.get_explain_fn() - result = otx_model._forward_explain_inst_seg(otx_model.model, inputs, fxt_data_sample, mode="predict") + result = otx_model._forward_explain_inst_seg(otx_model.model, inputs, mode="predict") assert "predictions" in result assert "feature_vector" in result @@ -38,13 +39,12 @@ def test_forward_explain_inst_seg(self, otx_model, fxt_data_sample): def test_customize_inputs(self, otx_model, fxt_inst_seg_data_entity) -> None: output_data = otx_model._customize_inputs(fxt_inst_seg_data_entity[2]) - assert output_data is not None - assert "gt_instances" in output_data["data_samples"][-1] - assert "masks" in output_data["data_samples"][-1].gt_instances - assert output_data["data_samples"][-1].metainfo["pad_shape"] == output_data["inputs"].shape[-2:] + assert output_data["mode"] == "loss" + assert output_data["entity"] == fxt_inst_seg_data_entity[2] def test_forward_explain(self, otx_model, fxt_inst_seg_data_entity): inputs = fxt_inst_seg_data_entity[2] + inputs.images = [image.float() for image in inputs.images] otx_model.training = False otx_model.explain_mode = True outputs = otx_model.forward_explain(inputs) diff --git a/tests/unit/engine/utils/test_api.py b/tests/unit/engine/utils/test_api.py index 7fc8dd9cc48..02e35cc18b2 100644 --- a/tests/unit/engine/utils/test_api.py +++ b/tests/unit/engine/utils/test_api.py @@ -33,7 +33,6 @@ def test_list_models_pattern() -> None: "efficientnet_v2", "maskrcnn_efficientnetb2b", "maskrcnn_efficientnetb2b_tile", - "maskrcnn_efficientnetb2b_tv", "tv_efficientnet_b3", "tv_efficientnet_v2_l", ]