-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathstatistic.py
76 lines (60 loc) · 2.08 KB
/
statistic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import glob
import json
import numpy as np
import pickle as pkl
from tqdm import tqdm
statis = {}
with open('datasets/CROHME/words_dict.txt') as f:
chars = f.readlines()
chars = {chars[i].strip(): i for i in range(len(chars))}
with open('datasets/CROHME/train_labels.txt') as f:
lines = f.readlines()
for line in tqdm(lines):
if r'\sqrt [' in line:
name, *symbols = line.split()
tmp = []
stack = []
for i in range(len(symbols)):
if symbols[i] == r'\sqrt' and i + 1 < len(symbols) and symbols[i+1] == '[':
tmp.append(r'\sqrt')
stack.append(r'\sqrt')
elif symbols[i] == '[' and tmp[-1] == r'\sqrt':
continue
elif symbols[i] == ']' and len(tmp):
tmp.pop()
else:
stack.append(symbols[i])
line = name + '\t' + ' '.join(stack)
line = line.replace(' { ', ' ').replace(' } ', ' ').replace('^', '').replace('_', '')
name, *symbols = line.split()
symbols = list(set(symbols))
for i in range(len(symbols)):
item = symbols[i]
if item =='\\':
print(line)
if item not in statis:
statis[item] = {
'num': 1,
'chars': {}
}
else:
statis[item]['num'] += 1
for j in range(len(symbols)):
if symbols[j] == item:
continue
if symbols[j] not in statis[item]['chars']:
statis[item]['chars'][symbols[j]] = 1
else:
statis[item]['chars'][symbols[j]] += 1
matrix = np.zeros((len(chars), len(chars)))
for item in statis:
for sym in statis[item]['chars']:
matrix[chars[item]][chars[sym]] = statis[item]['chars'][sym]/statis[item]['num']
for i in range(len(matrix)):
for j in range(i):
matrix[i][j] = matrix[j][i] = (matrix[i][j] + matrix[j][i]) / 2
with open('symbol_statistic_v1.json','w') as f:
json.dump(statis, f, ensure_ascii=False)
with open('symbol_statistic_v1.pkl', 'wb') as f:
pkl.dump(matrix,f)