-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathairbench96_faster.py
711 lines (598 loc) · 28.7 KB
/
airbench96_faster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
# A variant of airbench optimized for time-to-96%.
# 27.3s runtime on an A100; 3.1 PFLOPs.
# Evidence: 96.00 average accuracy in n=200 runs.
#
# We recorded the above runtime on an NVIDIA A100-SXM4-40GB with the following nvidia-smi:
# NVIDIA-SMI 515.105.01 Driver Version: 515.105.01 CUDA Version: 11.7
# torch.__version__ == '2.4.0+cu121'
#############################################
# Setup/Hyperparameters #
#############################################
import os
import sys
import uuid
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
torch.backends.cudnn.benchmark = True
hyp = {
'opt': {
'train_epochs': 45.0,
'batch_size': 1024,
'batch_size_masked': 512,
'lr': 9.0, # learning rate per 1024 examples
'momentum': 0.85,
'weight_decay': 0.012, # weight decay per 1024 examples (decoupled from learning rate)
'bias_scaler': 64.0, # scales up learning rate (but not weight decay) for BatchNorm biases
'label_smoothing': 0.2,
'whiten_bias_epochs': 3, # how many epochs to train the whitening layer bias before freezing
},
'aug': {
'flip': True,
'translate': 4,
'cutout': 12,
},
'proxy': {
'widths': {
'block1': 32,
'block2': 64,
'block3': 64,
},
'depth': 2,
'scaling_factor': 1/9,
},
'net': {
'widths': {
'block1': 128,
'block2': 384,
'block3': 512,
},
'depth': 3,
'scaling_factor': 1/9,
'tta_level': 2, # the level of test-time augmentation: 0=none, 1=mirror, 2=mirror+translate
},
}
#############################################
# DataLoader #
#############################################
CIFAR_MEAN = torch.tensor((0.4914, 0.4822, 0.4465))
CIFAR_STD = torch.tensor((0.2470, 0.2435, 0.2616))
def batch_flip_lr(inputs):
flip_mask = (torch.rand(len(inputs), device=inputs.device) < 0.5).view(-1, 1, 1, 1)
return torch.where(flip_mask, inputs.flip(-1), inputs)
def batch_crop(images, crop_size):
r = (images.size(-1) - crop_size)//2
shifts = torch.randint(-r, r+1, size=(len(images), 2), device=images.device)
images_out = torch.empty((len(images), 3, crop_size, crop_size), device=images.device, dtype=images.dtype)
# The two cropping methods in this if-else produce equivalent results, but the second is faster for r > 2.
if r <= 2:
for sy in range(-r, r+1):
for sx in range(-r, r+1):
mask = (shifts[:, 0] == sy) & (shifts[:, 1] == sx)
images_out[mask] = images[mask, :, r+sy:r+sy+crop_size, r+sx:r+sx+crop_size]
else:
images_tmp = torch.empty((len(images), 3, crop_size, crop_size+2*r), device=images.device, dtype=images.dtype)
for s in range(-r, r+1):
mask = (shifts[:, 0] == s)
images_tmp[mask] = images[mask, :, r+s:r+s+crop_size, :]
for s in range(-r, r+1):
mask = (shifts[:, 1] == s)
images_out[mask] = images_tmp[mask, :, :, r+s:r+s+crop_size]
return images_out
def make_random_square_masks(inputs, size):
is_even = int(size % 2 == 0)
n,c,h,w = inputs.shape
# seed top-left corners of squares to cutout boxes from, in one dimension each
corner_y = torch.randint(0, h-size+1, size=(n,), device=inputs.device)
corner_x = torch.randint(0, w-size+1, size=(n,), device=inputs.device)
# measure distance, using the center as a reference point
corner_y_dists = torch.arange(h, device=inputs.device).view(1, 1, h, 1) - corner_y.view(-1, 1, 1, 1)
corner_x_dists = torch.arange(w, device=inputs.device).view(1, 1, 1, w) - corner_x.view(-1, 1, 1, 1)
mask_y = (corner_y_dists >= 0) * (corner_y_dists < size)
mask_x = (corner_x_dists >= 0) * (corner_x_dists < size)
final_mask = mask_y * mask_x
return final_mask
def batch_cutout(inputs, size):
cutout_masks = make_random_square_masks(inputs, size)
return inputs.masked_fill(cutout_masks, 0)
def set_random_state(seed, state):
if seed is None:
# If we don't get a data seed, then make sure to randomize the state using independent generator, since
# it might have already been set by the model seed.
import random
torch.manual_seed(random.randint(0, 2**63))
else:
seed1 = 1000 * seed + state # just don't do more than 1000 epochs or else there will be overlap
torch.manual_seed(seed1)
class InfiniteCifarLoader:
"""
CIFAR-10 loader which constructs every input to be used during training during the call to __iter__.
The purpose is to support cross-epoch batches (in case the batch size does not divide the number of train examples),
and support stochastic iteration counts in order to preserve perfect linearity/independence.
"""
def __init__(self, path, train=True, batch_size=500, aug=None, altflip=True, subset_mask=None, aug_seed=None, order_seed=None):
data_path = os.path.join(path, 'train.pt' if train else 'test.pt')
if not os.path.exists(data_path):
dset = torchvision.datasets.CIFAR10(path, download=True, train=train)
images = torch.tensor(dset.data)
labels = torch.tensor(dset.targets)
torch.save({'images': images, 'labels': labels, 'classes': dset.classes}, data_path)
data = torch.load(data_path, map_location='cuda')
self.images, self.labels, self.classes = data['images'], data['labels'], data['classes']
# It's faster to load+process uint8 data than to load preprocessed fp16 data
self.images = (self.images.half() / 255).permute(0, 3, 1, 2).to(memory_format=torch.channels_last)
self.normalize = T.Normalize(CIFAR_MEAN, CIFAR_STD)
self.aug = aug or {}
for k in self.aug.keys():
assert k in ['flip', 'translate', 'cutout'], 'Unrecognized key: %s' % k
self.batch_size = batch_size
self.altflip = altflip
self.subset_mask = subset_mask if subset_mask is not None else torch.tensor([True]*len(self.images)).cuda()
self.train = train
self.aug_seed = aug_seed
self.order_seed = order_seed
def __iter__(self):
# Preprocess
images0 = self.normalize(self.images)
# Pre-randomly flip images in order to do alternating flip later.
if self.aug.get('flip', False) and self.altflip:
set_random_state(self.aug_seed, 0)
images0 = batch_flip_lr(images0)
# Pre-pad images to save time when doing random translation
pad = self.aug.get('translate', 0)
if pad > 0:
images0 = F.pad(images0, (pad,)*4, 'reflect')
labels0 = self.labels
# Iterate forever
epoch = 0
batch_size = self.batch_size
# In the below while-loop, we will repeatedly build a batch and then yield it.
num_examples = self.subset_mask.sum().item()
current_pointer = num_examples
batch_images = torch.empty(0, 3, 32, 32, dtype=images0.dtype, device=images0.device)
batch_labels = torch.empty(0, dtype=labels0.dtype, device=labels0.device)
batch_indices = torch.empty(0, dtype=labels0.dtype, device=labels0.device)
while True:
# Assume we need to generate more data to add to the batch.
assert len(batch_images) < batch_size
# If we have already exhausted the current epoch, then begin a new one.
if current_pointer >= num_examples:
# If we already reached the end of the last epoch then we need to generate
# a new augmented epoch of data (using random crop and alternating flip).
epoch += 1
set_random_state(self.aug_seed, epoch)
if pad > 0:
images1 = batch_crop(images0, 32)
if self.aug.get('flip', False):
if self.altflip:
images1 = images1 if epoch % 2 == 0 else images1.flip(-1)
else:
images1 = batch_flip_lr(images1)
if self.aug.get('cutout', 0) > 0:
images1 = batch_cutout(images1, self.aug['cutout'])
set_random_state(self.order_seed, epoch)
indices = (torch.randperm if self.train else torch.arange)(len(self.images), device=images0.device)
# The effect of doing subsetting in this manner is as follows. If the permutation wants to show us
# our four examples in order [3, 2, 0, 1], and the subset mask is [True, False, True, False],
# then we will be shown the examples [2, 0]. It is the subset of the ordering.
# The purpose is to minimize the interaction between the subset mask and the randomness.
# So that the subset causes not only a subset of the total examples to be shown, but also a subset of
# the actual sequence of examples which is shown during training.
indices_subset = indices[self.subset_mask[indices]]
current_pointer = 0
# Now we are sure to have more data in this epoch remaining.
# This epoch's remaining data is given by (images1[current_pointer:], labels0[current_pointer:])
# We add more data to the batch, up to whatever is needed to make a full batch (but it might not be enough).
remaining_size = batch_size - len(batch_images)
# Given that we want `remaining_size` more training examples, we construct them here, using
# the remaining available examples in the epoch.
extra_indices = indices_subset[current_pointer:current_pointer+remaining_size]
extra_images = images1[extra_indices]
extra_labels = labels0[extra_indices]
current_pointer += remaining_size
batch_indices = torch.cat([batch_indices, extra_indices])
batch_images = torch.cat([batch_images, extra_images])
batch_labels = torch.cat([batch_labels, extra_labels])
# If we have a full batch ready then yield it and reset.
if len(batch_images) == batch_size:
assert len(batch_images) == len(batch_labels)
yield (batch_indices, batch_images, batch_labels)
batch_images = torch.empty(0, 3, 32, 32, dtype=images0.dtype, device=images0.device)
batch_labels = torch.empty(0, dtype=labels0.dtype, device=labels0.device)
batch_indices = torch.empty(0, dtype=labels0.dtype, device=labels0.device)
############################################
# Evaluation #
############################################
def infer(model, loader, tta_level=0):
# Test-time augmentation strategy (for tta_level=2):
# 1. Flip/mirror the image left-to-right (50% of the time).
# 2. Translate the image by one pixel either up-and-left or down-and-right (50% of the time,
# i.e. both happen 25% of the time).
#
# This creates 6 views per image (left/right times the two translations and no-translation),
# which we evaluate and then weight according to the given probabilities.
def infer_basic(inputs, net):
return net(inputs).clone()
def infer_mirror(inputs, net):
return 0.5 * net(inputs) + 0.5 * net(inputs.flip(-1))
def infer_mirror_translate(inputs, net):
logits = infer_mirror(inputs, net)
pad = 1
padded_inputs = F.pad(inputs, (pad,)*4, 'reflect')
inputs_translate_list = [
padded_inputs[:, :, 0:32, 0:32],
padded_inputs[:, :, 2:34, 2:34],
]
logits_translate_list = [infer_mirror(inputs_translate, net)
for inputs_translate in inputs_translate_list]
logits_translate = torch.stack(logits_translate_list).mean(0)
return 0.5 * logits + 0.5 * logits_translate
model.eval()
test_images = loader.normalize(loader.images)
infer_fn = [infer_basic, infer_mirror, infer_mirror_translate][tta_level]
with torch.no_grad():
return torch.cat([infer_fn(inputs, model) for inputs in test_images.split(2000)])
def evaluate(model, loader, tta_level=0):
logits = infer(model, loader, tta_level)
return (logits.argmax(1) == loader.labels).float().mean().item()
#############################################
# Network Components #
#############################################
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class Mul(nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, x):
return x * self.scale
class BatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-12,
weight=False, bias=True):
super().__init__(num_features, eps=eps)
self.weight.requires_grad = weight
self.bias.requires_grad = bias
# Note that PyTorch already initializes the weights to one and bias to zero
class Conv(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False):
super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
def reset_parameters(self):
super().reset_parameters()
if self.bias is not None:
self.bias.data.zero_()
w = self.weight.data
torch.nn.init.dirac_(w[:w.size(1)])
class ConvGroup(nn.Module):
def __init__(self, channels_in, channels_out, depth):
super().__init__()
assert depth in (2, 3)
self.depth = depth
self.conv1 = Conv(channels_in, channels_out)
self.pool = nn.MaxPool2d(2)
self.norm1 = BatchNorm(channels_out)
self.conv2 = Conv(channels_out, channels_out)
self.norm2 = BatchNorm(channels_out)
if depth == 3:
self.conv3 = Conv(channels_out, channels_out)
self.norm3 = BatchNorm(channels_out)
self.activ = nn.GELU()
def forward(self, x):
x = self.conv1(x)
x = self.pool(x)
x = self.norm1(x)
x = self.activ(x)
if self.depth == 3:
x0 = x
x = self.conv2(x)
x = self.norm2(x)
x = self.activ(x)
if self.depth == 3:
x = self.conv3(x)
x = self.norm3(x)
x = x + x0
x = self.activ(x)
return x
#############################################
# Network Definition #
#############################################
def make_net(hyp):
widths = hyp['widths']
scaling_factor = hyp['scaling_factor']
depth = hyp['depth']
whiten_kernel_size = 2
whiten_width = 2 * 3 * whiten_kernel_size**2
net = nn.Sequential(
Conv(3, whiten_width, whiten_kernel_size, padding=0, bias=True),
nn.GELU(),
ConvGroup(whiten_width, widths['block1'], depth),
ConvGroup(widths['block1'], widths['block2'], depth),
ConvGroup(widths['block2'], widths['block3'], depth),
nn.MaxPool2d(3),
Flatten(),
nn.Linear(widths['block3'], 10, bias=False),
Mul(scaling_factor),
)
net[0].weight.requires_grad = False
net = net.half().cuda()
net = net.to(memory_format=torch.channels_last)
for mod in net.modules():
if isinstance(mod, BatchNorm):
mod.float()
return net
def reinit_net(model):
for m in model.modules():
if type(m) in (Conv, BatchNorm, nn.Linear):
m.reset_parameters()
#############################################
# Whitening Conv Initialization #
#############################################
def get_patches(x, patch_shape):
c, (h, w) = x.shape[1], patch_shape
return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float()
def get_whitening_parameters(patches):
n,c,h,w = patches.shape
patches_flat = patches.view(n, -1)
est_patch_covariance = (patches_flat.T @ patches_flat) / n
eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO='U')
return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.T.reshape(c*h*w,c,h,w).flip(0)
def init_whitening_conv(layer, train_set, eps=5e-4):
patches = get_patches(train_set, patch_shape=layer.weight.data.shape[2:])
eigenvalues, eigenvectors = get_whitening_parameters(patches)
eigenvectors_scaled = eigenvectors / torch.sqrt(eigenvalues + eps)
layer.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled))
############################################
# Lookahead #
############################################
class LookaheadState:
def __init__(self, net):
self.net_ema = {k: v.clone() for k, v in net.state_dict().items()}
def update(self, net, decay):
for ema_param, net_param in zip(self.net_ema.values(), net.state_dict().values()):
if net_param.dtype in (torch.half, torch.float):
ema_param.lerp_(net_param, 1-decay)
net_param.copy_(ema_param)
############################################
# Logging #
############################################
def print_columns(columns_list, is_head=False, is_final_entry=False):
print_string = ''
for col in columns_list:
print_string += '| %s ' % col
print_string += '|'
if is_head:
print('-'*len(print_string))
print(print_string)
if is_head or is_final_entry:
print('-'*len(print_string))
logging_columns_list = ['run ', 'epoch', 'train_loss', 'train_acc', 'val_acc', 'tta_val_acc', 'total_time_seconds']
def print_training_details(variables, is_final_entry):
formatted = []
for col in logging_columns_list:
var = variables.get(col.strip(), None)
if type(var) in (int, str):
res = str(var)
elif type(var) is float:
res = '{:0.4f}'.format(var)
else:
assert var is None
res = ''
formatted.append(res.rjust(len(col)))
print_columns(formatted, is_final_entry=is_final_entry)
############################################
# Training #
############################################
def train_proxy(hyp, model, data_seed):
batch_size = hyp['opt']['batch_size']
epochs = hyp['opt']['train_epochs']
momentum = hyp['opt']['momentum']
kilostep_scale = 1024 * (1 + 1 / (1 - momentum))
lr = hyp['opt']['lr'] / kilostep_scale # un-decoupled learning rate for PyTorch SGD
wd = hyp['opt']['weight_decay'] * batch_size / kilostep_scale
lr_biases = lr * hyp['opt']['bias_scaler']
loss_fn = nn.CrossEntropyLoss(label_smoothing=hyp['opt']['label_smoothing'], reduction='none')
train_loader = InfiniteCifarLoader('cifar10', train=True, batch_size=batch_size, aug=hyp['aug'],
aug_seed=data_seed, order_seed=data_seed)
steps_per_epoch = len(train_loader.images) // batch_size
total_train_steps = ceil(steps_per_epoch * epochs)
set_random_state(None, 0)
reinit_net(model)
print('Proxy parameters:', sum(p.numel() for p in model.parameters()))
current_steps = 0
norm_biases = [p for k, p in model.named_parameters() if 'norm' in k and p.requires_grad]
other_params = [p for k, p in model.named_parameters() if 'norm' not in k and p.requires_grad]
param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases),
dict(params=other_params, lr=lr, weight_decay=wd/lr)]
optimizer = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True)
def get_lr(step):
warmup_steps = int(total_train_steps * 0.1)
warmdown_steps = total_train_steps - warmup_steps
if step < warmup_steps:
frac = step / warmup_steps
return 0.2 * (1 - frac) + 1.0 * frac
else:
frac = (total_train_steps - step) / warmdown_steps
return frac
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr)
# Initialize the whitening layer using training images
train_images = train_loader.normalize(train_loader.images[:5000])
init_whitening_conv(model._orig_mod[0], train_images)
masks = []
for indices, inputs, labels in train_loader:
if current_steps % steps_per_epoch == 0:
epoch = current_steps // steps_per_epoch
model.train()
# Skip every other backward pass
if current_steps % 4 == 0:
outputs = model(inputs)
loss1 = loss_fn(outputs, labels)
mask = torch.zeros(len(inputs)).cuda().bool()
mask[loss1.argsort()[-hyp['opt']['batch_size_masked']:]] = True
masks.append(mask)
loss = (loss1 * mask.float()).sum()
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
else:
with torch.no_grad():
outputs = model(inputs)
loss1 = loss_fn(outputs, labels)
mask = torch.zeros(len(inputs)).cuda().bool()
mask[loss1.argsort()[-hyp['opt']['batch_size_masked']:]] = True
masks.append(mask)
scheduler.step()
current_steps += 1
if current_steps == total_train_steps:
break
return masks
def main(run, hyp, model_proxy, model_trainbias, model_freezebias):
batch_size = hyp['opt']['batch_size']
epochs = hyp['opt']['train_epochs']
momentum = hyp['opt']['momentum']
# Assuming gradients are constant in time, for Nesterov momentum, the below ratio is how much
# larger the default steps will be than the underlying per-example gradients. We divide the
# learning rate by this ratio in order to ensure steps are the same scale as gradients, regardless
# of the choice of momentum.
kilostep_scale = 1024 * (1 + 1 / (1 - momentum))
lr = hyp['opt']['lr'] / kilostep_scale # un-decoupled learning rate for PyTorch SGD
wd = hyp['opt']['weight_decay'] * batch_size / kilostep_scale
lr_biases = lr * hyp['opt']['bias_scaler']
set_random_state(None, 0)
import random
data_seed = random.randint(0, 2**50)
loss_fn = nn.CrossEntropyLoss(label_smoothing=hyp['opt']['label_smoothing'], reduction='none')
test_loader = InfiniteCifarLoader('cifar10', train=False, batch_size=2000)
train_loader = InfiniteCifarLoader('cifar10', train=True, batch_size=batch_size, aug=hyp['aug'],
aug_seed=data_seed, order_seed=data_seed)
steps_per_epoch = len(train_loader.images) // batch_size
total_train_steps = ceil(steps_per_epoch * epochs)
set_random_state(None, 0)
reinit_net(model_trainbias)
print('Main model parameters:', sum(p.numel() for p in model_trainbias.parameters()))
current_steps = 0
norm_biases = [p for k, p in model_trainbias.named_parameters() if 'norm' in k]
other_params = [p for k, p in model_trainbias.named_parameters() if 'norm' not in k]
param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases),
dict(params=other_params, lr=lr, weight_decay=wd/lr)]
optimizer_trainbias = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True)
norm_biases = [p for k, p in model_freezebias.named_parameters() if 'norm' in k]
other_params = [p for k, p in model_freezebias.named_parameters() if 'norm' not in k]
param_configs = [dict(params=norm_biases, lr=lr_biases, weight_decay=wd/lr_biases),
dict(params=other_params, lr=lr, weight_decay=wd/lr)]
optimizer_freezebias = torch.optim.SGD(param_configs, momentum=momentum, nesterov=True)
def get_lr(step):
warmup_steps = int(total_train_steps * 0.1)
warmdown_steps = total_train_steps - warmup_steps
if step < warmup_steps:
frac = step / warmup_steps
return 0.2 * (1 - frac) + 1.0 * frac
else:
frac = (total_train_steps - step) / warmdown_steps
return frac
scheduler_trainbias = torch.optim.lr_scheduler.LambdaLR(optimizer_trainbias, get_lr)
scheduler_freezebias = torch.optim.lr_scheduler.LambdaLR(optimizer_freezebias, get_lr)
alpha_schedule = 0.95**5 * (torch.arange(total_train_steps+1) / total_train_steps)**3
lookahead_state = LookaheadState(model_trainbias)
# For accurately timing GPU code
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
total_time_seconds = 0.0
# Initialize the whitening layer using training images
starter.record()
train_images = train_loader.normalize(train_loader.images[:5000])
init_whitening_conv(model_trainbias._orig_mod[0], train_images)
ender.record()
torch.cuda.synchronize()
total_time_seconds += 1e-3 * starter.elapsed_time(ender)
# Do a small proxy run to collect masks for use in fullsize run
print('Training small proxy...')
starter.record()
masks = iter(train_proxy(hyp, model_proxy, data_seed))
ender.record()
torch.cuda.synchronize()
total_time_seconds += 1e-3 * starter.elapsed_time(ender)
for indices, inputs, labels in train_loader:
# After training the whiten bias for some epochs, swap in the compiled model with frozen bias
if current_steps == 0:
model = model_trainbias
optimizer = optimizer_trainbias
scheduler = scheduler_trainbias
elif epoch == hyp['opt']['whiten_bias_epochs'] * steps_per_epoch:
model = model_freezebias
optimizer = optimizer_freezebias
scheduler = scheduler_freezebias
model.load_state_dict(model_trainbias.state_dict())
optimizer.load_state_dict(optimizer_trainbias.state_dict())
scheduler.load_state_dict(scheduler_trainbias.state_dict())
####################
# Training #
####################
if current_steps % steps_per_epoch == 0:
epoch = current_steps // steps_per_epoch
starter.record()
model.train()
mask = next(masks)
inputs = inputs[mask]
labels = labels[mask]
outputs = model(inputs)
loss = loss_fn(outputs, labels).sum()
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
scheduler.step()
current_steps += 1
if current_steps % 5 == 0:
lookahead_state.update(model, decay=alpha_schedule[current_steps].item())
if current_steps == total_train_steps:
if lookahead_state is not None:
lookahead_state.update(model, decay=1.0)
if (current_steps % steps_per_epoch == 0) or (current_steps == total_train_steps):
ender.record()
torch.cuda.synchronize()
total_time_seconds += 1e-3 * starter.elapsed_time(ender)
####################
# Evaluation #
####################
# Save the accuracy and loss from the last training batch of the epoch
train_acc = (outputs.detach().argmax(1) == labels).float().mean().item()
train_loss = loss.item() / batch_size
val_acc = evaluate(model, test_loader, tta_level=0)
print_training_details(locals(), is_final_entry=False)
run = None # Only print the run number once
if current_steps == total_train_steps:
break
####################
# TTA Evaluation #
####################
starter.record()
tta_val_acc = evaluate(model, test_loader, tta_level=hyp['net']['tta_level'])
ender.record()
torch.cuda.synchronize()
total_time_seconds += 1e-3 * starter.elapsed_time(ender)
epoch = 'eval'
print_training_details(locals(), is_final_entry=True)
return tta_val_acc
if __name__ == "__main__":
with open(sys.argv[0]) as f:
code = f.read()
model_proxy = make_net(hyp['proxy'])
model_proxy[0].bias.requires_grad = False
model_trainbias = make_net(hyp['net'])
model_freezebias = make_net(hyp['net'])
model_freezebias[0].bias.requires_grad = False
model_proxy = torch.compile(model_proxy, mode='max-autotune')
model_trainbias = torch.compile(model_trainbias, mode='max-autotune')
model_freezebias = torch.compile(model_freezebias, mode='max-autotune')
print_columns(logging_columns_list, is_head=True)
accs = torch.tensor([main(run, hyp, model_proxy, model_trainbias, model_freezebias)
for run in range(200)])
print('Mean: %.4f Std: %.4f' % (accs.mean(), accs.std()))
log = {'code': code, 'accs': accs}
log_dir = os.path.join('logs', str(uuid.uuid4()))
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir, 'log.pt')
print(os.path.abspath(log_path))
torch.save(log, os.path.join(log_dir, 'log.pt'))