Skip to content

Commit

Permalink
Re #44: move save_frequency to command-line argument and reorganize o…
Browse files Browse the repository at this point in the history
…utput folder structure
  • Loading branch information
chriscyyeung committed May 18, 2023
1 parent c43526f commit f690be4
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 50 deletions.
101 changes: 53 additions & 48 deletions UltrasoundSegmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--train-data-folder", type=str)
parser.add_argument("--val-data-folder", type=str)
parser.add_argument("--config-file", type=str)
parser.add_argument("--output-dir", type=str)
parser.add_argument("--config-file", type=str, default="train_config.yaml")
parser.add_argument("--wandb-project-name", type=str, default="aigt_ultrasound_segmentation")
parser.add_argument("--wandb-exp-name", type=str)
parser.add_argument("--log-level", type=str, default="INFO")
parser.add_argument("--log-file", type=str)
parser.add_argument("--save-log", action="store_true")
parser.add_argument("--save-ckpt-freq", type=int, default=0)
try:
return parser.parse_args()
except SystemExit as err:
Expand All @@ -62,42 +63,10 @@ def parse_args():


def main(args):
# Make sure output folder exists and save a copy of the configuration file
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

# Set up logging into file or console
if args.log_file is not None:
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
log_file = os.path.join(args.output_dir, args.log_file)
logging.basicConfig(filename=log_file, filemode="w", level=args.log_level)
print(f"Logging to file {log_file}.")
else:
logging.basicConfig(stream=sys.stdout, level=logging.INFO) # Log to console

# If config file is not given, use default train_config.yaml.
# If config file is given but not an absolute path, assume it is in the same folder as train.py.

if args.config_file is None:
args.config_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), "train_config.yaml")
else:
if not os.path.isabs(args.config_file):
args.config_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), args.config_file)

# Load config file and save a copy in the output folder where the trained model will be saved

# Load config file
with open(args.config_file, "r") as f:
config = yaml.safe_load(f)

with open(os.path.join(args.output_dir, "train_config.yaml"), "w") as f:
yaml.dump(config, f)

# Set up device

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f"Using device {device}.")

# Initialize Weights & Biases
wandb.login()
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
Expand All @@ -120,6 +89,44 @@ def main(args):
# Log values of the config dictionary on Weights & Biases
wandb.config.update(config)

# Create directory for run output
run_dir = os.path.join(args.output_dir, experiment_name)
os.makedirs(run_dir, exist_ok=True)

# Remove loggers from other libraries
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)

# Set up logging to console
log_formatter = logging.Formatter("%(asctime)s [%(levelname)-8s] %(message)s")
root_logger = logging.getLogger()
root_logger.setLevel(args.log_level)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(log_formatter)
root_logger.addHandler(console_handler)
# Log to file if specified
if args.save_log:
log_file = os.path.join(run_dir, "train.log")
file_handler = logging.FileHandler(log_file, mode="w")
file_handler.setFormatter(log_formatter)
root_logger.addHandler(file_handler)
logging.info(f"Logging to file {log_file}.")

# Save copy of config file in run folder
config_copy_path = os.path.join(run_dir, "train_config.yaml")
with open(config_copy_path, "w") as f:
yaml.dump(config, f)
logging.info(f"Saved config file to {config_copy_path}.")

# Create checkpoints folder if needed
if args.save_ckpt_freq > 0:
ckpt_dir = os.path.join(run_dir, "ckpts")
os.makedirs(ckpt_dir, exist_ok=True)

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f"Using device {device}.")

# Set seed for reproducibility
random.seed(config["seed"])
np.random.seed(config["seed"])
Expand Down Expand Up @@ -166,7 +173,7 @@ def main(args):
train_dataloader = DataLoader(
train_dataset,
batch_size=config["batch_size"],
shuffle=True,
shuffle=False,
generator=g
)
val_dataset = UltrasoundDataset(args.val_data_folder, transform=val_transform)
Expand Down Expand Up @@ -340,22 +347,20 @@ def main(args):
"examples": wandb.Image(fig)})

plt.close(fig)

# Save model after every Nth epoch as specified in the config file
if config["save_frequency"]>0 and (epoch + 1) % config["save_frequency"] == 0:
torch.save(model.state_dict(), os.path.join(args.output_dir, f"model_{epoch+1}.pt"))

# Save final trained model in a self-contained way.
# Generate a filename for the saved model that contains the run ID so that we can easily find the model corresponding to a given run.

# model_filename = f"model_{run.id}.pt"
# torch.save(model.state_dict(), os.path.join(args.output_dir, model_filename))
# Save model checkpoint (if not the last epoch)
if (args.save_ckpt_freq > 0
and (epoch + 1) % args.save_ckpt_freq == 0
and (epoch + 1) < config["num_epochs"]):
ckpt_model_path = os.path.join(ckpt_dir, f"model_{epoch:03d}.pt")
torch.save(model.state_dict(), ckpt_model_path)
logging.info(f"Saved model checkpoint to {ckpt_model_path}")

# Save the final model also under the name "model.pt" so that we can easily find it later.
# This is useful if we want to use the model for inference without having to specify the model filename.

model_filename = "model.pt"
torch.save(model.state_dict(), os.path.join(args.output_dir, model_filename))
model_path = os.path.join(run_dir, "model.pt")
torch.save(model.state_dict(), model_path)
logging.info(f"Saved model to {model_path}.")
run.finish()


Expand Down
3 changes: 1 addition & 2 deletions UltrasoundSegmentation/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ model_name: "monai_unet"
loss_function: "monai_dice"
in_channels: !!int 1
out_channels: !!int 2
save_frequency: !!int 0
num_epochs: !!int 40
num_epochs: !!int 10
batch_size: !!int 32
learning_rate: !!float 0.004
learning_rate_decay_factor: !!float 0.5
Expand Down

0 comments on commit f690be4

Please sign in to comment.