Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Update DeepEdit - training on TotalSegmentator Dataset #1362

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions sample-apps/radiology/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,16 @@ A command example to use active learning strategies with DeepEdit would be:
```json
{
"spleen": 1,
"right kidney": 2,
"left kidney": 3,
"liver": 6,
"stomach": 7,
"aorta": 8,
"inferior vena cava": 9,
"kidney_right": 2,
"kidney_left": 3,
"liver": 5,
"stomach": 6,
"aorta": 7,
"inferior_vena_cava": 8,
"background": 0
}
```
- Dataset: The model is pre-trained over dataset: https://www.synapse.org/#!Synapse:syn3193805/wiki/217789
- Dataset: The model is pre-trained over Total Segmentator dataset: https://zenodo.org/record/6802614#.ZCStTI7MKCg

- Inputs
- 1 channel for the image modality -> Automated mode
Expand Down
89 changes: 21 additions & 68 deletions sample-apps/radiology/lib/configs/deepedit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import lib.infers
import lib.trainers
from monai.networks.nets import UNETR, DynUNet
from monai.networks.nets import SegResNet

from monailabel.interfaces.config import TaskConfig
from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType
Expand All @@ -41,12 +41,12 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
# Multilabel
self.labels = {
"spleen": 1,
"right kidney": 2,
"left kidney": 3,
"liver": 6,
"stomach": 7,
"aorta": 8,
"inferior vena cava": 9,
"kidney_right": 2,
"kidney_left": 3,
"liver": 5,
"stomach": 6,
"aorta": 7,
"inferior_vena_cava": 8,
"background": 0,
}

Expand All @@ -59,81 +59,34 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
# Number of input channels - 4 for BRATS and 1 for spleen
self.number_intensity_ch = 1

network = self.conf.get("network", "dynunet")

# Model Files
self.path = [
os.path.join(self.model_dir, f"pretrained_{self.name}_{network}.pt"), # pretrained
os.path.join(self.model_dir, f"{self.name}_{network}.pt"), # published
os.path.join(self.model_dir, f"pretrained_{self.name}_segresnet.pt"), # pretrained
os.path.join(self.model_dir, f"{self.name}_segresnet.pt"), # published
]

# Download PreTrained Model
if strtobool(self.conf.get("use_pretrained_model", "true")):
url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
url = f"{url}/radiology_deepedit_{network}_multilabel.pt"
url = f"{url}/radiology_deepedit_segresnet_multilabel.pt"
download_file(url, self.path[0])

self.target_spacing = (1.0, 1.0, 1.0) # target space for image
self.spatial_size = (128, 128, 128) # train input size

# Network
self.network = (
UNETR(
spatial_dims=3,
in_channels=len(self.labels) + self.number_intensity_ch,
out_channels=len(self.labels),
img_size=self.spatial_size,
feature_size=64,
hidden_size=1536,
mlp_dim=3072,
num_heads=48,
pos_embed="conv",
norm_name="instance",
res_block=True,
)
if network == "unetr"
else DynUNet(
spatial_dims=3,
in_channels=len(self.labels) + self.number_intensity_ch,
out_channels=len(self.labels),
kernel_size=[3, 3, 3, 3, 3, 3],
strides=[1, 2, 2, 2, 2, [2, 2, 1]],
upsample_kernel_size=[2, 2, 2, 2, [2, 2, 1]],
norm_name="instance",
deep_supervision=False,
res_block=True,
)
self.network = SegResNet(
spatial_dims=3,
in_channels=len(self.labels) + self.number_intensity_ch,
out_channels=len(self.labels),
init_filters=32,
blocks_down=(1, 2, 2, 4),
blocks_up=(1, 1, 1),
norm="batch",
dropout_prob=0.2,
)

self.network_with_dropout = (
UNETR(
spatial_dims=3,
in_channels=len(self.labels) + self.number_intensity_ch,
out_channels=len(self.labels),
img_size=self.spatial_size,
feature_size=64,
hidden_size=1536,
mlp_dim=3072,
num_heads=48,
pos_embed="conv",
norm_name="instance",
res_block=True,
dropout_rate=0.2,
)
if network == "unetr"
else DynUNet(
spatial_dims=3,
in_channels=len(self.labels) + self.number_intensity_ch,
out_channels=len(self.labels),
kernel_size=[3, 3, 3, 3, 3, 3],
strides=[1, 2, 2, 2, 2, [2, 2, 1]],
upsample_kernel_size=[2, 2, 2, 2, [2, 2, 1]],
norm_name="instance",
deep_supervision=False,
res_block=True,
dropout=0.2,
)
)
self.network_with_dropout = self.network

# Others
self.epistemic_enabled = strtobool(conf.get("epistemic_enabled", "false"))
Expand Down Expand Up @@ -162,7 +115,7 @@ def infer(self) -> Union[InferTask, Dict[str, InferTask]]:
}

def trainer(self) -> Optional[TrainTask]:
output_dir = os.path.join(self.model_dir, f"{self.name}_" + self.conf.get("network", "dynunet"))
output_dir = os.path.join(self.model_dir, f"{self.name}_" + self.conf.get("network", "segresnet"))
diazandr3s marked this conversation as resolved.
Show resolved Hide resolved
load_path = self.path[0] if os.path.exists(self.path[0]) else self.path[1]

task: TrainTask = lib.trainers.DeepEdit(
Expand Down
12 changes: 9 additions & 3 deletions sample-apps/radiology/lib/infers/deepedit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
AsDiscreted,
EnsureChannelFirstd,
EnsureTyped,
GaussianSmoothd,
KeepLargestConnectedComponentd,
LoadImaged,
NormalizeIntensityd,
Orientationd,
Resized,
ScaleIntensityRanged,
ScaleIntensityd,
SqueezeDimd,
ToNumpyd,
)
Expand Down Expand Up @@ -74,10 +77,12 @@ def __init__(

def pre_transforms(self, data=None):
t = [
LoadImaged(keys="image", reader="ITKReader"),
LoadImaged(keys="image"),
EnsureChannelFirstd(keys="image"),
Orientationd(keys="image", axcodes="RAS"),
ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
NormalizeIntensityd(keys="image", nonzero=True),
GaussianSmoothd(keys="image", sigma=0.4),
ScaleIntensityd(keys="image", minv=-1.0, maxv=1.0),
]

self.add_cache_transform(t, data)
Expand Down Expand Up @@ -117,6 +122,7 @@ def post_transforms(self, data=None) -> Sequence[Callable]:
EnsureTyped(keys="pred", device=data.get("device") if data else None),
Activationsd(keys="pred", softmax=True),
AsDiscreted(keys="pred", argmax=True),
KeepLargestConnectedComponentd(keys="pred"),
SqueezeDimd(keys="pred", dim=0),
ToNumpyd(keys="pred"),
Restored(keys="pred", ref_image="image"),
Expand Down
44 changes: 25 additions & 19 deletions sample-apps/radiology/lib/trainers/deepedit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@
Activationsd,
AsDiscreted,
EnsureChannelFirstd,
GaussianSmoothd,
LoadImaged,
NormalizeIntensityd,
Orientationd,
RandFlipd,
RandRotate90d,
RandShiftIntensityd,
Resized,
ScaleIntensityRanged,
ScaleIntensityd,
SelectItemsd,
ToNumpyd,
ToTensord,
Expand Down Expand Up @@ -100,17 +99,13 @@ def get_click_transforms(self, context: Context):

def train_pre_transforms(self, context: Context):
return [
LoadImaged(keys=("image", "label"), reader="ITKReader"),
LoadImaged(keys=("image", "label")),
EnsureChannelFirstd(keys=("image", "label")),
NormalizeLabelsInDatasetd(keys="label", label_names=self._labels),
Orientationd(keys=["image", "label"], axcodes="RAS"),
# This transform may not work well for MR images
ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
RandFlipd(keys=("image", "label"), spatial_axis=[0], prob=0.10),
RandFlipd(keys=("image", "label"), spatial_axis=[1], prob=0.10),
RandFlipd(keys=("image", "label"), spatial_axis=[2], prob=0.10),
RandRotate90d(keys=("image", "label"), prob=0.10, max_k=3),
RandShiftIntensityd(keys="image", offsets=0.10, prob=0.50),
NormalizeIntensityd(keys="image", nonzero=True),
GaussianSmoothd(keys="image", sigma=0.4),
ScaleIntensityd(keys="image", minv=-1.0, maxv=1.0),
Resized(keys=("image", "label"), spatial_size=self.spatial_size, mode=("area", "nearest")),
# Transforms for click simulation
FindAllValidSlicesMissingLabelsd(keys="label", sids="sids"),
Expand All @@ -134,12 +129,13 @@ def train_post_transforms(self, context: Context):

def val_pre_transforms(self, context: Context):
return [
LoadImaged(keys=("image", "label"), reader="ITKReader"),
LoadImaged(keys=("image", "label")),
EnsureChannelFirstd(keys=("image", "label")),
NormalizeLabelsInDatasetd(keys="label", label_names=self._labels),
Orientationd(keys=["image", "label"], axcodes="RAS"),
# This transform may not work well for MR images
ScaleIntensityRanged(keys=("image"), a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
NormalizeIntensityd(keys="image", nonzero=True),
GaussianSmoothd(keys="image", sigma=0.4),
ScaleIntensityd(keys="image", minv=-1.0, maxv=1.0),
Resized(keys=("image", "label"), spatial_size=self.spatial_size, mode=("area", "nearest")),
# Transforms for click simulation
FindAllValidSlicesMissingLabelsd(keys="label", sids="sids"),
Expand All @@ -153,13 +149,23 @@ def val_pre_transforms(self, context: Context):
def val_inferer(self, context: Context):
return SimpleInferer()

def norm_labels(self):
# This should be applied along with NormalizeLabelsInDatasetd transform
new_label_nums = {}
for idx, (key_label, val_label) in enumerate(self._labels.items(), start=1):
if key_label != "background":
new_label_nums[key_label] = idx
if key_label == "background":
new_label_nums["background"] = 0
return new_label_nums

def train_iteration_update(self, context: Context):
return Interaction(
deepgrow_probability=self.deepgrow_probability_train,
transforms=self.get_click_transforms(context),
click_probability_key="probability",
train=True,
label_names=self._labels,
label_names=self.norm_labels(),
)

def val_iteration_update(self, context: Context):
Expand All @@ -168,13 +174,13 @@ def val_iteration_update(self, context: Context):
transforms=self.get_click_transforms(context),
click_probability_key="probability",
train=False,
label_names=self._labels,
label_names=self.norm_labels(),
)

def train_key_metric(self, context: Context):
all_metrics = dict()
all_metrics["train_dice"] = MeanDice(output_transform=from_engine(["pred", "label"]), include_background=False)
for key_label in self._labels:
for key_label in self.norm_labels():
if key_label != "background":
all_metrics[key_label + "_dice"] = MeanDice(
output_transform=from_engine(["pred_" + key_label, "label_" + key_label]), include_background=False
Expand All @@ -186,7 +192,7 @@ def val_key_metric(self, context: Context):
all_metrics["val_mean_dice"] = MeanDice(
output_transform=from_engine(["pred", "label"]), include_background=False
)
for key_label in self._labels:
for key_label in self.norm_labels():
if key_label != "background":
all_metrics[key_label + "_dice"] = MeanDice(
output_transform=from_engine(["pred_" + key_label, "label_" + key_label]), include_background=False
Expand Down