Skip to content

Commit

Permalink
SwinUNETR refactor to accept additional parameters (#8212)
Browse files Browse the repository at this point in the history
### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [  ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.

---------

Signed-off-by: Eloi Navet [email protected]
Signed-off-by: Eloi Navet <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
EloiNavet and KumoLiu authored Nov 25, 2024
1 parent d94df3f commit 3ee4cd2
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import itertools
from collections.abc import Sequence
from typing import Final

import numpy as np
import torch
Expand Down Expand Up @@ -51,8 +50,6 @@ class SwinUNETR(nn.Module):
<https://arxiv.org/abs/2201.01266>"
"""

patch_size: Final[int] = 2

@deprecated_arg(
name="img_size",
since="1.3",
Expand All @@ -65,18 +62,24 @@ def __init__(
img_size: Sequence[int] | int,
in_channels: int,
out_channels: int,
patch_size: int = 2,
depths: Sequence[int] = (2, 2, 2, 2),
num_heads: Sequence[int] = (3, 6, 12, 24),
window_size: Sequence[int] | int = 7,
qkv_bias: bool = True,
mlp_ratio: float = 4.0,
feature_size: int = 24,
norm_name: tuple | str = "instance",
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
dropout_path_rate: float = 0.0,
normalize: bool = True,
norm_layer: type[LayerNorm] = nn.LayerNorm,
patch_norm: bool = True,
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
use_v2=False,
downsample: str | nn.Module = "merging",
use_v2: bool = False,
) -> None:
"""
Args:
Expand All @@ -86,14 +89,20 @@ def __init__(
It will be removed in an upcoming version.
in_channels: dimension of input channels.
out_channels: dimension of output channels.
patch_size: size of the patch token.
feature_size: dimension of network feature size.
depths: number of layers in each stage.
num_heads: number of attention heads.
window_size: local window size.
qkv_bias: add a learnable bias to query, key, value.
mlp_ratio: ratio of mlp hidden dim to embedding dim.
norm_name: feature normalization type and arguments.
drop_rate: dropout rate.
attn_drop_rate: attention dropout rate.
dropout_path_rate: drop path rate.
normalize: normalize output intermediate features in each stage.
norm_layer: normalization layer.
patch_norm: whether to apply normalization to the patch embedding.
use_checkpoint: use gradient checkpointing for reduced memory usage.
spatial_dims: number of spatial dims.
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
Expand All @@ -116,13 +125,15 @@ def __init__(

super().__init__()

img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
window_size = ensure_tuple_rep(7, spatial_dims)

if spatial_dims not in (2, 3):
raise ValueError("spatial dimension should be 2 or 3.")

self.patch_size = patch_size

img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
window_size = ensure_tuple_rep(window_size, spatial_dims)

self._check_input_size(img_size)

if not (0 <= drop_rate <= 1):
Expand All @@ -146,12 +157,13 @@ def __init__(
patch_size=patch_sizes,
depths=depths,
num_heads=num_heads,
mlp_ratio=4.0,
qkv_bias=True,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dropout_path_rate,
norm_layer=nn.LayerNorm,
norm_layer=norm_layer,
patch_norm=patch_norm,
use_checkpoint=use_checkpoint,
spatial_dims=spatial_dims,
downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
Expand Down

0 comments on commit 3ee4cd2

Please sign in to comment.