-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
77 lines (62 loc) · 2.14 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from email import header
# -*- encoding: utf-8 -*-
'''
@File : utils.py
@Time : 2022/05/21 16:26:23
@Author : Fei Gao
@Contact : [email protected]
BNU, Beijing, China
'''
import numpy as np
from numba.typed import Dict
from numba import types, njit
def convert2nb_dict(node_neighbors_dict, tri_neighbors_dict):
node_neighbors_dict_nb = Dict.empty(key_type=types.int64, value_type=types.int64[:])
tri_neighbors_dict_nb = Dict.empty(key_type=types.int64, value_type=types.int64[:,:])
for k, v in node_neighbors_dict.items():
if len(v) == 0:
continue
else:
node_neighbors_dict_nb[int(k)] = np.array(v)
for k,v in tri_neighbors_dict.items():
if len(v) == 0:
continue
else:
tri_neighbors_dict_nb[int(k)] = np.array(list(v))
return node_neighbors_dict_nb, tri_neighbors_dict_nb
def parser_mc_results(rho, cut:bool=True):
rhos_array = np.vstack([x[1] for x in sorted(zip(rho.keys(), rho.values()), key=lambda x:x[0])]).T
if cut:
cut_point = min(np.argwhere(np.count_nonzero(rhos_array, axis=0)>1))[0]
cut_rhos_array = []
for rhos in rhos_array:
clean_rhos = []
for i, rr in enumerate(rhos):
if i<cut_point:
clean_rhos.append(rr)
elif rr==0:
clean_rhos.append(np.nan)
else:
clean_rhos.append(rr)
cut_rhos_array.append(clean_rhos)
cut_rhos_array = np.array(cut_rhos_array)
avg_rhos = np.nanmean(cut_rhos_array, axis=0)
else:
avg_rhos = np.mean(rhos_array, axis=0)
return avg_rhos
def parser_results(rho):
return np.array([x[1] for x in sorted(zip(rho.keys(), rho.values()), key=lambda x:x[0])])
@njit()
def ifConverge(rho:np.ndarray, N:int, threshold:float=1e-3)->bool:
if len(rho) < 100:
return False
else:
std = np.std(rho[-100:]) * N
if std < threshold:
return True
else:
return False
import os
def checkFolder(path):
if not os.path.isdir(path):
os.mkdir(path)