Skip to content

Commit

Permalink
[ENH] Tiny changes to dss.dss0() (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbara authored Nov 27, 2023
1 parent 89034f5 commit dc3c29c
Show file tree
Hide file tree
Showing 16 changed files with 72 additions and 68 deletions.
6 changes: 3 additions & 3 deletions citation.cff
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ authors:
given-names: "Nicolas"
orcid: "https://orcid.org/0000-0003-1495-561X"
title: "MEEGkit"
version: 0.1.4
doi: 10.5281/zenodo.5643659
date-released: 2021-10-15
version: 0.1.5
doi: 10.5281/zenodo.10210992
date-released: 2023-11-27
url: "https://github.com/nbara/python-meegkit"
2 changes: 1 addition & 1 deletion meegkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""M/EEG denoising utilities in python."""
__version__ = "0.1.4"
__version__ = "0.1.5"

from . import asr, cca, detrend, dss, lof, ress, sns, star, trca, tspca, utils

Expand Down
2 changes: 1 addition & 1 deletion meegkit/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
pyriemann = None


class ASR():
class ASR:
"""Artifact Subspace Reconstruction.
Artifact subspace reconstruction (ASR) is an automatic, online,
Expand Down
6 changes: 3 additions & 3 deletions meegkit/cca.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def mcca(C, n_channels, n_keep=[]):
W = whiten_nt(CC, keep=True)
A[ix0:ix1, ix0:ix1] = W

C = A.T.dot(C.dot(A))
C = A.T @ C @ A

# final PCA
V, d = pca(C, thresh=None) # don't threshold the PCA to keep n_channels
A = A.dot(V)
C = V.T.dot(C.dot(V))
A = A @ V
C = V.T @ C @ V
scores = np.diag(C)

AA = []
Expand Down
2 changes: 1 addition & 1 deletion meegkit/detrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False):
yy = demean(x)

# PCA
V, _ = pca(rr.T.dot(rr), thresh=threshold)
V, _ = pca(rr.T @ rr, thresh=threshold)
rrr = rr.dot(V)

# Regression (OLS)
Expand Down
51 changes: 27 additions & 24 deletions meegkit/dss.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def dss1(X, weights=None, keep1=None, keep2=1e-12):
return todss, fromdss, pwr0, pwr1


def dss0(c0, c1, keep1=None, keep2=1e-9):
def dss0(c0, c1, keep1=None, keep2=1e-9, return_unmixing=True):
"""DSS base function.
This function allows specifying arbitrary bias functions (as compared to
Expand All @@ -84,13 +84,16 @@ def dss0(c0, c1, keep1=None, keep2=1e-9):
Number of PCs to retain (default=None, which keeps all).
keep2: float
Ignore PCs smaller than keep2 (default=1e-9).
return_unmixing : bool
If True (default), return the unmixing matrix.
Returns
-------
todss: array, shape=(n_dss_components, n_chans)
todss: array, shape=(n_chans, n_dss_components)
Matrix to convert X to normalized DSS components.
fromdss : array, shape=()
Matrix to transform back to original space.
fromdss : array, shape=(n_dss_components, n_chans)
Matrix to transform back to original space. Only returned if
``return_unmixing`` is True.
pwr0: array
Power per component (baseline).
pwr1: array
Expand All @@ -101,46 +104,46 @@ def dss0(c0, c1, keep1=None, keep2=1e-9):
The data mean is NOT removed prior to processing.
"""
if c0 is None or c1 is None:
raise AttributeError("dss0 needs at least two arguments")
if c0.shape != c1.shape:
raise AttributeError("c0 and c1 should have same size")
if c0.shape[0] != c0.shape[1]:
raise AttributeError("c0 should be square")
if np.any(np.isnan(c0)) or np.any(np.isinf(c0)):
raise ValueError("NaN or INF in c0")
if np.any(np.isnan(c1)) or np.any(np.isinf(c1)):
raise ValueError("NaN or INF in c1")
# Check size and squareness
assert c0.shape == c1.shape == (c0.shape[0], c0.shape[0]), \
"c0 and c1 should have the same size, and be square"

# Check for NaN or INF
assert not (np.any(np.isnan(c0)) or np.any(np.isinf(c0))), "NaN or INF in c0"
assert not (np.any(np.isnan(c1)) or np.any(np.isinf(c1))), "NaN or INF in c1"

# derive PCA and whitening matrix from unbiased covariance
eigvec0, eigval0 = pca(c0, max_comps=keep1, thresh=keep2)

# apply whitening and PCA matrices to the biased covariance
# (== covariance of bias whitened data)
W = np.sqrt(1. / eigval0) # diagonal of whitening matrix
W = np.diag(np.sqrt(1. / eigval0)) # diagonal of whitening matrix

# c1 is projected into whitened PCA space of data channels
c2 = (W * eigvec0).T.dot(c1).dot(eigvec0) * W
c2 = (eigvec0 @ W).T @ c1 @ (eigvec0 @ W)

# proj. matrix from whitened data space to a space maximizing bias
eigvec2, eigval2 = pca(c2, max_comps=keep1, thresh=keep2)

# DSS matrix (raw data to normalized DSS)
todss = (W[np.newaxis, :] * eigvec0).dot(eigvec2)
fromdss = linalg.pinv(todss)
todss = eigvec0 @ W @ eigvec2

# Normalise DSS matrix
N = np.sqrt(1. / np.diag(np.dot(np.dot(todss.T, c0), todss)))
todss = todss * N
N = np.sqrt(np.diag(todss.T @ c0 @ todss))
todss /= N

pwr0 = np.sqrt(np.sum(np.dot(c0, todss) ** 2, axis=0))
pwr1 = np.sqrt(np.sum(np.dot(c1, todss) ** 2, axis=0))
pwr0 = np.sqrt(np.sum((c0 @ todss) ** 2, axis=0))
pwr1 = np.sqrt(np.sum((c1 @ todss) ** 2, axis=0))

# Return data
# next line equiv. to: np.array([np.dot(todss, ep) for ep in data])
# dss_data = np.einsum('ij,hjk->hik', todss, data)

return todss, fromdss, pwr0, pwr1
if return_unmixing:
fromdss = linalg.pinv(todss)
return todss, fromdss, pwr0, pwr1
else:
return todss, pwr0, pwr1


def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,
Expand Down Expand Up @@ -373,7 +376,7 @@ def nan_basic_interp(array):
ax.flat[3].plot(np.arange(iterations + 1), aggr_resid, marker="o")
ax.flat[3].set_title("Iterations")

f.set_tight_layout(True)
plt.tight_layout()
plt.savefig(f"{prefix}_{iterations:03}.png")
plt.close("all")

Expand Down
2 changes: 1 addition & 1 deletion meegkit/lof.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sklearn.neighbors import LocalOutlierFactor


class LOF():
class LOF:
"""Local Outlier Factor.
Local Outlier Factor (LOF) is an automatic, density-based outlier detection
Expand Down
2 changes: 0 additions & 2 deletions meegkit/trca.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ def predict(self, X):
----------
X: array, shape=(n_samples, n_chans[, n_trials])
Test data.
model: dict
Fitted model to be used in testing phase.
Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions meegkit/tspca.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12):
# TSPCA: clean x by removing regression on time-shifted refs
y = np.zeros((n_samples_X, n_chans_X, n_trials_X))
for t in np.arange(n_trials_X):
r = multishift(R[..., t], shifts, reshape=True)
y[..., t] = X[:z.shape[0], :, t] - (r @ regression)
z = multishift(R[..., t], shifts, reshape=True) @ regression
y[..., t] = X[:z.shape[0], :, t] - z

y, mean2 = demean(y, wX, return_mean=True, inplace=True)
idx = np.arange(offset1, initial_samples - offset2)
Expand Down
2 changes: 1 addition & 1 deletion meegkit/utils/auditory.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def erbspace(flow, fhigh, n):
return y, bw


class GammatoneFilterbank():
class GammatoneFilterbank:
"""Gammatone Filterbank.
This class computes the filter coefficients for a bank of Gammatone
Expand Down
3 changes: 1 addition & 2 deletions meegkit/utils/covariances.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,7 @@ def pca(cov, max_comps=None, thresh=0):

var = 100 * d.sum() / p0
if var < 99:
print("[PCA] Explained variance of selected components : {:.2f}%".
format(var))
print(f"[PCA] Explained variance of selected components : {var:.2f}%")

return V, d

Expand Down
27 changes: 14 additions & 13 deletions meegkit/utils/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def find_outlier_trials(X, thresh=None, show=True):
thresh : float or array of floats
Keep trials less than thresh from mean.
show : bool
If true plot trial deviations before and after.
If True (default), plot trial deviations before and after.
Returns
-------
Expand All @@ -173,9 +173,11 @@ def find_outlier_trials(X, thresh=None, show=True):
"""
if thresh is None:
thresh = [np.inf]
elif isinstance(thresh, float) or isinstance(thresh, int):
thresh = [thresh]
thresh = np.array([np.inf])
elif isinstance(thresh, (float, int)):
thresh = np.array([thresh])
else:
thresh = np.asarray(thresh)

if X.ndim > 3:
raise ValueError("X should be 2D or 3D")
Expand Down Expand Up @@ -206,7 +208,7 @@ def find_outlier_trials(X, thresh=None, show=True):
ax1.axhline(y=thresh[0], color="grey", linestyle=":")
ax1.set_xlabel("Trial #")
ax1.set_ylabel("Normalized deviation from mean")
ax1.set_title("Before, " + str(len(d)), fontsize=10)
ax1.set_title("Before, " + str(len(d)))
ax1.set_xlim(0, len(d) + 1)
plt.draw()

Expand All @@ -215,19 +217,18 @@ def find_outlier_trials(X, thresh=None, show=True):
_, dd = find_outlier_trials(X[:, idx], None, False)
ax2.plot(dd, ls="-")
ax2.set_xlabel("Trial #")
ax2.set_title("After, " + str(len(idx)), fontsize=10)
ax2.set_title("After, " + str(len(idx)))
ax2.yaxis.tick_right()
ax2.set_xlim(0, len(idx) + 1)
plt.show()

thresh.pop(0)
thresh = thresh[1:]
if thresh:
bads2, _ = find_outlier_trials(X[:, idx], thresh, show)
idx2 = idx[bads2]
idx = np.setdiff1d(idx, idx2)
bads, _ = find_outlier_trials(X[:, idx], thresh, show)
idx = np.setdiff1d(idx, idx[bads])

bads = []
bads_accumulated = []
if len(idx) < n_trials:
bads = np.setdiff1d(range(n_trials), idx)
bads_accumulated = np.setdiff1d(range(n_trials), idx)

return bads, d
return bads_accumulated, d
7 changes: 2 additions & 5 deletions meegkit/utils/sig.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,7 @@ def gaussfilt(data, srate, f, fwhm, n_harm=1, shift=0, return_empvals=False, sho
plt.plot(hz, fx, "o-")
plt.xlim([0, None])

title = "Requested: {}, {} Hz\nEmpirical: {}, {} Hz".format(
f, fwhm, empVals[0], empVals[1]
)
title = f"Requested: {f}, {fwhm} Hz\nEmpirical: {empVals[0]}, {empVals[1]} Hz"
plt.title(title)
plt.xlabel("Frequency (Hz)")
plt.ylabel("Amplitude gain")
Expand Down Expand Up @@ -515,8 +513,7 @@ def stmcb(x, u_in=None, q=None, p=None, niter=5, a_in=None):
else:
if len(u_in) != len(x):
raise ValueError(
"stmcb: u_in and x must be of the same size: {} != {}".format(
len(u_in), len(x)))
f"stmcb: u_in and x must be of the same size: {len(u_in)} != {len(x)}")
if a_in is None:
q = 0
_, a_in = prony(x, q, p)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_ress.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False):
n_chans=n_chans, freq=target, sfreq=sfreq,
show=False)
r = ress.RESS(sfreq=sfreq, peak_freq=target, neig_freq=neig_freq,
peak_width=peak_width, neig_width=neig_width, n_keep=n_keep, compute_unmixing=True)
peak_width=peak_width, neig_width=neig_width, n_keep=n_keep,
compute_unmixing=True)
out = r.fit_transform(data)

nfft = 500
Expand Down
19 changes: 12 additions & 7 deletions tests/test_tspca.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ def test_tspca_sns_dss(): # TODO

# remove means
noisy_data = demean(data)
demean(ref)

# Apply TSPCA
# -------------------------------------------------------------------------
# shifts = np.arange(-50, 51)
# print('TSPCA...')
# y_tspca, idx = tspca.tsr(noisy_data, noisy_ref, shifts)[0:2]
# print('\b OK!')
shifts = np.arange(-50, 51)
print("TSPCA...")
y_tspca, idx = tspca.tsr(noisy_data, ref, shifts)[0:2]
print("\b OK!")
y_tspca = noisy_data

# Apply SNS
Expand Down Expand Up @@ -84,7 +83,11 @@ def test_tsr(show=True):
ax[0].plot(x[:500, 0], ":", label="real signal")
ax[1].plot((y - x)[:500], label="residual")
ax[0].legend()
ax[1].legend()
ax[1].set_xlabel("time (samples)")
ax[0].set_title("signals")
ax[1].set_title("residuals")
plt.tight_layout()
# ax[1].legend()
# plt.show()

# Test residual almost 0.0
Expand All @@ -103,9 +106,11 @@ def test_tsr(show=True):
ax[1].plot(x[:500, 0], "grey", label="real signal")
ax[1].plot(y[:500, 0], ":", label="recovered signal")
ax[2].plot((signal - y)[:500, 0], label="before - after")
ax[0].legend()
# ax[0].legend()
ax[1].legend()
ax[2].legend()
ax[1].set_xlabel("time (samples)")
plt.tight_layout()
plt.show()

if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_outliers(show=False):
idx, _ = find_outlier_trials(x, 2, show=show)
np.testing.assert_array_equal(idx, np.arange(5))

idx, _ = find_outlier_trials(x, [2, 2], show=show)
idx, _ = find_outlier_trials(x, [2, 2], show=True)
np.testing.assert_array_equal(idx, np.arange(5))

idx = find_outlier_samples(x, 5)
Expand Down

0 comments on commit dc3c29c

Please sign in to comment.