forked from j-luo93/ASLI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_char_emb.py
47 lines (38 loc) · 1.84 KB
/
extract_char_emb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from argparse import ArgumentParser
import pandas as pd
import torch
from dev_misc import g, get_tensor
from sound_law.main import setup
from sound_law.s2s.base_model import get_emb_params
from sound_law.s2s.module import PhonoEmbedding
from sound_law.train.manager import OneToManyManager
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('saved_g_path', type=str, help='Path to the saved g.')
parser.add_argument('saved_model_path', type=str, help='Path to the saved model.')
parser.add_argument('out_name', type=str, help='Path to save the output. No suffix should be included.')
args = parser.parse_args()
if '.' in args.out_name:
raise ValueError(f'No suffix should be included.')
initiator = setup()
initiator.run(saved_g_path=args.saved_g_path)
_, _, abc, _ = OneToManyManager.prepare_raw_data()
assert g.share_src_tgt_abc
sd = torch.load(args.saved_model_path)
emb_params = get_emb_params(len(abc),
phono_feat_mat=get_tensor(abc.pfm),
special_ids=get_tensor(abc.special_ids))
emb = PhonoEmbedding.from_params(emb_params)
prefix = 'encoder.embedding'
emb.load_state_dict({'weight': sd[f'{prefix}.weight'],
'special_weight': sd[f'{prefix}.special_weight'],
'special_mask': sd[f'{prefix}.special_mask'],
'pfm': sd[f'{prefix}.pfm']})
emb.cuda()
char_emb = emb.char_embedding.detach().cpu().numpy()
size = char_emb.shape[-1]
cols = [f'vec_{i}' for i in range(size)]
df = pd.DataFrame(char_emb, columns=cols)
df.to_csv(args.out_name + '.tsv', sep='\t', index=None, header=None)
meta_df = pd.DataFrame(list(abc), columns=['unit'])
meta_df.to_csv(args.out_name + '.meta.tsv', sep='\t', index=None, header=None)