-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_wrapper.py
157 lines (143 loc) · 4.45 KB
/
train_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""Functionality for locally training files."""
import argparse
import os
from typing import List
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from app.ml_models.rnn.data_tools import WordLevelDataset
from app.ml_models.rnn.vocabulary import Vocabulary
from app.ml_models.rnn.train import train
from app.ml_models.rnn.loaded_rnn_model import init_torch_model_from_path
from app.ml_models.rnn.rnn_model import RNNAnna
def _run(**kwargs) -> None:
"""Run a simple training loop."""
# Preliminary checks
if kwargs["batch_size"] != 1:
raise ValueError(
"Batch size 1 is not implemented. Want to help out? https://github.com/Sasafrass/straattaal/issues/22"
)
# Build or load vobaulary
v = Vocabulary()
try:
v.load(prefix=kwargs["data_directory"], filename_vocab=kwargs["filename_vocab"])
except FileNotFoundError:
v.build(
prefix=kwargs["data_directory"],
filename_datasets=kwargs["filename_datasets"],
filename_destination=kwargs["filename_vocab"],
)
# Build training set and dataloader.
dataset = WordLevelDataset(
prefix=kwargs["data_directory"],
filename_datasets=kwargs["filename_datasets"],
vocabulary=v,
)
if kwargs["num_training_steps"] is not None:
# Steps = epochs * (data set size / batch size)
# Hence Epochs = steps / (data set size / batch size)
kwargs["epochs"] = kwargs["num_training_steps"] // (
len(dataset) / kwargs["batch_size"]
)
train_loader = DataLoader(dataset, kwargs["batch_size"], shuffle=True)
# Set the stage
if not os.path.exists("models"):
os.makedirs("models")
# Build or load model
if kwargs["model_startpoint"] is not None:
rnn = init_torch_model_from_path(
kwargs["model_startpoint"], v, device=kwargs["device"]
)
else:
rnn = RNNAnna(dataset.vocabulary.size, hidden_size=kwargs["hidden_size"])
# Train.
train(
rnn,
train_loader,
dataset,
print_every=3000,
save_directory="models",
**kwargs,
)
if __name__ == "__main__":
# Parse command line arguments
PARSER = argparse.ArgumentParser()
PARSER.add_argument(
"--device",
type=str,
default="cpu",
help="Device to run the network on: cpu or cuda/gpu.",
)
PARSER.add_argument(
"--filename_datasets",
nargs="+",
default=["dutch.txt"],
help="Files to pull the datasets from. Can be passed as space separated list.",
)
PARSER.add_argument(
"--filename_vocab",
type=str,
default="vocabulary.txt",
help="File to get the vocabulary from. If the file does not exist, it will be created from the given train sets.",
)
PARSER.add_argument(
"--data_directory",
type=str,
default="data",
help="Prefix for the directory in which the data and vocabulary is stored.",
)
PARSER.add_argument(
"--model_name",
type=str,
default="my_own_model",
help="Name used for storing intermediate models in the models/ directory.",
)
PARSER.add_argument(
"--model_startpoint",
type=str,
default=None,
help="If this is set, a model is loaded from this path as a starting point for training.",
)
PARSER.add_argument(
"--num_epochs",
type=int,
default=10,
help="Number of epochs for training.",
)
PARSER.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size.",
)
PARSER.add_argument(
"--num_training_steps",
type=int,
default=None,
help="Number of steps for training. Overrides --num_epochs if set.",
)
PARSER.add_argument(
"--save_every",
type=int,
default=2,
help="Epochs to complete before saving an intermediate model statedict.",
)
PARSER.add_argument(
"--learning_rate",
type=float,
default=0.0005,
help="Learning rate for the SGD optimizer.",
)
PARSER.add_argument(
"--momentum",
type=float,
default=0.9,
help="Momentum for the SGD optimizer.",
)
PARSER.add_argument(
"--hidden_size",
type=int,
default=128,
help="Hidden size of the recurrent model.",
)
_run(**vars(PARSER.parse_args()))