forked from j-luo93/ASLI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_run_df.py
64 lines (57 loc) · 2.66 KB
/
get_run_df.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import re
from argparse import ArgumentParser
from pathlib import Path
import pandas as pd
import torch
from sound_law.utils import (load_event, load_stats, read_distance_metrics,
read_matching_metrics)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('folder', type=str)
parser.add_argument('mode', choices=['run', 'irreg', 'regress', 'merger'], type=str)
parser.add_argument('prefix', type=str)
args = parser.parse_args()
folder = Path(args.folder)
if args.mode == 'run':
runs = list(folder.glob('*/'))
# Get all matching scores.
match_dfs = list()
event_dfs = list()
dist_dfs = list()
meta_data = list()
for run in runs:
saved_dict = {k: v.value for k, v in torch.load(run / 'hparams.pth').items()}
lang = saved_dict['tgt_lang']
with open(run / 'best_run', 'r') as fin:
meta_record = {'best_epoch': int(fin.read(-1)), 'run': str(run)}
meta_record.update(saved_dict)
meta_data.append(meta_record)
# FIXME(j_luo) check we are not missing any entry for the dfs.
match_df = read_matching_metrics(run).assign(run=str(run))
event_df = load_event(run).assign(run=str(run))
dist_df = read_distance_metrics(run).assign(run=str(run))
match_dfs.append(match_df)
event_dfs.append(event_df)
dist_dfs.append(dist_df)
all_match_df = pd.concat(match_dfs, ignore_index=True)
all_event_df = pd.concat(event_dfs, ignore_index=True)
all_dist_df = pd.concat(dist_dfs, ignore_index=True)
meta_df = pd.DataFrame(meta_data)
all_match_df.to_csv(f'{args.prefix}_match.tsv', sep='\t', index=False)
all_event_df.to_csv(f'{args.prefix}_event.tsv', sep='\t', index=False)
all_dist_df.to_csv(f'{args.prefix}_dist.tsv', sep='\t', index=False)
meta_df.to_csv(f'{args.prefix}_meta.tsv', sep='\t', index=False)
else:
# Get all the data folders based on the mode.
if args.mode == 'irreg':
runs = [f'data/wikt/pgmc-rand{i}' for i in range(1, 51)]
elif args.mode == 'regress':
runs = [f'data/wikt/pgmc-rand-regress{i}' for i in range(1, 51)]
else:
runs = [f'data/wikt/pgmc-rand-merger{i}' for i in range(1, 51)]
stats_dfs = list()
for run in runs:
stats_df = load_stats(run).assign(data_folder=str(run))
stats_dfs.append(stats_df)
all_stats_df = pd.concat(stats_dfs, ignore_index=True)
all_stats_df.to_csv(f'{args.prefix}_stats.tsv', sep='\t', index=False)