diff --git a/UltrasoundSegmentation/train.py b/UltrasoundSegmentation/train.py index 7e5ec66..c96b8db 100644 --- a/UltrasoundSegmentation/train.py +++ b/UltrasoundSegmentation/train.py @@ -48,12 +48,13 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--train-data-folder", type=str) parser.add_argument("--val-data-folder", type=str) - parser.add_argument("--config-file", type=str) parser.add_argument("--output-dir", type=str) + parser.add_argument("--config-file", type=str, default="train_config.yaml") parser.add_argument("--wandb-project-name", type=str, default="aigt_ultrasound_segmentation") parser.add_argument("--wandb-exp-name", type=str) parser.add_argument("--log-level", type=str, default="INFO") - parser.add_argument("--log-file", type=str) + parser.add_argument("--save-log", action="store_true") + parser.add_argument("--save-ckpt-freq", type=int, default=0) try: return parser.parse_args() except SystemExit as err: @@ -62,42 +63,10 @@ def parse_args(): def main(args): - # Make sure output folder exists and save a copy of the configuration file - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - - # Set up logging into file or console - if args.log_file is not None: - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - log_file = os.path.join(args.output_dir, args.log_file) - logging.basicConfig(filename=log_file, filemode="w", level=args.log_level) - print(f"Logging to file {log_file}.") - else: - logging.basicConfig(stream=sys.stdout, level=logging.INFO) # Log to console - - # If config file is not given, use default train_config.yaml. - # If config file is given but not an absolute path, assume it is in the same folder as train.py. - - if args.config_file is None: - args.config_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), "train_config.yaml") - else: - if not os.path.isabs(args.config_file): - args.config_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), args.config_file) - - # Load config file and save a copy in the output folder where the trained model will be saved - + # Load config file with open(args.config_file, "r") as f: config = yaml.safe_load(f) - - with open(os.path.join(args.output_dir, "train_config.yaml"), "w") as f: - yaml.dump(config, f) - - # Set up device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - logging.info(f"Using device {device}.") - # Initialize Weights & Biases wandb.login() timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") @@ -120,6 +89,44 @@ def main(args): # Log values of the config dictionary on Weights & Biases wandb.config.update(config) + # Create directory for run output + run_dir = os.path.join(args.output_dir, experiment_name) + os.makedirs(run_dir, exist_ok=True) + + # Remove loggers from other libraries + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + # Set up logging to console + log_formatter = logging.Formatter("%(asctime)s [%(levelname)-8s] %(message)s") + root_logger = logging.getLogger() + root_logger.setLevel(args.log_level) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(log_formatter) + root_logger.addHandler(console_handler) + # Log to file if specified + if args.save_log: + log_file = os.path.join(run_dir, "train.log") + file_handler = logging.FileHandler(log_file, mode="w") + file_handler.setFormatter(log_formatter) + root_logger.addHandler(file_handler) + logging.info(f"Logging to file {log_file}.") + + # Save copy of config file in run folder + config_copy_path = os.path.join(run_dir, "train_config.yaml") + with open(config_copy_path, "w") as f: + yaml.dump(config, f) + logging.info(f"Saved config file to {config_copy_path}.") + + # Create checkpoints folder if needed + if args.save_ckpt_freq > 0: + ckpt_dir = os.path.join(run_dir, "ckpts") + os.makedirs(ckpt_dir, exist_ok=True) + + # Set up device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logging.info(f"Using device {device}.") + # Set seed for reproducibility random.seed(config["seed"]) np.random.seed(config["seed"]) @@ -166,7 +173,7 @@ def main(args): train_dataloader = DataLoader( train_dataset, batch_size=config["batch_size"], - shuffle=True, + shuffle=False, generator=g ) val_dataset = UltrasoundDataset(args.val_data_folder, transform=val_transform) @@ -340,22 +347,20 @@ def main(args): "examples": wandb.Image(fig)}) plt.close(fig) - - # Save model after every Nth epoch as specified in the config file - if config["save_frequency"]>0 and (epoch + 1) % config["save_frequency"] == 0: - torch.save(model.state_dict(), os.path.join(args.output_dir, f"model_{epoch+1}.pt")) - # Save final trained model in a self-contained way. - # Generate a filename for the saved model that contains the run ID so that we can easily find the model corresponding to a given run. - - # model_filename = f"model_{run.id}.pt" - # torch.save(model.state_dict(), os.path.join(args.output_dir, model_filename)) + # Save model checkpoint (if not the last epoch) + if (args.save_ckpt_freq > 0 + and (epoch + 1) % args.save_ckpt_freq == 0 + and (epoch + 1) < config["num_epochs"]): + ckpt_model_path = os.path.join(ckpt_dir, f"model_{epoch:03d}.pt") + torch.save(model.state_dict(), ckpt_model_path) + logging.info(f"Saved model checkpoint to {ckpt_model_path}") # Save the final model also under the name "model.pt" so that we can easily find it later. # This is useful if we want to use the model for inference without having to specify the model filename. - - model_filename = "model.pt" - torch.save(model.state_dict(), os.path.join(args.output_dir, model_filename)) + model_path = os.path.join(run_dir, "model.pt") + torch.save(model.state_dict(), model_path) + logging.info(f"Saved model to {model_path}.") run.finish() diff --git a/UltrasoundSegmentation/train_config.yaml b/UltrasoundSegmentation/train_config.yaml index b2eea32..31363c0 100644 --- a/UltrasoundSegmentation/train_config.yaml +++ b/UltrasoundSegmentation/train_config.yaml @@ -4,8 +4,7 @@ model_name: "monai_unet" loss_function: "monai_dice" in_channels: !!int 1 out_channels: !!int 2 -save_frequency: !!int 0 -num_epochs: !!int 40 +num_epochs: !!int 10 batch_size: !!int 32 learning_rate: !!float 0.004 learning_rate_decay_factor: !!float 0.5