diff --git a/contexts/ssvep_kalunga_norest.yml b/contexts/ssvep_kalunga_norest.yml new file mode 100644 index 000000000..fbbb03b75 --- /dev/null +++ b/contexts/ssvep_kalunga_norest.yml @@ -0,0 +1,6 @@ +SSVEP: + events: + - "13" + - "17" + - "21" + n_classes: 3 diff --git a/contexts/ssvep_resample.yml b/contexts/ssvep_resample.yml new file mode 100644 index 000000000..5152b898b --- /dev/null +++ b/contexts/ssvep_resample.yml @@ -0,0 +1,2 @@ +SSVEP: + resample: 250.0 diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index 16f3d6bfb..49f6652d5 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -21,7 +21,7 @@ Enhancements Bugs ~~~~ -- None +- Correct :class:`moabb.pipelines.classification.SSVEP_CCA` and :class:`moabb.pipelines.classification.SSVEP_TRCA` behavior (:gh:`XXX`by `Sylvain Chevallier`_) API changes ~~~~~~~~~~~ diff --git a/moabb/benchmark.py b/moabb/benchmark.py index f2493cfa3..ddfd3b264 100644 --- a/moabb/benchmark.py +++ b/moabb/benchmark.py @@ -31,6 +31,15 @@ log = logging.getLogger(__name__) +def _ppl_needs_epochs(pn): + """Check if the pipeline needs MNE epochs as input.""" + ppl_with_epochs = ["braindecode", "Keras", "SSVEP CCA", "TRCA-SSVEP", "MsetCCA-SSVEP"] + if any(s in pn for s in ppl_with_epochs): + return True + else: + return False + + def benchmark( # noqa: C901 pipelines="./pipelines/", evaluations=None, @@ -165,7 +174,7 @@ def benchmark( # noqa: C901 ppl_with_epochs, ppl_with_array = {}, {} for pn, pv in prdgms[paradigm].items(): - if "braindecode" in pn or "Keras" in pn: + if _ppl_needs_epochs(pn): ppl_with_epochs[pn] = pv else: ppl_with_array[pn] = pv diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index d9afff3fa..59133f95a 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -1,16 +1,23 @@ +import logging + import numpy as np import scipy.linalg as linalg from joblib import Parallel, delayed +from mne import BaseEpochs from pyriemann.estimation import Covariances from pyriemann.utils.covariance import covariances from pyriemann.utils.mean import mean_covariance from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.cross_decomposition import CCA +from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_is_fitted from .utils import filterbank +log = logging.getLogger(__name__) + + class SSVEP_CCA(BaseEstimator, ClassifierMixin): """Classifier based on Canonical Correlation Analysis for SSVEP. @@ -21,18 +28,15 @@ class SSVEP_CCA(BaseEstimator, ClassifierMixin): Parameters ---------- - interval : list of length 2 - List of form [tmin, tmax]. With tmin and tmax as defined in the SSVEP - paradigm :meth:`moabb.paradigms.SSVEP` - - freqs : dict with n_classes keys - Frequencies corresponding to the SSVEP stimulation frequencies. - They are used to identify SSVEP classes presents in the data. - - n_harmonics: int + n_harmonics: int, default=3 Number of stimulation frequency's harmonics to be used in the generation of the CCA reference signal. + Attributes + ---------- + classes_: list of int + List of unique classes present in the training data. + References ---------- @@ -42,59 +46,104 @@ class SSVEP_CCA(BaseEstimator, ClassifierMixin): canonical correlation analysis method. Journal of neural engineering, 6(4), 046002. https://doi.org/10.1088/1741-2560/6/4/046002 + + Notes + ----- + .. versionchanged:: 1.1.0 + Use MNE Epochs object as input data instead of numpy array, fix label encoding. """ - def __init__(self, interval, freqs, n_harmonics=3): + def __init__(self, n_harmonics=3): self.Yf = dict() self.cca = CCA(n_components=1) - self.interval = interval - self.slen = interval[1] - interval[0] - self.freqs = freqs self.n_harmonics = n_harmonics self.classes_ = [] - self.one_hot = {} - for i, k in enumerate(freqs.keys()): - self.classes_.append(i) - self.one_hot[k] = i + self.one_hot_ = {} + self.le_, self.slen_, self.freqs_ = None, None, [] def fit(self, X, y, sample_weight=None): """Compute reference sinusoid signal. These sinusoid are generated for each frequency in the dataset - """ - n_times = X.shape[2] - for f in self.freqs: + Parameters + ---------- + X : MNE Epochs + The training data as MNE Epochs object. + y : Unused, + Only for compatibility with scikit-learn + sample_weight : Unused, + Only for compatibility with scikit-learn + + Returns + ------- + self: SSVEP_CCA object + Instance of classifier. + """ + if not isinstance(X, BaseEpochs): + raise ValueError("X should be an MNE Epochs object.") + + self.slen_ = X.times[-1] - X.times[0] + n_times = len(X.times) + self.freqs_ = list(X.event_id.keys()) + self.le_ = LabelEncoder().fit(self.freqs_) + self.classes_ = self.le_.transform(self.freqs_) + for i, k in zip(self.freqs_, self.le_.transform(self.freqs_)): + self.one_hot_[i] = k + + for f in self.freqs_: if f.replace(".", "", 1).isnumeric(): freq = float(f) yf = [] for h in range(1, self.n_harmonics + 1): yf.append( - np.sin(2 * np.pi * freq * h * np.linspace(0, self.slen, n_times)) + np.sin(2 * np.pi * freq * h * np.linspace(0, self.slen_, n_times)) ) yf.append( - np.cos(2 * np.pi * freq * h * np.linspace(0, self.slen, n_times)) + np.cos(2 * np.pi * freq * h * np.linspace(0, self.slen_, n_times)) ) self.Yf[f] = np.array(yf) return self def predict(self, X): - """Predict is made by taking the maximum correlation coefficient.""" + """Predict is made by taking the maximum correlation coefficient. + + Parameters + ---------- + X : MNE Epochs + The data to predict as MNE Epochs object. + + Returns + ------- + y : list of int + Predicted labels. + """ y = [] for x in X: corr_f = {} - for f in self.freqs: + for f in self.freqs_: if f.replace(".", "", 1).isnumeric(): S_x, S_y = self.cca.fit_transform(x.T, self.Yf[f].T) corr_f[f] = np.corrcoef(S_x.T, S_y.T)[0, 1] - y.append(self.one_hot[max(corr_f, key=corr_f.get)]) + y.append(self.one_hot_[max(corr_f, key=corr_f.get)]) return y def predict_proba(self, X): - """Probability could be computed from the correlation coefficient.""" - P = np.zeros(shape=(len(X), len(self.freqs))) + """Probability could be computed from the correlation coefficient. + + Parameters + ---------- + X : MNE Epochs + The data to predict as MNE Epochs object. + + Returns + ------- + proba : ndarray of shape (n_trials, n_classes) + probability of each class for each trial. + """ + P = np.zeros(shape=(len(X), len(self.freqs_))) for i, x in enumerate(X): - for j, f in enumerate(self.freqs): + for j, f in enumerate(self.freqs_): if f.replace(".", "", 1).isnumeric(): S_x, S_y = self.cca.fit_transform(x.T, self.Yf[f].T) P[i, j] = np.corrcoef(S_x.T, S_y.T)[0, 1] @@ -102,22 +151,13 @@ def predict_proba(self, X): class SSVEP_TRCA(BaseEstimator, ClassifierMixin): - """Classifier based on the Task-Related Component Analysis method [1]_ for - SSVEP. + """Task-Related Component Analysis method [1]_ for SSVEP. Parameters ---------- - sfreq : float - Sampling frequency of the data to be analyzed. - - freqs : dict with n_classes keys - Frequencies corresponding to the SSVEP components. These are - necessary to design the filterbank bands. - - downsample: int, default=1 - Factor by which downsample the data. A downsample value of N will result - on a sampling frequency of (sfreq // N) by taking one sample every N of - the original data. In the original TRCA paper [1]_ data are at 250Hz. + n_fbands: int, default=5 + Number of sub-bands to divide the SSVEP frequencies, with filterbank + approach. is_ensemble: bool, default=False If True, predict on new data using the Ensemble-TRCA method described @@ -142,8 +182,6 @@ class SSVEP_TRCA(BaseEstimator, ClassifierMixin): is used. So method='original' and regul='scm' is similar to original implementation. - - Attributes ---------- fb_coefs : list of len (n_fbands) @@ -164,8 +202,11 @@ class SSVEP_TRCA(BaseEstimator, ClassifierMixin): Weight coefficients for the different electrodes which are used as spatial filters for the data. + freqs_: list of str + List of unique frequencies present in the training data. + Reference - ---------- + --------- .. [1] M. Nakanishi, Y. Wang, X. Chen, Y. -T. Wang, X. Gao, and T.-P. Jung, "Enhancing detection of SSVEPs for a high-speed brain speller using @@ -179,27 +220,27 @@ class SSVEP_TRCA(BaseEstimator, ClassifierMixin): Notes ----- .. versionadded:: 0.4.4 + + .. versionchanged:: 1.1.1 + TRCA implementation works with MNE Epochs object, fix labels encoding issue. """ def __init__( self, - interval, - freqs, - downsample=1, + n_fbands=5, is_ensemble=True, method="original", estimator="scm", ): - self.freqs = freqs - self.peaks = np.array([float(f) for f in freqs.keys()]) - self.n_fbands = len(self.peaks) - self.downsample = downsample - self.interval = interval - self.slen = interval[1] - interval[0] self.is_ensemble = is_ensemble - self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] self.estimator = estimator self.method = method + self.n_fbands = n_fbands + self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] + self.one_hot_, self.one_inv_ = {}, {} + self.sfreq_, self.freqs_, self.peaks_ = None, None, None + self.le_, self.classes_, self.n_classes = None, None, None + self.templates_, self.weights_ = None, None def _Q_S_estim(self, data): # Check if X is a single trial (test data) or not @@ -329,7 +370,7 @@ def fit(self, X, y): Parameters ---------- - X : ndarray of shape (n_trials, n_channels, n_samples) + X : MNE Epochs Training data. Trials are grouped by class, divided in n_fbands bands by the filterbank approach and then used to calculate weight vectors and templates for each class and band. @@ -342,33 +383,42 @@ def fit(self, X, y): self: CCA object Instance of classifier. """ - # Downsample data - X = X[:, :, :: self.downsample] - - # Get shape of X and labels - n_trials, n_channels, n_samples = X.shape + if not isinstance(X, BaseEpochs): + raise ValueError("X should be an MNE Epochs object.") - self.sfreq = int(n_samples / self.slen) - self.sfreq = self.sfreq / self.downsample - - self.classes_ = np.unique(y) + n_channels, n_samples = X.info["nchan"], len(X.times) + self.sfreq_ = X.info["sfreq"] + self.freqs_ = list(X.event_id.keys()) + self.peaks_ = np.array([float(f) for f in self.freqs_]) + self.fb_coefs = [(x + 1) ** (-1.25) + 0.25 for x in range(self.n_fbands)] + self.le_ = LabelEncoder().fit(self.freqs_) + self.classes_ = self.le_.transform(self.freqs_) self.n_classes = len(self.classes_) + for i, k in zip(self.freqs_, self.classes_): + self.one_hot_[i] = k + self.one_inv_[k] = i + if self.n_fbands > len(self.peaks_): + log.warning("Try with lower n_fbands if there is an error.") # Initialize the final arrays self.templates_ = np.zeros((self.n_classes, self.n_fbands, n_channels, n_samples)) self.weights_ = np.zeros((self.n_fbands, self.n_classes, n_channels)) - for class_idx in self.classes_: - X_cal = X[y == class_idx] # Select data with a specific label + # for class_idx in self.classes_: + for freq, k in self.one_hot_.items(): + X_cal = X[freq] # Select data with a specific label + # Filterbank approach for band_n in range(self.n_fbands): # Filter the data and compute TRCA - X_filter = filterbank(X_cal, self.sfreq, band_n, self.peaks) + X_filter = filterbank( + X_cal.get_data(copy=False), self.sfreq_, band_n, self.peaks_ + ) w_best, _ = self._compute_trca(X_filter) # Get template by averaging trials and take the best filter for this band - self.templates_[class_idx, band_n, :, :] = np.mean(X_filter, axis=0) - self.weights_[band_n, class_idx, :] = w_best + self.templates_[k, band_n, :, :] = np.mean(X_filter, axis=0) + self.weights_[band_n, k, :] = w_best return self @@ -381,8 +431,8 @@ def predict(self, X): Parameters ---------- - X : ndarray of shape (n_trials, n_channels, n_samples) - Testing data. This will be divided in self.n_fbands using the filter- bank approach, + X : MNE Epochs + Testing data. This will be divided in self.n_fbands using the filterbank approach, then it will be transformed by the different spatial filters and compared to the previously fit templates according to the selected method for analysis (ensemble or not). Finally, correlation scores for all sub-bands of each class will be combined, @@ -398,34 +448,21 @@ def predict(self, X): # Check is fit had been called check_is_fitted(self) - # Check if X is a single trial or not - if X.ndim == 2: - X = X[np.newaxis, ...] - - # Downsample data - X = X[:, :, :: self.downsample] - - # Get test data shape - n_trials, _, _ = X.shape - # Initialize pred array y_pred = [] - for trial_n in range(n_trials): - # Pick trial - X_test = X[trial_n, :, :] - + for X_test in X: # Initialize correlations array corr_array = np.zeros((self.n_fbands, self.n_classes)) # Filter the data in the corresponding band for band_n in range(self.n_fbands): - X_filter = filterbank(X_test, self.sfreq, band_n, self.peaks) + X_filter = filterbank(X_test, self.sfreq_, band_n, self.peaks_) # Compute correlation with all the templates and bands - for class_idx in range(self.n_classes): + for freq, k in self.one_hot_.items(): # Get the corresponding template - template = np.squeeze(self.templates_[class_idx, band_n, :, :]) + template = np.squeeze(self.templates_[k, band_n, :, :]) if self.is_ensemble: w = np.squeeze( @@ -433,7 +470,8 @@ def predict(self, X): ).T # (n_classes, n_channel) else: w = np.squeeze( - self.weights_[band_n, class_idx, :] + # self.weights_[band_n, class_idx, :] + self.weights_[band_n, k, :] ).T # (n_channel,) # Compute 2D correlation of spatially filtered testdata with ref @@ -441,14 +479,14 @@ def predict(self, X): np.dot(X_filter.T, w).flatten(), np.dot(template.T, w).flatten(), ) - corr_array[band_n, class_idx] = r[0, 1] + corr_array[band_n, k] = r[0, 1] # Fusion for the filterbank analysis - rho = np.dot(self.fb_coefs, corr_array) + self.rho = np.dot(self.fb_coefs, corr_array) - # Select the maximal value and append to preddictions - tau = np.argmax(rho) - y_pred.append(tau) + # Select the maximal value and append to predictions + self.tau = np.argmax(self.rho) + y_pred.append(self.one_hot_[self.one_inv_[self.tau]]) return y_pred @@ -477,51 +515,37 @@ def predict_proba(self, X): # Check is fit had been called check_is_fitted(self) - - # Check if X is a single trial or not - if X.ndim == 2: - X = X[np.newaxis, ...] - - # Downsample data - X = X[:, :, :: self.downsample] - - # Get test data shape - n_trials, _, _ = X.shape + n_trials = len(X) # Initialize pred array - y_pred = np.zeros((n_trials, len(self.peaks))) - - for trial_n in range(n_trials): - # Pick trial - X_test = X[trial_n, :, :] + y_pred = np.zeros((n_trials, self.n_classes)) + for trial_n, X_test in enumerate(X): # Initialize correlations array corr_array = np.zeros((self.n_fbands, self.n_classes)) # Filter the data in the corresponding band for band_n in range(self.n_fbands): - X_filter = filterbank(X_test, self.sfreq, band_n, self.peaks) + X_filter = filterbank(X_test, self.sfreq_, band_n, self.peaks_) # Compute correlation with all the templates and bands - for class_idx in range(self.n_classes): + for freq, k in self.one_hot_.items(): # Get the corresponding template - template = np.squeeze(self.templates_[class_idx, band_n, :, :]) + template = np.squeeze(self.templates_[k, band_n, :, :]) if self.is_ensemble: w = np.squeeze( self.weights_[band_n, :, :] ).T # (n_class, n_channel) else: - w = np.squeeze( - self.weights_[band_n, class_idx, :] - ).T # (n_channel,) + w = np.squeeze(self.weights_[band_n, k, :]).T # (n_channel,) # Compute 2D correlation of spatially filtered testdata with ref r = np.corrcoef( np.dot(X_filter.T, w).flatten(), np.dot(template.T, w).flatten(), ) - corr_array[band_n, class_idx] = r[0, 1] + corr_array[band_n, k] = r[0, 1] normalized_coefs = self.fb_coefs / (np.sum(self.fb_coefs)) # Fusion for the filterbank analysis @@ -562,15 +586,21 @@ class SSVEP_MsetCCA(BaseEstimator, ClassifierMixin): Parameters ---------- - freqs : dict with n_classes keys - Frequencies corresponding to the SSVEP stimulation frequencies. - They are used to identify SSVEP classes presents in the data. - - n_filters: int + n_filters: int, default=1 Number of multisets spatial filters used per sample data. It corresponds to the number of eigen vectors taken the solution of the MAXVAR objective function as formulated in Eq.5 in [1]_. + n_jobs: int, default=1 + Number of jobs to run whitening in parallel. + + Attributes + ---------- + classes_ : ndarray of shape (n_classes,) + Array with the class labels extracted at fit time. + + freqs_: list of str + List of unique frequencies present in the training data. References ---------- @@ -585,20 +615,38 @@ class SSVEP_MsetCCA(BaseEstimator, ClassifierMixin): .. versionadded:: 0.5.0 """ - def __init__(self, freqs, n_filters=1, n_jobs=1): + def __init__(self, n_filters=1, n_jobs=1): self.n_jobs = n_jobs self.n_filters = n_filters - self.freqs = freqs self.cca = CCA(n_components=1) + self.freqs_, self.le_, self.classes_ = [], None, None + self.one_hot_, self.Ym = {}, {} def fit(self, X, y, sample_weight=None): - """Compute the optimized reference signal at each stimulus - frequency.""" - self.classes_ = np.unique(y) - self.one_hot = {} - for i, k in enumerate(self.classes_): - self.one_hot[k] = i - n_trials, n_channels, n_times = X.shape + """Compute the optimized reference signal at each stimulus frequency. + + Parameters + ---------- + X : MNE Epochs + The training data as MNE Epochs object. + + y : np.ndarray of shape (n_trials,) + The target labels for each trial. + + Returns + ------- + self: SSVEP_MsetCCA object + Instance of classifier. + """ + if not isinstance(X, BaseEpochs): + raise ValueError("X should be an MNE Epochs object.") + + self.freqs_ = list(X.event_id.keys()) + self.le_ = LabelEncoder().fit(self.freqs_) + self.classes_ = self.le_.transform(self.freqs_) + for i, k in zip(self.freqs_, self.le_.transform(self.freqs_)): + self.one_hot_[i] = k + n_trials, n_channels, n_times = len(X), X.info["nchan"], len(X.times) # Whiten signal in parallel if self.n_jobs == 1: @@ -630,14 +678,24 @@ def fit(self, X, y, sample_weight=None): Z = W.transpose((0, 2, 1)) @ X_white # Get Ym - self.Ym = dict() for m_class in self.classes_: self.Ym[m_class] = Z[y == m_class].transpose(2, 0, 1).reshape(-1, n_times) return self def predict(self, X): - """Predict is made by taking the maximum correlation coefficient.""" + """Predict is made by taking the maximum correlation coefficient. + + Parameters + ---------- + X : MNE Epochs + The data to predict as MNE Epochs object. + + Returns + ------- + y : list of int + Predicted labels. + """ # Check is fit had been called check_is_fitted(self) @@ -648,11 +706,22 @@ def predict(self, X): for f in self.classes_: S_x, S_y = self.cca.fit_transform(x.T, self.Ym[f].T) corr_f[f] = np.corrcoef(S_x.T, S_y.T)[0, 1] - y.append(self.one_hot[max(corr_f, key=corr_f.get)]) + y.append(max(corr_f, key=corr_f.get)) return y def predict_proba(self, X): - """Probability could be computed from the correlation coefficient.""" + """Probability could be computed from the correlation coefficient. + + Parameters + ---------- + X : MNE Epochs + The data to predict as MNE Epochs object. + + Returns + ------- + P : ndarray of shape (n_trials, n_classes) + Probability of each class for each trial. + """ # Check is fit had been called check_is_fitted(self) diff --git a/moabb/pipelines/utils.py b/moabb/pipelines/utils.py index ecb2b41fc..75f573550 100644 --- a/moabb/pipelines/utils.py +++ b/moabb/pipelines/utils.py @@ -316,27 +316,37 @@ def filterbank(X, sfreq, idx_fb, peaks): sfreq = sfreq / 2 peaks = np.sort(peaks) + min_freq = np.min(peaks) max_freq = np.max(peaks) if max_freq < 40: - top = 40 + top = 100 else: - top = 60 + top = 115 # Check for Nyquist if top >= sfreq: top = sfreq - 10 # Lowcut frequencies for the pass band (depends on the frequencies of SSVEP) # No more than 3dB loss in the passband - passband = [peaks[i] - 1 for i in range(len(peaks))] + diff = max_freq - min_freq + passband = [min_freq - 2 + x * diff for x in range(7)] + # passband = [peaks[i] - 1 for i in range(len(peaks))] # At least 40db attenuation in the stopband - stopband = [peaks[i] - 2 for i in range(len(peaks))] + if min_freq - 4 > 0: + stopband = [ + min_freq - 4 + x * (diff - 2) if x < 3 else min_freq - 4 + x * diff + for x in range(7) + ] + else: + stopband = [2 + x * (diff - 2) if x < 3 else 2 + x * diff for x in range(7)] + # stopband = [peaks[i] - 2 for i in range(len(peaks))] Wp = [passband[idx_fb] / sfreq, top / sfreq] - Ws = [stopband[idx_fb] / sfreq, (top + 20) / sfreq] + Ws = [stopband[idx_fb] / sfreq, (top + 7) / sfreq] - N, Wn = scp.cheb1ord(Wp, Ws, 3, 15) # Chebyshev type I filter order selection. + N, Wn = scp.cheb1ord(Wp, Ws, 3, 40) # Chebyshev type I filter order selection. B, A = scp.cheby1(N, 0.5, Wn, btype="bandpass") # Chebyshev type I filter design diff --git a/moabb/tests/classification.py b/moabb/tests/classification.py index 6e39dc229..3fc2fb7e5 100644 --- a/moabb/tests/classification.py +++ b/moabb/tests/classification.py @@ -4,7 +4,105 @@ from moabb.datasets.fake import FakeDataset from moabb.paradigms import SSVEP -from moabb.pipelines import SSVEP_MsetCCA +from moabb.pipelines import SSVEP_CCA, SSVEP_TRCA, SSVEP_MsetCCA + + +class TestSSVEP_CCA(unittest.TestCase): + def setUp(self): + # Use moabb generated dataset for test + dataset = FakeDataset(n_sessions=1, n_runs=1, n_subjects=1, paradigm="ssvep") + paradigm = SSVEP(n_classes=3) + X, y, _ = paradigm.get_data(dataset) + self.freqs = paradigm.used_events(dataset) + self.n_harmonics = 3 + self.X = X + self.y = y + self.clf = SSVEP_CCA(n_harmonics=self.n_harmonics) + + def test_fit(self): + self.clf.fit(self.X, self.y) + self.assertTrue(hasattr(self.clf, "freqs_")) + self.assertTrue(hasattr(self.clf, "classes_")) + self.assertTrue(hasattr(self.clf, "le_")) + self.assertTrue(hasattr(self.clf, "one_hot_")) + self.assertTrue(hasattr(self.clf, "slen_")) + + def test_predict(self): + self.clf.fit(self.X, self.y) + y_pred = self.clf.predict(self.X) + self.assertEqual(len(y_pred), len(self.X)) + + def test_predict_proba(self): + self.clf.fit(self.X, self.y) + P = self.clf.predict_proba(self.X) + self.assertEqual(P.shape[0], len(self.X)) + self.assertEqual(P.shape[1], len(self.freqs_)) + + def test_fit_predict_is_fitted(self): + self.assertRaises(NotFittedError, self.clf.predict, self.X) + self.assertRaises(NotFittedError, self.clf.predict_proba, self.X) + self.clf.fit(self.X, self.y) + check_is_fitted( + self.clf, attributes=["classes_", "one_hot_", "slen_", "freqs_", "le_"] + ) + + +class TestSSVEP_TRCA(unittest.TestCase): + def setUp(self): + # Use moabb generated dataset for test + dataset = FakeDataset(n_sessions=1, n_runs=1, n_subjects=1, paradigm="ssvep") + self.n_classes = 3 + paradigm = SSVEP(n_classes=self.n_classes) + X, y, _ = paradigm.get_data(dataset) + self.freqs = paradigm.used_events(dataset) + self.n_fbands = 3 + self.X = X + self.y = y + + def test_fit(self): + for method in ["original", "riemann", "logeuclid"]: + for estimator in ["scm", "lwf", "oas"]: + self.clf = SSVEP_TRCA( + n_fbands=self.n_fbands, method=method, estimator=estimator + ) + self.clf.fit(self.X, self.y) + self.assertTrue(hasattr(self.clf, "freqs_")) + self.assertTrue(hasattr(self.clf, "peaks_")) + self.assertTrue(hasattr(self.clf, "classes_")) + self.assertTrue(hasattr(self.clf, "n_classes")) + self.assertTrue(hasattr(self.clf, "le_")) + self.assertTrue(hasattr(self.clf, "one_hot_")) + self.assertTrue(hasattr(self.clf, "one_inv_")) + self.assertTrue(hasattr(self.clf, "sfreq_")) + + def test_predict(self): + self.clf.fit(self.X, self.y) + y_pred = self.clf.predict(self.X) + self.assertEqual(len(y_pred), len(self.X)) + + def test_predict_proba(self): + self.clf.fit(self.X, self.y) + P = self.clf.predict_proba(self.X) + self.assertEqual(P.shape[0], len(self.X)) + self.assertEqual(P.shape[1], len(self.n_classes)) + + def test_fit_predict_is_fitted(self): + self.assertRaises(NotFittedError, self.clf.predict, self.X) + self.assertRaises(NotFittedError, self.clf.predict_proba, self.X) + self.clf.fit(self.X, self.y) + check_is_fitted( + self.clf, + attributes=[ + "classes_", + "n_classes", + "peaks_", + "one_hot_", + "one_inv_", + "freqs_", + "le_", + "sfreq_", + ], + ) class TestSSVEP_MsetCCA(unittest.TestCase): @@ -17,12 +115,14 @@ def setUp(self): self.n_filters = 2 self.X = X self.y = y - self.clf = SSVEP_MsetCCA(freqs=self.freqs, n_filters=self.n_filters) + self.clf = SSVEP_MsetCCA(n_filters=self.n_filters) def test_fit(self): self.clf.fit(self.X, self.y) + self.assertTrue(hasattr(self.clf, "freqs_")) self.assertTrue(hasattr(self.clf, "classes_")) - self.assertTrue(hasattr(self.clf, "one_hot")) + self.assertTrue(hasattr(self.clf, "le_")) + self.assertTrue(hasattr(self.clf, "one_hot_")) self.assertTrue(hasattr(self.clf, "Ym")) def test_predict(self): @@ -34,13 +134,15 @@ def test_predict_proba(self): self.clf.fit(self.X, self.y) P = self.clf.predict_proba(self.X) self.assertEqual(P.shape[0], len(self.X)) - self.assertEqual(P.shape[1], len(self.freqs)) + self.assertEqual(P.shape[1], len(self.classes_)) def test_fit_predict_is_fitted(self): self.assertRaises(NotFittedError, self.clf.predict, self.X) self.assertRaises(NotFittedError, self.clf.predict_proba, self.X) self.clf.fit(self.X, self.y) - check_is_fitted(self.clf, attributes=["classes_", "one_hot", "Ym"]) + check_is_fitted( + self.clf, attributes=["classes_", "one_hot_", "Ym", "freqs_", "le_"] + ) if __name__ == "__main__": diff --git a/moabb/tests/test_pipelines/SSVEP_CCA.yml b/moabb/tests/test_pipelines/SSVEP_CCA.yml index cdf712018..9037fe763 100644 --- a/moabb/tests/test_pipelines/SSVEP_CCA.yml +++ b/moabb/tests/test_pipelines/SSVEP_CCA.yml @@ -8,5 +8,3 @@ pipeline: from: moabb.pipelines.classification parameters: n_harmonics: 3 - interval: [1, 3] - freqs: {"13":0, "17":1} diff --git a/pipelines/CCA-SSVEP.yml b/pipelines/CCA-SSVEP.yml index bae157483..87ed8f9a1 100644 --- a/pipelines/CCA-SSVEP.yml +++ b/pipelines/CCA-SSVEP.yml @@ -8,5 +8,3 @@ pipeline: from: moabb.pipelines.classification parameters: n_harmonics: 3 - interval: [2, 4] - freqs: {"13": 2, "17":3, "21":4} diff --git a/pipelines/MsetCCA-SSVEP.yml b/pipelines/MsetCCA-SSVEP.yml index cc4c8dc45..c49db2bf7 100644 --- a/pipelines/MsetCCA-SSVEP.yml +++ b/pipelines/MsetCCA-SSVEP.yml @@ -1,4 +1,5 @@ name: MsetCCA-SSVEP + paradigms: - SSVEP @@ -9,4 +10,5 @@ pipeline: - name: SSVEP_MsetCCA from: moabb.pipelines.classification parameters: - freqs: {"13":2, "17":3, "21":4} + n_filters: 1 + n_jobs: 1 diff --git a/pipelines/TRCA-SSVEP.yml b/pipelines/TRCA-SSVEP.yml index 5592b0798..07938ac94 100644 --- a/pipelines/TRCA-SSVEP.yml +++ b/pipelines/TRCA-SSVEP.yml @@ -9,5 +9,6 @@ pipeline: - name: SSVEP_TRCA from: moabb.pipelines.classification parameters: - interval: [2, 4] - freqs: {"13":2, "17":3, "21":4} + n_fbands: 5 + is_ensemble: True + method: "original"