Skip to content

Commit

Permalink
fix unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene123tw committed May 10, 2024
1 parent 0c33d44 commit 0fc7de9
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 117 deletions.
25 changes: 12 additions & 13 deletions src/otx/core/model/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@

import numpy as np
import torch
from mmengine.structures.instance_data import InstanceData
from model_api.models import Model
from model_api.tilers import InstanceSegmentationTiler
from torchvision import tv_tensors

from otx.algo.explain.explain_algo import InstSegExplainAlgo, feature_vector_fn
from otx.algo.instance_segmentation.mmdet.models.detectors.two_stage import TwoStageDetector
from otx.algo.utils.mmengine_utils import InstanceData
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity
Expand All @@ -35,8 +36,6 @@
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from mmdet.models.data_preprocessors import DetDataPreprocessor
from mmdet.models.detectors import TwoStageDetector
from mmdet.structures import OptSampleList
from model_api.models.utils import InstanceSegmentationResult
from omegaconf import DictConfig
from torch import nn
Expand Down Expand Up @@ -257,8 +256,7 @@ def forward_explain(self, inputs: InstanceSegBatchDataEntity) -> InstanceSegBatc
@staticmethod
def _forward_explain_inst_seg(
self: TwoStageDetector,
inputs: torch.Tensor,
data_samples: OptSampleList = None,
entity: InstanceSegBatchDataEntity,
mode: str = "tensor", # noqa: ARG004
) -> dict[str, torch.Tensor]:
"""Forward func of the BaseDetector instance, which located in is in ExplainableOTXInstanceSegModel().model."""
Expand All @@ -267,20 +265,21 @@ def _forward_explain_inst_seg(
for param in self.parameters():
param.requires_grad = False

x = self.extract_feat(inputs)
x = self.extract_feat(entity.images)

feature_vector = self.feature_vector_fn(x)
predictions = self.get_results_from_head(x, data_samples)
predictions = self.get_results_from_head(x, entity)

if isinstance(predictions, tuple) and isinstance(predictions[0], torch.Tensor):
# Export case, consists of tensors
# For OV task saliency map are generated on MAPI side
saliency_map = torch.empty(1, dtype=torch.uint8)

elif isinstance(predictions, list) and isinstance(predictions[0], InstanceData):
# Predict case, consists of InstanceData
saliency_map = self.explain_fn(predictions)
predictions = self.add_pred_to_datasample(data_samples, predictions)
else:
msg = f"Unexpected predictions type: {type(predictions)}"
raise TypeError(msg)

return {
"predictions": predictions,
Expand All @@ -291,7 +290,7 @@ def _forward_explain_inst_seg(
def get_results_from_head(
self,
x: tuple[torch.Tensor],
data_samples: OptSampleList | None,
entity: InstanceSegBatchDataEntity,
) -> tuple[torch.Tensor] | list[InstanceData]:
"""Get the results from the head of the instance segmentation model.
Expand All @@ -306,9 +305,9 @@ def get_results_from_head(
from otx.algo.instance_segmentation.rtmdet_inst import RTMDetInstTiny

if isinstance(self, RTMDetInstTiny):
return self.model.bbox_head.predict(x, data_samples, rescale=False)
rpn_results_list = self.model.rpn_head.predict(x, data_samples, rescale=False)
return self.model.roi_head.predict(x, rpn_results_list, data_samples, rescale=True)
return self.model.bbox_head.predict(x, entity, rescale=False)
rpn_results_list = self.model.rpn_head.predict(x, entity, rescale=False)
return self.model.roi_head.predict(x, rpn_results_list, entity, rescale=True)

def get_explain_fn(self) -> Callable:
"""Returns explain function."""
Expand Down
119 changes: 71 additions & 48 deletions src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,66 +28,89 @@ engine:

callback_monitor: val/map_50

data: ../_base_/data/mmdet_base.yaml
data: ../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 100
data:
task: ROTATED_DETECTION
config:
stack_images: true
data_format: coco_instances
include_polygons: true
train_subset:
batch_size: 4
transforms:
- type: LoadImageFromFile
backend_args: null
- type: LoadAnnotations
with_bbox: true
with_mask: true
- type: Resize
keep_ratio: true
scale:
- 1024
- 1024
- type: RandomFlip
prob: 0.5
- type: PackDetInputs
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: true
transform_bbox: true
transform_mask: true
scale:
- 1024
- 1024
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
transform_mask: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
scale: False
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [1.0, 1.0, 1.0]
sampler:
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler
val_subset:
batch_size: 1
transforms:
- type: LoadImageFromFile
backend_args: null
- type: Resize
keep_ratio: true
scale:
- 1024
- 1024
- type: LoadAnnotations
with_bbox: true
with_mask: true
- type: PackDetInputs
meta_keys:
- img_id
- img_path
- ori_shape
- img_shape
- scale_factor
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: true
transform_bbox: false
transform_mask: false
scale:
- 1024
- 1024
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
transform_mask: false
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
scale: False
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [1.0, 1.0, 1.0]
test_subset:
batch_size: 1
transforms:
- type: LoadImageFromFile
backend_args: null
- type: Resize
keep_ratio: true
scale:
- 1024
- 1024
- type: LoadAnnotations
with_bbox: true
with_mask: true
- type: PackDetInputs
meta_keys:
- img_id
- img_path
- ori_shape
- img_shape
- scale_factor
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: true
transform_bbox: false
transform_mask: false
scale:
- 1024
- 1024
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
transform_mask: false
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
scale: False
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [1.0, 1.0, 1.0]
117 changes: 69 additions & 48 deletions src/otx/recipe/rotated_detection/maskrcnn_r50.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,66 +28,87 @@ engine:

callback_monitor: val/map_50

data: ../_base_/data/mmdet_base.yaml
data: ../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 100
data:
task: ROTATED_DETECTION
config:
stack_images: true
data_format: coco_instances
include_polygons: true
train_subset:
batch_size: 4
transforms:
- type: LoadImageFromFile
backend_args: null
- type: LoadAnnotations
with_bbox: true
with_mask: true
- type: Resize
keep_ratio: true
scale:
- 1024
- 1024
- type: RandomFlip
prob: 0.5
- type: PackDetInputs
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: true
transform_bbox: true
transform_mask: true
scale:
- 1024
- 1024
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
transform_mask: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
scale: False
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
val_subset:
batch_size: 1
transforms:
- type: LoadImageFromFile
backend_args: null
- type: Resize
keep_ratio: true
scale:
- 1024
- 1024
- type: LoadAnnotations
with_bbox: true
with_mask: true
- type: PackDetInputs
meta_keys:
- img_id
- img_path
- ori_shape
- img_shape
- scale_factor
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: true
transform_bbox: false
transform_mask: false
scale:
- 1024
- 1024
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
transform_mask: false
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
scale: False
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
test_subset:
batch_size: 1
transforms:
- type: LoadImageFromFile
backend_args: null
- type: Resize
keep_ratio: true
scale:
- 1024
- 1024
- type: LoadAnnotations
with_bbox: true
with_mask: true
- type: PackDetInputs
meta_keys:
- img_id
- img_path
- ori_shape
- img_shape
- scale_factor
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: true
transform_bbox: false
transform_mask: false
scale:
- 1024
- 1024
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
transform_mask: false
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
scale: False
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
14 changes: 7 additions & 7 deletions tests/unit/core/model/test_inst_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,25 @@ def test_get_explain_fn(self, otx_model):
explain_fn = otx_model.get_explain_fn()
assert callable(explain_fn)

def test_forward_explain_inst_seg(self, otx_model, fxt_data_sample):
inputs = torch.randn(1, 3, 224, 224)
def test_forward_explain_inst_seg(self, otx_model, fxt_inst_seg_data_entity):
inputs = fxt_inst_seg_data_entity[2]
inputs.images = torch.randn(1, 3, 224, 224)
otx_model.model.feature_vector_fn = feature_vector_fn
otx_model.model.explain_fn = otx_model.get_explain_fn()
result = otx_model._forward_explain_inst_seg(otx_model.model, inputs, fxt_data_sample, mode="predict")
result = otx_model._forward_explain_inst_seg(otx_model.model, inputs, mode="predict")

assert "predictions" in result
assert "feature_vector" in result
assert "saliency_map" in result

def test_customize_inputs(self, otx_model, fxt_inst_seg_data_entity) -> None:
output_data = otx_model._customize_inputs(fxt_inst_seg_data_entity[2])
assert output_data is not None
assert "gt_instances" in output_data["data_samples"][-1]
assert "masks" in output_data["data_samples"][-1].gt_instances
assert output_data["data_samples"][-1].metainfo["pad_shape"] == output_data["inputs"].shape[-2:]
assert output_data["mode"] == "loss"
assert output_data["entity"] == fxt_inst_seg_data_entity[2]

def test_forward_explain(self, otx_model, fxt_inst_seg_data_entity):
inputs = fxt_inst_seg_data_entity[2]
inputs.images = [image.float() for image in inputs.images]
otx_model.training = False
otx_model.explain_mode = True
outputs = otx_model.forward_explain(inputs)
Expand Down
1 change: 0 additions & 1 deletion tests/unit/engine/utils/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def test_list_models_pattern() -> None:
"efficientnet_v2",
"maskrcnn_efficientnetb2b",
"maskrcnn_efficientnetb2b_tile",
"maskrcnn_efficientnetb2b_tv",
"tv_efficientnet_b3",
"tv_efficientnet_v2_l",
]
Expand Down

0 comments on commit 0fc7de9

Please sign in to comment.