Skip to content

Commit

Permalink
refactor: separate files (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
2018007956 authored Jan 29, 2024
1 parent 5f2ccf7 commit ca76b7a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
4 changes: 4 additions & 0 deletions code/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import torch

def optim(args, trainable_params):
return getattr(torch.optim, args.optimizer)(trainable_params, lr=args.learning_rate)
8 changes: 8 additions & 0 deletions code/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from torch.optim import lr_scheduler

def sched(args, optimizer):
if args.scheduler == "multistep":
return lr_scheduler.MultiStepLR(optimizer, milestones=[args.max_epoch // 2, args.max_epoch // 2 * 2], gamma=0.1)
elif args.scheduler == "cosine":
return lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0)

23 changes: 8 additions & 15 deletions code/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import torch
from torch import cuda
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from optimizer import optim
from scheduler import sched
from tqdm import tqdm
import wandb

Expand All @@ -26,7 +27,7 @@ def parse_args():
parser = ArgumentParser()

# pkl 데이터셋 경로
parser.add_argument('--train_dataset_dir', type=str, default="/data/ephemeral/home/level2-cv-datacentric-cv-01/data/medical/pickle/[2048]_cs[1024]_aug['CJ', 'GB', 'N']/train")
parser.add_argument('--train_dataset_dir', type=str, default="/data/ephemeral/home/level2-cv-datacentric-cv-01/data/medical/pickle/[1024, 1536, 2048]_cs[256, 512, 1024, 2048]_aug['CJ', 'GB', 'N']/train")
parser.add_argument('--data_dir', type=str, default=os.environ.get('SM_CHANNEL_TRAIN', '../data/medical'))
parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR', 'trained_models'))
parser.add_argument('--seed', type=int, default=137)
Expand All @@ -43,7 +44,7 @@ def parse_args():
parser.add_argument('-m', '--mode', type=str, default='on', help='wandb logging mode(on: online, off: disabled)')
parser.add_argument('-p', '--project', type=str, default='datacentric', help='wandb project name')
parser.add_argument('-d', '--data', default='pickle', type=str, help='description about dataset', choices=['original', 'pickle'])
parser.add_argument("--optimizer", type=str, default='adam', choices=['adam', 'adamW'])
parser.add_argument("--optimizer", type=str, default='Adam', choices=['adam', 'adamW'])
parser.add_argument("--scheduler", type=str, default='cosine', choices=['multistep', 'cosine'])
parser.add_argument("--resume", type=str, default=None, choices=[None, 'resume', 'finetune'])

Expand Down Expand Up @@ -123,18 +124,10 @@ def do_training(args):
checkpoint = torch.load(osp.join(args.save_dir, "best.pth"))
model.load_state_dict(checkpoint)

### Optimizer ###
if args.optimizer == "adam":
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
elif args.optimizer == "adamW":
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

### Scheduler ###
if args.scheduler == "multistep":
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[args.max_epoch // 2, args.max_epoch // 2 * 2], gamma=0.1)
elif args.scheduler == "cosine":
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0)

trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim(args, trainable_params)
scheduler = sched(args, optimizer)

### WandB ###
if args.mode == 'on':
wandb.init(
Expand Down

0 comments on commit ca76b7a

Please sign in to comment.