-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstep_002_get_wpte_DBS.py
139 lines (93 loc) · 3.42 KB
/
step_002_get_wpte_DBS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# -*- coding: utf-8 -*-
"""
Created on Fri Jan 28 02:47:57 2022
copied 15/03/2022 to debug and refactoring
@author: cagdas
"""
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 2 14:39:53 2022
@author: cagdas
"""
import matplotlib.pyplot as plt
import numpy as np
import pylab as py
import pandas as pd
from kcsd.KCSD import oKCSD3D
from mayavi import mlab
import nibabel as nib
from scipy.signal import filtfilt, butter, iirnotch, spectrogram, hilbert
from scipy.fft import fftshift
from scipy import stats
import scipy.io as sio
import matplotlib.pyplot as plt
import pywt
from dtw import *
# %matplotlib qt
# import pywt
# from password_MK import password
# import os
# from pymef.mef_session import MefSession
#%% load converted file 4 kHz 32 bit
sub_name_list = ['210319','210413','210505','210527',
'210708','210805','210909']
sub_name = sub_name_list[6]
data = np.load('C:/WNencki/processing/dbs_macro/dbs_macro_filtered/DBS_sub_'+sub_name+'_macro_4khz_9sec_'+'ENCODE'+'_filt.npy')
# data_load.close()
#%% functions
def wpte_recon(data):
"""decompose one channel signal with wavelet packet transform and reconstruct for each freaquency band maxlevel = 2^^12 freq bands FS/2/2**maxlevel"""
wp = pywt.WaveletPacket(data=data, wavelet='db5',maxlevel = 11, mode='symmetric')
#changed to 12-11
nodes_all = [n.path for n in wp.get_leaf_nodes(True)]
#print(nodes_all)
trim_point = 872
# trim_point = 4584
recons_wp = np.zeros((516,len(data)+trim_point)) #1000 hz
# recons_wp = np.zeros((512,20488))
for indx, node in enumerate(nodes_all[:516]):
new_wp_1 = pywt.WaveletPacket(data=None, wavelet='db5',maxlevel = 11, mode='symmetric')
new_wp_1[node] = wp[node].data
recons_wp[indx] = new_wp_1.reconstruct(update=False)
return np.float32(recons_wp[:,int(trim_point/2):-int(trim_point/2)])#change it
def wpte_window_fast(data,win_size=4000,win_inc=160):
K, L, M = np.shape(data)
num_win = int(np.floor((M - win_size)/win_inc)+1)
wpte = np.empty((K,L,516,num_win),dtype='float32')
wpte_temp = np.empty((516,M),dtype='float32')
for chan in range(K):
for epoch in range(L):
wpte_temp[:,:] = wpte_recon(data[chan,epoch,:])
wpte_temp_power = np.float32(np.power(wpte_temp,2))
indexer = np.arange(int(win_size)).reshape(1, -1) + int(win_inc) * np.arange(num_win).reshape(-1, 1)
recons_wins = wpte_temp_power[:,indexer]
# print(recons_wins)
wpte_mov = np.mean(recons_wins,axis=2)
# print(wpte_mov)
wpte[chan,epoch,:,:]=wpte_mov
return wpte
#%% wpte test
import time
start_time = time.time()
# sampd = data[0:3,0:3,:]
sampd = data[0:2,0:2,:]
sampd = np.float32(sampd)
wpte_test = wpte_window_fast(sampd)
print("--- %s seconds ---" % (time.time() - start_time))
plt.plot(wpte_test[0,0,10,:])
#%% test
# import time
# start_time = time.time()
# sampd = data[0:1,0:1,:]
# wpte = wpte_window(sampd)
# print("--- %s seconds ---" % (time.time() - start_time))
#%% zscore corrected
# plt.plot(stats.zscore(wpte[2,1,44,:], axis=0, ddof=1))
#%% calculate wpte
import time
start_time = time.time()
wpte = wpte_window_fast(data)
print("--- %s seconds ---" % (time.time() - start_time))
#%% save it
wpte = np.float32(wpte)
np.save('C:/WNencki/processing/dbs_macro/dbs_macro_wpte/DBS_sub_'+sub_name+'_macro_4khz_9sec_'+'ENCODE'+'_filt_wpte',wpte)