diff --git a/src/otx/core/model/entity/classification.py b/src/otx/core/model/entity/classification.py index ec0a7aa25ee..63d8dc7fce2 100644 --- a/src/otx/core/model/entity/classification.py +++ b/src/otx/core/model/entity/classification.py @@ -12,7 +12,6 @@ import numpy as np import torch -from otx.algo.explain.explain_algo import feature_vector_fn from otx.core.data.dataset.classification import HLabelInfo from otx.core.data.entity.base import ( OTXBatchLossEntity, @@ -80,6 +79,8 @@ def forward_explain( inputs: T_OTXBatchDataEntity, ) -> T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity: """Model forward function.""" + from otx.algo.explain.explain_algo import feature_vector_fn + self.model.feature_vector_fn = feature_vector_fn self.model.explain_fn = self.get_explain_fn() @@ -139,6 +140,8 @@ def get_explain_fn(self) -> Callable: return explainer.func def _reset_model_forward(self) -> None: + from otx.algo.explain.explain_algo import feature_vector_fn + if not self.explain_mode: return diff --git a/src/otx/core/model/entity/instance_segmentation.py b/src/otx/core/model/entity/instance_segmentation.py index ed2f0e86ae6..12effcebcb2 100644 --- a/src/otx/core/model/entity/instance_segmentation.py +++ b/src/otx/core/model/entity/instance_segmentation.py @@ -18,7 +18,6 @@ from openvino.model_api.tilers import InstanceSegmentationTiler from torchvision import tv_tensors -from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo, feature_vector_fn from otx.core.config.data import TileConfig from otx.core.data.entity.base import ( OTXBatchLossEntity, @@ -139,6 +138,8 @@ def forward_explain( inputs: InstanceSegBatchDataEntity, ) -> InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI | OTXBatchLossEntity: """Model forward function.""" + from otx.algo.explain.explain_algo import feature_vector_fn + self.model.feature_vector_fn = feature_vector_fn self.model.explain_fn = self.get_explain_fn() @@ -196,6 +197,8 @@ def _forward_explain_inst_seg( def get_explain_fn(self) -> Callable: """Returns explain function.""" + from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo + explainer = MaskRCNNExplainAlgo(num_classes=self.num_classes) return explainer.func