-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathexample.py
87 lines (76 loc) · 4.47 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import os
import wandb
from trainer import Trainer
from loss import loss_DANN, class_prediction_loss
from models import DANNModel, OneDomainModel
from dataloader import create_data_generators
from metrics import AccuracyScoreFromLogits
from utils.callbacks import simple_callback, print_callback, ModelSaver, HistorySaver, WandbCallback
from utils.schedulers import LRSchedulerSGD
import configs.dann_config as dann_config
# os.environ['CUDA_VISIBLE_DEVICES'] = '4, 5'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if __name__ == '__main__':
print('source_domain is {}, target_domain is {}'.format(dann_config.SOURCE_DOMAIN, dann_config.TARGET_DOMAIN))
train_gen_s, val_gen_s, _ = create_data_generators(dann_config.DATASET,
dann_config.SOURCE_DOMAIN,
batch_size=dann_config.BATCH_SIZE,
infinite_train=True,
image_size=dann_config.IMAGE_SIZE,
split_ratios=[0.9, 0.1, 0],
num_workers=dann_config.NUM_WORKERS,
device=device)
train_gen_t, _, _ = create_data_generators(dann_config.DATASET,
dann_config.TARGET_DOMAIN,
batch_size=dann_config.BATCH_SIZE,
infinite_train=True,
split_ratios=[1, 0, 0],
image_size=dann_config.IMAGE_SIZE,
num_workers=dann_config.NUM_WORKERS,
device=device)
val_gen_t, _, _ = create_data_generators(dann_config.DATASET,
dann_config.TARGET_DOMAIN,
batch_size=dann_config.BATCH_SIZE,
infinite_train=False,
split_ratios=[1, 0, 0],
image_size=dann_config.IMAGE_SIZE,
num_workers=dann_config.NUM_WORKERS,
device=device)
model = DANNModel().to(device)
acc = AccuracyScoreFromLogits()
scheduler = LRSchedulerSGD(blocks_with_smaller_lr=dann_config.BLOCKS_WITH_SMALLER_LR)
tr = Trainer(model, loss_DANN)
experiment_name = 'test_run_after_merge'
details_name = ''
tr.fit(train_gen_s, train_gen_t,
n_epochs=dann_config.N_EPOCHS,
validation_data=[val_gen_s, val_gen_t],
metrics=[acc],
steps_per_epoch=dann_config.STEPS_PER_EPOCH,
val_freq=dann_config.VAL_FREQ,
opt='sgd',
opt_kwargs={'lr': 0.01, 'momentum': 0.9},
lr_scheduler=scheduler,
callbacks=[
print_callback(watch=["loss", "domain_loss", "val_loss",
"val_domain_loss", 'trg_metrics', 'src_metrics']),
ModelSaver(str(experiment_name + '_' + dann_config.SOURCE_DOMAIN + '_' +
dann_config.TARGET_DOMAIN + '_' + details_name),
dann_config.SAVE_MODEL_FREQ,
save_by_schedule=False,
save_best=True,
eval_metric='accuracy'),
WandbCallback(config=dann_config,
name=str(dann_config.SOURCE_DOMAIN + "_" + dann_config.TARGET_DOMAIN +
"_" + details_name),
group=experiment_name),
HistorySaver(str(experiment_name + '_' + dann_config.SOURCE_DOMAIN + '_' +
dann_config.TARGET_DOMAIN + "_" + details_name),
dann_config.VAL_FREQ, path=str('_log/' + experiment_name + "_" + details_name),
# extra_losses={'domain_loss': ['domain_loss', 'val_domain_loss'],
# 'train_domain_loss': ['domain_loss_on_src',
# 'domain_loss_on_trg']}
)
])
wandb.join()