Skip to content

Commit

Permalink
Fix circular import issue #3202 (#3203)
Browse files Browse the repository at this point in the history
Fix circular issues
  • Loading branch information
harimkang authored Mar 26, 2024
1 parent 7070028 commit 6683778
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/otx/core/model/entity/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import numpy as np
import torch

from otx.algo.explain.explain_algo import feature_vector_fn
from otx.core.data.dataset.classification import HLabelInfo
from otx.core.data.entity.base import (
OTXBatchLossEntity,
Expand Down Expand Up @@ -80,6 +79,8 @@ def forward_explain(
inputs: T_OTXBatchDataEntity,
) -> T_OTXBatchPredEntity | T_OTXBatchPredEntityWithXAI | OTXBatchLossEntity:
"""Model forward function."""
from otx.algo.explain.explain_algo import feature_vector_fn

self.model.feature_vector_fn = feature_vector_fn
self.model.explain_fn = self.get_explain_fn()

Expand Down Expand Up @@ -139,6 +140,8 @@ def get_explain_fn(self) -> Callable:
return explainer.func

def _reset_model_forward(self) -> None:
from otx.algo.explain.explain_algo import feature_vector_fn

if not self.explain_mode:
return

Expand Down
5 changes: 4 additions & 1 deletion src/otx/core/model/entity/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from openvino.model_api.tilers import InstanceSegmentationTiler
from torchvision import tv_tensors

from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo, feature_vector_fn
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import (
OTXBatchLossEntity,
Expand Down Expand Up @@ -139,6 +138,8 @@ def forward_explain(
inputs: InstanceSegBatchDataEntity,
) -> InstanceSegBatchPredEntity | InstanceSegBatchPredEntityWithXAI | OTXBatchLossEntity:
"""Model forward function."""
from otx.algo.explain.explain_algo import feature_vector_fn

self.model.feature_vector_fn = feature_vector_fn
self.model.explain_fn = self.get_explain_fn()

Expand Down Expand Up @@ -196,6 +197,8 @@ def _forward_explain_inst_seg(

def get_explain_fn(self) -> Callable:
"""Returns explain function."""
from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo

explainer = MaskRCNNExplainAlgo(num_classes=self.num_classes)
return explainer.func

Expand Down

0 comments on commit 6683778

Please sign in to comment.