-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathpretraining_omniglot.py
73 lines (52 loc) · 2.09 KB
/
pretraining_omniglot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import numpy as np
import torch
from torch.nn import functional as F
from tqdm import tqdm
import datasets.datasetfactory as df
import configs.classification.pretraining_parser as params
import model.learner as learner
import model.modelfactory as mf
import utils
from experiment.experiment import experiment
import logging
logger = logging.getLogger('experiment')
def main():
p = params.Parser()
total_seeds = len(p.parse_known_args()[0].seed)
rank = p.parse_known_args()[0].rank
all_args = vars(p.parse_known_args()[0])
print("All args = ", all_args)
args = utils.get_run(vars(p.parse_known_args()[0]), rank)
utils.set_seed(args['seed'])
my_experiment = experiment(args['name'], args, "../results/", commit_changes=False, rank=0, seed=1)
gpu_to_use = rank % args["gpus"]
if torch.cuda.is_available():
device = torch.device('cuda:' + str(gpu_to_use))
logger.info("Using gpu : %s", 'cuda:' + str(gpu_to_use))
else:
device = torch.device('cpu')
dataset = df.DatasetFactory.get_dataset(args['dataset'], background=True, train=True,path=args["path"], all=True)
iterator = torch.utils.data.DataLoader(dataset, batch_size=256,
shuffle=True, num_workers=0)
logger.info(str(args))
config = mf.ModelFactory.get_model("na", args["dataset"])
maml = learner.Learner(config).to(device)
for k, v in maml.named_parameters():
print(k, v.requires_grad)
opt = torch.optim.Adam(maml.parameters(), lr=args["lr"])
for e in range(args["epoch"]):
correct = 0
for img, y in tqdm(iterator):
img = img.to(device)
y = y.to(device)
pred = maml(img)
opt.zero_grad()
loss = F.cross_entropy(pred, y.long())
loss.backward()
opt.step()
correct += (pred.argmax(1) == y).sum().float() / len(y)
logger.info("Accuracy at epoch %d = %s", e, str(correct / len(iterator)))
torch.save(maml, my_experiment.path + "model.net")
if __name__ == '__main__':
main()