From 50be901a34769372db7ecabcf04b1d25f656eed8 Mon Sep 17 00:00:00 2001 From: Chris Yeung Date: Thu, 11 May 2023 19:10:45 -0400 Subject: [PATCH] Re #45: support transforms applied on only training data --- UltrasoundSegmentation/train.py | 26 +++++++++----- UltrasoundSegmentation/train_config.yaml | 46 ++++++++++++------------ 2 files changed, 42 insertions(+), 30 deletions(-) diff --git a/UltrasoundSegmentation/train.py b/UltrasoundSegmentation/train.py index 589e391..8d4ae6e 100644 --- a/UltrasoundSegmentation/train.py +++ b/UltrasoundSegmentation/train.py @@ -122,19 +122,29 @@ def main(args): g.manual_seed(config["seed"]) # Create transforms - transform_list = [] - if config["transforms"]: - for tfm in config["transforms"]: + train_transform_list = [] + val_transform_list = [] + if config["transforms"]["general"]: + for tfm in config["transforms"]["general"]: try: - transform_list.append(getattr(transforms, tfm["name"])(**tfm["params"])) + train_transform_list.append(getattr(transforms, tfm["name"])(**tfm["params"])) + val_transform_list.append(getattr(transforms, tfm["name"])(**tfm["params"])) except KeyError: # Apply transform to both image and label by default - transform_list.append(getattr(transforms, tfm["name"])(keys=["image", "label"])) - transform = Compose(transform_list) + train_transform_list.append(getattr(transforms, tfm["name"])(keys=["image", "label"])) + val_transform_list.append(getattr(transforms, tfm["name"])(keys=["image", "label"])) + if config["transforms"]["train"]: + for tfm in config["transforms"]["train"]: + try: + train_transform_list.append(getattr(transforms, tfm["name"])(**tfm["params"])) + except KeyError: + train_transform_list.append(getattr(transforms, tfm["name"])(keys=["image", "label"])) + train_transform = Compose(train_transform_list) + val_transform = Compose(val_transform_list) # Create dataloaders using UltrasoundDataset - train_dataset = UltrasoundDataset(args.train_data_folder, transform=transform) + train_dataset = UltrasoundDataset(args.train_data_folder, transform=train_transform) train_dataloader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=False, generator=g) - val_dataset = UltrasoundDataset(args.val_data_folder, transform=transform) + val_dataset = UltrasoundDataset(args.val_data_folder, transform=val_transform) val_dataloader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, generator=g) # Construct model diff --git a/UltrasoundSegmentation/train_config.yaml b/UltrasoundSegmentation/train_config.yaml index c491b92..20f35be 100644 --- a/UltrasoundSegmentation/train_config.yaml +++ b/UltrasoundSegmentation/train_config.yaml @@ -12,25 +12,27 @@ learning_rate_decay_factor: !!float 0.5 learning_rate_decay_frequency: !!int 10 seed: !!int 42 transforms: - # Basic transforms, do not modify - - name: "Transposed" - params: - keys: ["image", "label"] - indices: [2, 0, 1] - - name: "ToTensord" - - name: "EnsureTyped" - params: - keys: ["image", "label"] - dtype: "float32" - - # Add additional transforms here - - name: "Resized" - params: - keys: ["image", "label"] - spatial_size: [128, 128] - - name: "RandGaussianNoised" - params: - keys: ["image"] - prob: 0.15 - mean: 0.0 - std: 0.1 + general: + # Basic transforms, do not modify + - name: "Transposed" + params: + keys: ["image", "label"] + indices: [2, 0, 1] + - name: "ToTensord" + - name: "EnsureTyped" + params: + keys: ["image", "label"] + dtype: "float32" + + # Add additional transforms here + - name: "Resized" + params: + keys: ["image", "label"] + spatial_size: [128, 128] + train: + - name: "RandGaussianNoised" + params: + keys: ["image"] + prob: 0.15 + mean: 0.0 + std: 0.1