Skip to content

Commit

Permalink
Reorganize files
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscyyeung committed Dec 6, 2021
1 parent 7f6617c commit 325b603
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 387 deletions.
2 changes: 1 addition & 1 deletion code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import argparse
import numpy as np
import tensorflow as tf
from model2 import Model
from model import Model


def get_parser():
Expand Down
172 changes: 90 additions & 82 deletions code/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,15 @@
from medpy import metric
from skimage.measure import label

import dataloader
import generator
from dataloader import LAHeart
from losses import DTCLoss
from vnet import VNet


class Model:
def __init__(self, config):
self.config = config
self.train_iterator = None
self.test_iterator = None
self.network = None
self.loss_fn = None
self.optimizer = None
self.current_iter = 0

def read_config(self):
print(f"{datetime.datetime.now()}: Reading configuration file...")
Expand All @@ -38,7 +33,7 @@ def read_config(self):

# model settings
self.input_shape = self.config["TrainingSettings"]["InputShape"]
self.epochs = self.config["TrainingSettings"]["Epochs"]
self.iterations = self.config["TrainingSettings"]["Iterations"]
self.batch_size = self.config["TrainingSettings"]["BatchSize"]
self.labeled_bs = self.config["TrainingSettings"]["LabeledBatchSize"]
self.num_labeled = self.config["TrainingSettings"]["NumberLabeledImages"]
Expand All @@ -65,20 +60,6 @@ def read_config(self):

print(f"{datetime.datetime.now()}: Reading configuration file complete.")

@tf.function
def train_step(self, next_element, epoch):
label = next_element[1]
with tf.GradientTape() as tape:
# get predictions
pred_tanh, pred = self.network(next_element[0])
# calculate loss
self.loss_fn.set_true(label[..., 0])
self.loss_fn.set_epoch(epoch)
loss = self.loss_fn(pred[..., 0], pred_tanh[..., 0])
grads = tape.gradient(loss, self.network.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.network.trainable_variables))
return loss

def train(self):
# read config file
self.read_config()
Expand All @@ -94,99 +75,97 @@ def train(self):
if pipeline["preprocess"]["train"] is not None:
for transform in pipeline["preprocess"]["train"]:
try:
tfm_class = getattr(dataloader, transform["name"])(*[], **transform["variables"])
tfm_class = getattr(generator, transform["name"])(*[], **transform["variables"])
except KeyError:
tfm_class = getattr(dataloader, transform["name"])()
tfm_class = getattr(generator, transform["name"])()
train_transforms.append(tfm_class)

# instantiate dataset iterator
self.train_iterator = dataloader.LAHeart(
train_generator = generator.DataGenerator(
data_dir=self.data_dir,
transforms=train_transforms,
train=True,
image_dims=self.input_shape[0:3],
batch_size=self.batch_size,
labeled_bs=self.labeled_bs,
num_labeled=self.num_labeled
)

# instantiate VNet model, loss function, optimizer
self.network = VNet(self.input_shape)
self.loss_fn = DTCLoss(
# instantiate VNet model, loss function, optimizer, LR decay
network = VNet(self.input_shape)
loss_fn = DTCLoss(
self.sigmoid_k,
self.beta,
self.consistency,
self.consistency_rampup,
self.consistency_interval,
self.labeled_bs
)
self.optimizer = tfa.optimizers.SGDW(
self.weight_decay,
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
self.initial_learning_rate,
decay_steps=self.lr_decay_interval,
decay_rate=self.learning_rate_decay,
staircase=True
)
optimizer = tfa.optimizers.SGDW(
self.weight_decay,
lr_schedule,
self.momentum
)

# train the network
max_epochs = round(self.epochs / len(self.train_iterator)) # number of passes through dataset
print(f"{datetime.datetime.now()}: Beginning training...")
for epoch in tqdm.tqdm(range(max_epochs + 1)):
for sampled_batch in self.train_iterator():
current_iter = tf.convert_to_tensor(self.current_iter, dtype=tf.int64)
loss = self.train_step(sampled_batch, current_iter)
self.current_iter += 1
print(f"{datetime.datetime.now()}: Epoch {self.current_iter}: loss: {loss}")

# adjust learning rate
if self.current_iter % self.lr_decay_interval == 0:
new_lr = self.initial_learning_rate * self.learning_rate_decay ** \
(self.current_iter // self.lr_decay_interval)
self.optimizer.lr.assign(new_lr)
print(f"{datetime.datetime.now()}: Learning rate decayed to {new_lr}")

# break when 6000 iterations reached
if self.current_iter >= self.epochs:
# optimizer = tf.keras.optimizers.SGD(
# learning_rate=lr_schedule,
# momentum=self.momentum
# )

@tf.function
def train_step(image, label, epoch):
with tf.GradientTape() as tape:
out_seg, out_tanh = network(image)
loss_fn.set_epoch(epoch)
loss = loss_fn(label, out_seg, out_tanh)
grads = tape.gradient(loss, network.trainable_weights)
optimizer.apply_gradients(zip(grads, network.trainable_weights))
return loss

# train model
epoch_size = len(train_generator)
epochs = self.iterations // epoch_size + 1
current_iter = tf.convert_to_tensor(0, dtype=tf.int64)
for epoch in tqdm.tqdm(range(epochs)):
for batch_idx in range(epoch_size):
image, label = train_generator[batch_idx]
loss = train_step(image, label, current_iter)
current_iter += 1
print(f"{datetime.datetime.now()}: Epoch {current_iter}: loss: {loss}")

# stop loop when 6000 iterations reached
if current_iter >= self.iterations:
break
else:
# get new batch of images
print("resetting indexes")
train_generator.on_epoch_end()
continue
break

# save model
if not os.path.isdir(self.model_save_dir):
Path(self.model_save_dir).mkdir(exist_ok=True)
complete_model_save_path = os.path.join(self.model_save_dir, f"DTC_{self.num_labeled}_labels")
self.network.save(complete_model_save_path)
network.save(complete_model_save_path)
print(f"{datetime.datetime.now()}: Trained model saved to {complete_model_save_path}.")

def test(self):
# read config file
self.read_config()

# get images/labels and apply data augmentation
with tf.device("/cpu:0"):
print(f"{datetime.datetime.now()}: Loading images and applying transformations...")
# load pipeline from yaml
with open(os.path.join(os.path.dirname(os.getcwd()), self.training_pipeline), "r") as f:
pipeline = yaml.load(f, Loader=yaml.FullLoader)

# get list of transforms
test_transforms = []
if pipeline["preprocess"]["test"] is not None:
for transform in pipeline["preprocess"]["test"]:
try:
tfm_class = getattr(dataloader, transform["name"])(*[], **transform["variables"])
except KeyError:
tfm_class = getattr(dataloader, transform["name"])()
test_transforms.append(tfm_class)

self.test_iterator = dataloader.LAHeart(
data_dir=self.data_dir,
transforms=test_transforms,
train=False,
)

print(f"{datetime.datetime.now()}: Image loading and transformation complete.")

# load saved model
complete_model_save_path = os.path.join(self.model_save_dir, f"DTC_{self.num_labeled}_labels")
self.network = tf.keras.models.load_model(complete_model_save_path)
print(f"{datetime.datetime.now()}: Model loaded from {complete_model_save_path}.")

self.test_iterator = LAHeart(
self.data_dir
)
avg_metric = self.test_all_cases(
self.test_iterator(),
stride_xy=self.stride_xy,
Expand All @@ -210,11 +189,11 @@ def test_all_cases(self, test_iterator, num_classes=1, patch_size=(112, 112, 80)
single_metric = self.calculate_metric_per_case(pred, label[:].numpy())

if metric_detail:
print(f"Image number: {idx}\n, "
f"\tDice Coefficient: {single_metric[0]:.5f}\n, "
f"\tJaccard: {single_metric[1]:.5f}\n, "
f"\t95% Hausdorff Distance (HD): {single_metric[2]:.5f}\n, "
f"\tAverage Surface Distance (ASD): {single_metric[3]:.5f}\n, ")
print(f"Image number: {idx},\n "
f"\tDice Coefficient: {single_metric[0]:.5f},\n "
f"\tJaccard: {single_metric[1]:.5f},\n "
f"\t95% Hausdorff Distance (HD): {single_metric[2]:.5f},\n "
f"\tAverage Surface Distance (ASD): {single_metric[3]:.5f}")

total_metric += np.asarray(single_metric)

Expand Down Expand Up @@ -288,6 +267,35 @@ def calculate_metric_per_case(pred, gt):
return dice, jc, hd, asd


class EpochCallback(tf.keras.callbacks.Callback):
"""Class to get the current epoch during training."""
def __init__(self, loss):
self.loss = loss
self.epoch = 0

def on_epoch_begin(self, epoch, logs=None):
if self.epoch < 1:
self.epoch = epoch # starts at 0

def on_train_batch_begin(self, batch, logs=None):
self.epoch += 1
self.loss.set_epoch(self.epoch)


class MaxIterationsCallback(tf.keras.callbacks.Callback):
"""Class to stop training when max iterations are reached."""
def __init__(self, max_iters):
self.max_iters = max_iters
self.batch = None

def on_batch_end(self, batch, logs=None):
self.batch = batch + 1

def on_epoch_end(self, epoch, logs=None):
if (epoch + 1) * self.batch >= self.max_iters:
self.model.stop_training = True


if __name__ == '__main__':
with open(os.path.join(os.path.dirname(os.getcwd()), "configs/config.json"), "r") as config_json:
config = json.load(config_json)
Expand Down
Loading

0 comments on commit 325b603

Please sign in to comment.