From ca76b7a39ed3a27f811d8a9e704b3e109dc55477 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=B1=84=EC=95=84=5FT6043?= <48304130+2018007956@users.noreply.github.com> Date: Tue, 30 Jan 2024 01:01:39 +0900 Subject: [PATCH] refactor: separate files (#25) #21 --- code/optimizer.py | 4 ++++ code/scheduler.py | 8 ++++++++ code/train.py | 23 ++++++++--------------- 3 files changed, 20 insertions(+), 15 deletions(-) create mode 100644 code/optimizer.py create mode 100644 code/scheduler.py diff --git a/code/optimizer.py b/code/optimizer.py new file mode 100644 index 0000000..36fdd84 --- /dev/null +++ b/code/optimizer.py @@ -0,0 +1,4 @@ +import torch + +def optim(args, trainable_params): + return getattr(torch.optim, args.optimizer)(trainable_params, lr=args.learning_rate) \ No newline at end of file diff --git a/code/scheduler.py b/code/scheduler.py new file mode 100644 index 0000000..ffc3f16 --- /dev/null +++ b/code/scheduler.py @@ -0,0 +1,8 @@ +from torch.optim import lr_scheduler + +def sched(args, optimizer): + if args.scheduler == "multistep": + return lr_scheduler.MultiStepLR(optimizer, milestones=[args.max_epoch // 2, args.max_epoch // 2 * 2], gamma=0.1) + elif args.scheduler == "cosine": + return lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0) + \ No newline at end of file diff --git a/code/train.py b/code/train.py index 7143436..cc1639a 100644 --- a/code/train.py +++ b/code/train.py @@ -9,7 +9,8 @@ import torch from torch import cuda from torch.utils.data import DataLoader -from torch.optim import lr_scheduler +from optimizer import optim +from scheduler import sched from tqdm import tqdm import wandb @@ -26,7 +27,7 @@ def parse_args(): parser = ArgumentParser() # pkl 데이터셋 경로 - parser.add_argument('--train_dataset_dir', type=str, default="/data/ephemeral/home/level2-cv-datacentric-cv-01/data/medical/pickle/[2048]_cs[1024]_aug['CJ', 'GB', 'N']/train") + parser.add_argument('--train_dataset_dir', type=str, default="/data/ephemeral/home/level2-cv-datacentric-cv-01/data/medical/pickle/[1024, 1536, 2048]_cs[256, 512, 1024, 2048]_aug['CJ', 'GB', 'N']/train") parser.add_argument('--data_dir', type=str, default=os.environ.get('SM_CHANNEL_TRAIN', '../data/medical')) parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR', 'trained_models')) parser.add_argument('--seed', type=int, default=137) @@ -43,7 +44,7 @@ def parse_args(): parser.add_argument('-m', '--mode', type=str, default='on', help='wandb logging mode(on: online, off: disabled)') parser.add_argument('-p', '--project', type=str, default='datacentric', help='wandb project name') parser.add_argument('-d', '--data', default='pickle', type=str, help='description about dataset', choices=['original', 'pickle']) - parser.add_argument("--optimizer", type=str, default='adam', choices=['adam', 'adamW']) + parser.add_argument("--optimizer", type=str, default='Adam', choices=['adam', 'adamW']) parser.add_argument("--scheduler", type=str, default='cosine', choices=['multistep', 'cosine']) parser.add_argument("--resume", type=str, default=None, choices=[None, 'resume', 'finetune']) @@ -123,18 +124,10 @@ def do_training(args): checkpoint = torch.load(osp.join(args.save_dir, "best.pth")) model.load_state_dict(checkpoint) - ### Optimizer ### - if args.optimizer == "adam": - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - elif args.optimizer == "adamW": - optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) - - ### Scheduler ### - if args.scheduler == "multistep": - scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[args.max_epoch // 2, args.max_epoch // 2 * 2], gamma=0.1) - elif args.scheduler == "cosine": - scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0) - + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = optim(args, trainable_params) + scheduler = sched(args, optimizer) + ### WandB ### if args.mode == 'on': wandb.init(