-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmaml_utils.py
49 lines (39 loc) · 1.44 KB
/
maml_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import datetime
import os
import torch
import maml_config
def std_checkpoint_fct(ckpt_dir,
current_episode,
current_loss,
params,
buffers,
train_data,
test_data,
maml_hparams: maml_config.MamlHyperParameters,
env_config: maml_config.EnvConfig,
other_config: dict):
if not current_episode % 1000 == 0 and not current_episode == maml_hparams.n_episodes - 1:
return
ckpt_name = os.path.join(
ckpt_dir, f'ep{current_episode}_loss{current_loss}.pt')
state_dict = params | buffers
torch.save(
{
'current_date': datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S"),
'current_episode': current_episode,
'current_loss': current_loss,
'model_state_dict': state_dict,
'train_data': train_data,
'test_data': test_data,
**vars(maml_hparams),
**vars(env_config),
**other_config
}, ckpt_name
)
def get_ckpt_dir(base_dir, anil, run_name):
if run_name is None:
run_name = datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
ckpt_dir = os.path.join(base_dir, run_name) if not anil else os.path.join(
base_dir, 'anil', run_name)
os.makedirs(ckpt_dir, exist_ok=True)
return ckpt_dir