Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscyyeung committed Dec 5, 2021
1 parent 1389396 commit 0a56ad6
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 45 deletions.
38 changes: 27 additions & 11 deletions code/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __len__(self):
def __getitem__(self, index):
# generate one batch
indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
print(indexes)
image_dir_temp = [self.image_dir_list[idx] for idx in indexes]

# create emtpy arrays for all images in batch
Expand Down Expand Up @@ -129,9 +130,8 @@ class RandomCrop:
"""Randomly crops the image to the specified output_size. The output
size can be a tuple or an integer (for a cubic crop).
"""
def __init__(self, output_size, seed):
def __init__(self, output_size):
self.name = "RandomCrop"
self.seed = seed

assert isinstance(output_size, (int, tuple, list))
if isinstance(output_size, int):
Expand All @@ -143,8 +143,8 @@ def __init__(self, output_size, seed):
def __call__(self, sample):
image, label = sample["image"], sample["label"]

image = tf.image.random_crop(image, self.output_size, self.seed)
label = tf.image.random_crop(label, self.output_size, self.seed)
image = tf.image.random_crop(image, self.output_size)
label = tf.image.random_crop(label, self.output_size)

return {"image": image, "label": label}

Expand All @@ -166,20 +166,19 @@ def __call__(self, sample):

class RandomFlip:
"""Randomly flips the image in a sample along its x or y axis."""
def __init__(self, seed):
def __init__(self):
self.name = "RandomFlip"
self.seed = seed

def __call__(self, sample):
image, label = sample["image"], sample["label"]

axis = np.random.randint(2)
if axis: # flip along y axis
image = tf.image.random_flip_left_right(image, self.seed)
label = tf.image.random_flip_left_right(label, self.seed)
image = tf.image.random_flip_left_right(image)
label = tf.image.random_flip_left_right(label)
else: # flip along x axis
image = tf.image.random_flip_up_down(image, self.seed)
label = tf.image.random_flip_up_down(label, self.seed)
image = tf.image.random_flip_up_down(image)
label = tf.image.random_flip_up_down(label)

return {"image": image, "label": label}

Expand All @@ -203,4 +202,21 @@ def __call__(self, sample):
train_transforms.append(tfm_class)

dataset = DataGenerator(data_dir, transforms=train_transforms)
print(len(dataset.get_sample_indices()) == 8 * 4)
image, label = dataset[0]
print(image.shape, label.shape)

print(dataset.indexes)
for i in range(len(dataset)):
image, label = dataset[i]
print(image.shape, label.shape)

# import nibabel as nib
# nib.save(nib.Nifti1Image(image[0, ..., 0].astype(np.float32), np.eye(4)), "test_image_0.nii.gz")
# nib.save(nib.Nifti1Image(label[0].astype(np.float32), np.eye(4)), "test_label_0.nii.gz")
# nib.save(nib.Nifti1Image(image[1, ..., 0].astype(np.float32), np.eye(4)), "test_image_1.nii.gz")
# nib.save(nib.Nifti1Image(label[1].astype(np.float32), np.eye(4)), "test_label_1.nii.gz")

# from model2 import VNet
# network = VNet((112, 112, 80, 1), 0.0001)
# out_seg, out_tanh = network(image)
# print(out_seg.shape, out_tanh.shape)
5 changes: 3 additions & 2 deletions code/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ def __call__(self, y_true, y_pred, y_pred_tanh):
pred_soft = tf.keras.activations.sigmoid(y_pred) # convert to logits

# labeled predictions
pred_labeled = pred_soft[:self.labeled_bs]
pred_labeled = y_pred[:self.labeled_bs]
pred_soft_labeled = pred_soft[:self.labeled_bs]
pred_tanh_labeled = y_pred_tanh[:self.labeled_bs]
true_labeled = y_true[:self.labeled_bs]

# supervised loss (labeled images)
true_lsf = tf.py_function(compute_lsf_gt, [y_true[:], tf.shape(pred_labeled)], tf.float32)
loss_lsf = self.mse(true_lsf, pred_tanh_labeled)
loss_seg_dice = self.dice_loss(true_labeled == 1, pred_labeled)
loss_seg_dice = self.dice_loss(true_labeled == 1, pred_soft_labeled)
supervised_loss = loss_seg_dice + self.beta * loss_lsf

# unsupervised loss (no labels)
Expand Down
45 changes: 37 additions & 8 deletions code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,42 @@ def get_parser():
default="configs/config.json",
help="JSON file for model configuration"
)
args = parser.parse_args()
return args
parser.add_argument(
"--debug_mode",
type=bool,
choices=[True, False],
help="Enable TensorFlow debugger V2"
)
parser.add_argument(
"--dump_dir",
type=str,
default="/tmp/tfdbg2_logdir"
)
parser.add_argument(
"--dump_tensor_debug_mode",
type=str,
default="FULL_HEALTH"
)
parser.add_argument(
"--dump_circular_buffer_size",
type=int,
default=-1
)
return parser.parse_args()


def main(FLAGS):
# debugging
if FLAGS.debug_mode:
tf.debugging.experimental.enable_dump_debug_info(
FLAGS.dump_dir,
tensor_debug_mode=FLAGS.dump_tensor_debug_mode,
circular_buffer_size=FLAGS.dump_circular_buffer_size
)
tf.debugging.enable_check_numerics()

def main(args):
# load config file
with open(os.path.join(os.path.dirname(os.getcwd()), args.config_json), "r") as config_json:
with open(os.path.join(os.path.dirname(os.getcwd()), FLAGS.config_json), "r") as config_json:
config = json.load(config_json)

# set seeds
Expand All @@ -50,14 +79,14 @@ def main(args):

# run model
model = Model(config)
if args.phase == "train":
if FLAGS.phase == "train":
model.train()
elif args.phase == "test":
elif FLAGS.phase == "test":
model.test()
else:
sys.exit("Invalid training phase.")


if __name__ == '__main__':
args = get_parser()
main(args)
FLAGS = get_parser()
main(FLAGS)
38 changes: 20 additions & 18 deletions code/model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
class Model:
def __init__(self, config):
self.config = config
self.network = None
self.loss_fn = None
self.optimizer = None

def read_config(self):
print(f"{datetime.datetime.now()}: Reading configuration file...")
Expand Down Expand Up @@ -63,16 +60,6 @@ def read_config(self):

print(f"{datetime.datetime.now()}: Reading configuration file complete.")

@tf.function
def train_step(self, image, label, epoch):
with tf.GradientTape() as tape:
out_seg, out_tanh = self.network(image)
self.loss_fn.set_epoch(epoch)
loss = self.loss_fn(label, out_seg, out_tanh)
grads = tape.gradient(loss, self.network.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.network.trainable_weights))
return loss

def train(self):
# read config file
self.read_config()
Expand Down Expand Up @@ -104,8 +91,8 @@ def train(self):
)

# instantiate VNet model, loss function, optimizer, LR decay
self.network = VNet(self.input_shape)
self.loss_fn = DTCLoss(
network = VNet(self.input_shape)
loss_fn = DTCLoss(
self.sigmoid_k,
self.beta,
self.consistency,
Expand All @@ -119,11 +106,25 @@ def train(self):
decay_rate=self.learning_rate_decay,
staircase=True
)
self.optimizer = tfa.optimizers.SGDW(
optimizer = tfa.optimizers.SGDW(
self.weight_decay,
lr_schedule,
self.momentum
)
# optimizer = tf.keras.optimizers.SGD(
# learning_rate=lr_schedule,
# momentum=self.momentum
# )

@tf.function
def train_step(image, label, epoch):
with tf.GradientTape() as tape:
out_seg, out_tanh = network(image)
loss_fn.set_epoch(epoch)
loss = loss_fn(label, out_seg, out_tanh)
grads = tape.gradient(loss, network.trainable_weights)
optimizer.apply_gradients(zip(grads, network.trainable_weights))
return loss

# train model
epoch_size = len(train_generator)
Expand All @@ -132,7 +133,7 @@ def train(self):
for epoch in tqdm.tqdm(range(epochs)):
for batch_idx in range(epoch_size):
image, label = train_generator[batch_idx]
loss = self.train_step(image, label, current_iter)
loss = train_step(image, label, current_iter)
current_iter += 1
print(f"{datetime.datetime.now()}: Epoch {current_iter}: loss: {loss}")

Expand All @@ -141,6 +142,7 @@ def train(self):
break
else:
# get new batch of images
print("resetting indexes")
train_generator.on_epoch_end()
continue
break
Expand All @@ -149,7 +151,7 @@ def train(self):
if not os.path.isdir(self.model_save_dir):
Path(self.model_save_dir).mkdir(exist_ok=True)
complete_model_save_path = os.path.join(self.model_save_dir, f"DTC_{self.num_labeled}_labels")
self.network.save(complete_model_save_path)
network.save(complete_model_save_path)
print(f"{datetime.datetime.now()}: Trained model saved to {complete_model_save_path}.")

def test(self):
Expand Down
7 changes: 2 additions & 5 deletions configs/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ description: left atrial segmentation pipeline
modality: MRI
preprocess:
train:
- name: "RandomCrop"
variables:
output_size: [112, 112, 80]
seed: 2021
- name: "RandomRotation"
- name: "RandomFlip"
- name: "RandomCrop"
variables:
seed: 2021
output_size: [112, 112, 80]

test:
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ nibabel
numpy
medpy
scikit-image
tensorflow==2.5.0
tensorflow
tensorflow-addons

0 comments on commit 0a56ad6

Please sign in to comment.