-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy patharguments.py
128 lines (111 loc) · 4.87 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import Optional, List
from dataclasses import dataclass, field
from transformers import TrainingArguments
@dataclass
class BaseTrainingArguments:
experiment_prefix: str = field(
metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
)
initial_peers: List[str] = field(
default_factory=list,
metadata={"help": "One or more peers (comma-separated) that will welcome you into the collaboration"},
)
dht_listen_on: str = field(
default="[::]:*", metadata={"help": "Network interface used for incoming DHT communication. Default: all ipv6"}
)
@dataclass
class AveragerArguments:
averaging_expiration: float = field(
default=5.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
)
averaging_timeout: float = field(
default=30.0, metadata={"help": "Give up on averaging step after this many seconds"}
)
listen_on: str = field(
default="[::]:*",
metadata={"help": "Network interface used for incoming averager communication. Default: all ipv6"},
)
min_refresh_period: float = field(
default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
)
max_refresh_period: float = field(
default=30, metadata={"help": "Wait for at most this many seconds before fetching new collaboration state"}
)
default_refresh_period: float = field(
default=3, metadata={"help": "Attempt to fetch collaboration state every this often until successful"}
)
expected_drift_peers: float = field(
default=3, metadata={"help": "Trainer assumes that this many new peers can join per step"}
)
expected_drift_rate: float = field(
default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
)
performance_ema_alpha: float = field(
default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
)
target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
metadata_expiration: float = field(
default=30, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
)
@dataclass
class CollaborativeOptimizerArguments:
target_batch_size: int = field(
default=4096,
metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
)
client_mode: bool = field(
default=False,
metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"},
)
batch_size_lead: int = field(
default=0,
metadata={"help": "Optional: begin looking for group in advance, this many samples before target_batch_size"},
)
bandwidth: float = field(
default=100.0,
metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
)
compression: str = field(
default="FLOAT16", metadata={"help": "Use this compression when averaging parameters/gradients"}
)
@dataclass
class CollaborationArguments(AveragerArguments, CollaborativeOptimizerArguments, BaseTrainingArguments):
statistics_expiration: float = field(
default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
)
endpoint: Optional[str] = field(
default=None,
metadata={"help": "This node's IP for inbound connections, used when running from behind a proxy"},
)
@dataclass
class DatasetArguments:
dataset_path: Optional[str] = field(
default="data/albert_tokenized_wikitext", metadata={"help": "Path to the tokenized dataset"}
)
tokenizer_path: Optional[str] = field(default="data/tokenizer", metadata={"help": "Path to the tokenizer"})
config_path: Optional[str] = field(
default="https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
metadata={"help": "Path to the model config"},
)
cache_dir: Optional[str] = field(default="data", metadata={"help": "Path to the cache"})
@dataclass
class AlbertTrainingArguments(TrainingArguments):
dataloader_num_workers: int = 4
per_device_train_batch_size: int = 4
per_device_eval_batch_size: int = 4
gradient_accumulation_steps: int = 2
seq_length: int = 512
max_steps: int = 1_000_000 # Albert is actually ready after 125000 steps
learning_rate: float = 0.00176
warmup_steps: int = 5000
adam_epsilon: float = 1e-6
weight_decay: float = 0.01
max_grad_norm: float = 1.0
clamp_value: float = 10000.0
fp16: bool = True
fp16_opt_level: str = "O2"
do_train: bool = True
logging_steps: int = 100
save_total_limit: int = 2
save_steps: int = 500
output_dir: str = "outputs"