Skip to content

Commit

Permalink
Re #45: support transforms applied on only training data
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscyyeung committed May 11, 2023
1 parent 650ca88 commit 50be901
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 30 deletions.
26 changes: 18 additions & 8 deletions UltrasoundSegmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,29 @@ def main(args):
g.manual_seed(config["seed"])

# Create transforms
transform_list = []
if config["transforms"]:
for tfm in config["transforms"]:
train_transform_list = []
val_transform_list = []
if config["transforms"]["general"]:
for tfm in config["transforms"]["general"]:
try:
transform_list.append(getattr(transforms, tfm["name"])(**tfm["params"]))
train_transform_list.append(getattr(transforms, tfm["name"])(**tfm["params"]))
val_transform_list.append(getattr(transforms, tfm["name"])(**tfm["params"]))
except KeyError: # Apply transform to both image and label by default
transform_list.append(getattr(transforms, tfm["name"])(keys=["image", "label"]))
transform = Compose(transform_list)
train_transform_list.append(getattr(transforms, tfm["name"])(keys=["image", "label"]))
val_transform_list.append(getattr(transforms, tfm["name"])(keys=["image", "label"]))
if config["transforms"]["train"]:
for tfm in config["transforms"]["train"]:
try:
train_transform_list.append(getattr(transforms, tfm["name"])(**tfm["params"]))
except KeyError:
train_transform_list.append(getattr(transforms, tfm["name"])(keys=["image", "label"]))
train_transform = Compose(train_transform_list)
val_transform = Compose(val_transform_list)

# Create dataloaders using UltrasoundDataset
train_dataset = UltrasoundDataset(args.train_data_folder, transform=transform)
train_dataset = UltrasoundDataset(args.train_data_folder, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=False, generator=g)
val_dataset = UltrasoundDataset(args.val_data_folder, transform=transform)
val_dataset = UltrasoundDataset(args.val_data_folder, transform=val_transform)
val_dataloader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, generator=g)

# Construct model
Expand Down
46 changes: 24 additions & 22 deletions UltrasoundSegmentation/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,27 @@ learning_rate_decay_factor: !!float 0.5
learning_rate_decay_frequency: !!int 10
seed: !!int 42
transforms:
# Basic transforms, do not modify
- name: "Transposed"
params:
keys: ["image", "label"]
indices: [2, 0, 1]
- name: "ToTensord"
- name: "EnsureTyped"
params:
keys: ["image", "label"]
dtype: "float32"

# Add additional transforms here
- name: "Resized"
params:
keys: ["image", "label"]
spatial_size: [128, 128]
- name: "RandGaussianNoised"
params:
keys: ["image"]
prob: 0.15
mean: 0.0
std: 0.1
general:
# Basic transforms, do not modify
- name: "Transposed"
params:
keys: ["image", "label"]
indices: [2, 0, 1]
- name: "ToTensord"
- name: "EnsureTyped"
params:
keys: ["image", "label"]
dtype: "float32"

# Add additional transforms here
- name: "Resized"
params:
keys: ["image", "label"]
spatial_size: [128, 128]
train:
- name: "RandGaussianNoised"
params:
keys: ["image"]
prob: 0.15
mean: 0.0
std: 0.1

0 comments on commit 50be901

Please sign in to comment.