Skip to content

Commit

Permalink
Enhance AugmentedDataset (#541)
Browse files Browse the repository at this point in the history
* enhance augmented dataset

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* update whats new

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
qbarthelemy and pre-commit-ci[bot] authored Feb 2, 2024
1 parent 99f6e83 commit 2ca20aa
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Enhancements
- Option to interpolate channel in paradigms' `match_all` method (:gh:`480` by `Gregoire Cattan`_)
- Adding leave k-Subjects out evaluations (:gh:`470` by `Bruno Aristimunha`_)
- Update Braindecode dependency to 0.8 (:gh:`542` by `Pierre Guetschel`_)
- Improve transform function of AugmentedDataset (:gh:`541` by `Quentin Barthelemy`_)

Bugs
~~~~
Expand Down
15 changes: 6 additions & 9 deletions moabb/pipelines/features.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import mne
import numpy as np
import scipy.signal as signal
from numpy import concatenate, ndarray
from numpy import ndarray
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import StandardScaler

Expand Down Expand Up @@ -99,20 +99,17 @@ def fit(self, X: ndarray, y: ndarray):

def transform(self, X: ndarray):
if self.order == 1:
X_fin: ndarray = X
X_new: ndarray = X
else:
X_p = X[:, :, : -self.order * self.lag]
X_p = concatenate(
[X_p]
+ [
X_new = np.concatenate(
[
X[:, :, p * self.lag : -(self.order - p) * self.lag]
for p in range(1, self.order)
for p in range(0, self.order)
],
axis=1,
)
X_fin = X_p

return X_fin
return X_new


class StandardScaler_Epoch(BaseEstimator, TransformerMixin):
Expand Down

0 comments on commit 2ca20aa

Please sign in to comment.