-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
executable file
·72 lines (59 loc) · 2.02 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import absolute_import, division, print_function
import os
import sys
import glob
import time
import json
import torch
import pickle
import random
import logging
import subprocess
import numpy as np
import jsbeautifier
from args import load_args
from logger import create_logger
from train import baseline_training
from evaluate import baseline_eval, FewShotBenchmark
from utils import set_seed, dist_training, backup_codes_to_project_folder
opts = jsbeautifier.default_options()
def main():
start_time = time.time()
args = load_args()
if args.load_args is not None:
args = torch.load(args.load_args)
logger_id = args.logger_id if args.logger_id is not None else max(args.local_rank, 0)
logger = create_logger(
os.path.join(args.output_dir, "{}-{}.log".format(args.task_name, logger_id)),
process_rank = logger_id
)
args.device = str(args.device) if hasattr(args,'device') else None
logger.info("{}".format(jsbeautifier.beautify(json.dumps(args.__dict__), opts)))
args = dist_training(args)
set_seed(args.seed)
logger.warning(
" Process rank: {}, Python Version : {}\n"
" torch version {}, cuda version: {}, CuDNN version : {}\n"
" device: {}, n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
args.local_rank,
sys.version.replace("\n", " "),
torch.__version__,
str(torch.version.cuda),
torch.backends.cudnn.version(),
args.device,
args.n_gpu,
bool(args.local_rank != -1),
args.fp16,
)
)
if args.do_eval:
baseline_eval(args, logger)
if args.do_few_shot_benchamrk:
FewShotBenchmark(args, logger)
if args.do_train:
baseline_training(args, logger)
end_time = time.time()
ty_res = time.gmtime(end_time-start_time)
logger.info("Total runtime : {}".format( time.strftime("%H:%M:%S",ty_res) ))
if __name__ == "__main__":
main()