Skip to content

Commit

Permalink
Remove mmengine from semantic segmentation task (#3424)
Browse files Browse the repository at this point in the history
* remove unnecessery test

* remove mmengine

* fix pre-commit
  • Loading branch information
kprokofi authored Apr 30, 2024
1 parent 3ff1032 commit d75c35b
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 13 deletions.
7 changes: 0 additions & 7 deletions src/otx/algo/segmentation/backbones/litehrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
import torch.utils.checkpoint as cp
from mmengine.utils import is_tuple_of
from torch import nn
from torch.nn import functional

Expand Down Expand Up @@ -153,9 +152,6 @@ def __init__(
if len(act_cfg) != 2:
msg = "act_cfg must be a dict or a tuple of dicts of length 2."
raise ValueError(msg)
if not is_tuple_of(act_cfg, dict):
msg = "act_cfg must be a dict or a tuple of dicts."
raise TypeError(msg)

self.channels = channels
total_channel = sum(channels)
Expand Down Expand Up @@ -226,9 +222,6 @@ def __init__(
if len(act_cfg) != 2:
msg = "act_cfg must be a dict or a tuple of dicts of length 2."
raise ValueError(msg)
if not is_tuple_of(act_cfg, dict):
msg = "act_cfg must be a dict or a tuple of dicts."
raise TypeError(msg)

self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.conv1 = ConvModule(
Expand Down
5 changes: 1 addition & 4 deletions src/otx/algo/segmentation/heads/ham_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import torch
import torch.nn.functional as f
from mmengine.device import get_device
from torch import nn

from otx.algo.modules import ConvModule
Expand Down Expand Up @@ -223,7 +222,7 @@ def _build_bases(
segments: int,
channels: int,
basis_vectors: int,
device: torch.device | None = None,
device: torch.device,
) -> torch.Tensor:
"""Build bases in initialization.
Expand All @@ -237,8 +236,6 @@ def _build_bases(
Returns:
torch.Tensor: Tensor of shape (batch_size * segments, channels, basis_vectors) containing the built bases.
"""
if device is None:
device = get_device()
bases = torch.rand((batch_size * segments, channels, basis_vectors)).to(device)

return f.normalize(bases, dim=1)
Expand Down
5 changes: 4 additions & 1 deletion src/otx/recipe/_base_/data/mmseg_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ config:
size:
- 512
- 512
scale:
- 0.2
- 1.0
ratio:
- 0.5
- 2.0
antialias: True
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
- class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion
- class_path: torchvision.transforms.v2.RandomHorizontalFlip
init_args:
p: 0.5
Expand Down
5 changes: 4 additions & 1 deletion src/otx/recipe/semantic_segmentation/dino_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,14 @@ overrides:
size:
- 560
- 560
scale:
- 0.2
- 1.0
ratio:
- 0.5
- 2.0
antialias: True
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
- class_path: otx.core.data.transform_libs.torchvision.PhotoMetricDistortion
- class_path: torchvision.transforms.v2.RandomHorizontalFlip
init_args:
p: 0.5
Expand Down

0 comments on commit d75c35b

Please sign in to comment.