-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
103 lines (81 loc) · 2.68 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import hydra
import wandb
from hydra.utils import call, instantiate
from omegaconf import DictConfig, OmegaConf
from bert_ru_sentiment_emotion.model.models import get_model
from bert_ru_sentiment_emotion.trainer.eval import eval
from bert_ru_sentiment_emotion.trainer.train import train
from bert_ru_sentiment_emotion.utils.utils import push_to_hub, save_model
# turn off bert warnings
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# printing full errors
os.environ["HYDRA_FULL_ERROR"] = "1"
# login to services
# wandb.login()
# notebook_login()
def training(cfg: DictConfig):
tokenizer, model = get_model(
cfg.model.encoder,
cfg.dataset.labels,
cfg.dataset.num_labels,
cfg.trainer.problem_type,
cfg.task,
)
train_dataloader, val_dataloader, test_dataloader = call(
cfg.dataset.dataloader, tokenizer=tokenizer
)
if cfg.log_wandb:
wandb.init(
project=f"{cfg.project_name}-{cfg.model.name}-{cfg.dataset.name}",
config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
)
optimizer = instantiate(cfg.optimizer, params=model.parameters())
model.cuda()
train(
model=model,
train_dataloader=train_dataloader,
optimizer=optimizer,
epochs=cfg.trainer.num_epochs,
val_dataloader=val_dataloader,
test_dataloader=test_dataloader,
labels=cfg.dataset.labels,
problem_type=cfg.trainer.problem_type,
log_wandb=cfg.log_wandb,
)
if cfg.log_wandb:
wandb.finish()
ask = input("Upload to hub?: ")
if ask == "y" or ask == "yes":
save_model(
model,
tokenizer,
f"models/{cfg.model.name}-{cfg.dataset.name}-ep={cfg.trainer.num_epochs}-lr={cfg.trainer.lr}",
)
push_to_hub(model, tokenizer, f"{cfg.model.name}-{cfg.dataset.name}")
def evaluation(cfg: DictConfig):
tokenizer, model = get_model(
f"seara/{cfg.model.name}-{cfg.dataset.name}",
cfg.dataset.labels,
cfg.dataset.num_labels,
cfg.trainer.problem_type,
cfg.task,
)
train_dataloader, val_dataloader, test_dataloader = call(
cfg.dataset.dataloader, tokenizer=tokenizer
)
eval(
model=model,
test_dataloader=test_dataloader,
labels=cfg.dataset.labels,
problem_type=cfg.trainer.problem_type,
)
@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg: DictConfig):
if cfg.task == "train":
training(cfg)
elif cfg.task == "eval":
evaluation(cfg)
if __name__ == "__main__":
main()