From 42e6b7b3827cd9c570faf9ef47c7927f11751755 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Nov 2024 10:39:54 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks --- moabb/evaluations/metasplitters.py | 7 +++++-- moabb/evaluations/splitters.py | 19 ++++++++++++------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/moabb/evaluations/metasplitters.py b/moabb/evaluations/metasplitters.py index a3b0ae75b..5e95fb626 100644 --- a/moabb/evaluations/metasplitters.py +++ b/moabb/evaluations/metasplitters.py @@ -2,6 +2,7 @@ from moabb.evaluations.utils import sort_group + class PseudoOnlineSplit(BaseCrossValidator): """Pseudo-online split for evaluation test data. @@ -69,7 +70,9 @@ def split(self, indices, y, metadata=None): break # Take the fist run as calibration else: if self.calib_size is None: - raise ValueError('Data contains just one run. Need to provide calibration size.') + raise ValueError( + "Data contains just one run. Need to provide calibration size." + ) # Take first #calib_size samples as calibration calib_size = self.calib_size calib_ix = group[:calib_size].index @@ -79,6 +82,6 @@ def split(self, indices, y, metadata=None): else: if self.calib_size is None: - raise ValueError('Need to provide calibration size.') + raise ValueError("Need to provide calibration size.") calib_size = self.calib_size yield list(indices[:calib_size]), list(indices[calib_size:]) diff --git a/moabb/evaluations/splitters.py b/moabb/evaluations/splitters.py index fe72ba037..a0b507d51 100644 --- a/moabb/evaluations/splitters.py +++ b/moabb/evaluations/splitters.py @@ -1,4 +1,3 @@ -import copy from sklearn.model_selection import BaseCrossValidator, StratifiedKFold from sklearn.utils import check_random_state @@ -71,12 +70,12 @@ class WithinSessionSplitter(BaseCrossValidator): def __init__( self, - cv = StratifiedKFold, - custom_cv = False, + cv=StratifiedKFold, + custom_cv=False, n_folds: int = 5, random_state: int = 42, shuffle: bool = True, - calib_size: int = None + calib_size: int = None, ): self.n_folds = n_folds self.shuffle = shuffle @@ -87,7 +86,11 @@ def __init__( def get_n_splits(self, metadata): num_sessions_subjects = metadata.groupby(["subject", "session"]).ngroups - return self.cv.get_n_splits(metadata) if self.custom_cv else self.n_folds * num_sessions_subjects + return ( + self.cv.get_n_splits(metadata) + if self.custom_cv + else self.n_folds * num_sessions_subjects + ) def split(self, y, metadata, **kwargs): all_index = metadata.index.values @@ -115,7 +118,9 @@ def split(self, y, metadata, **kwargs): # Handle custom splitter if isinstance(self.cv(), PseudoOnlineSplit): splitter = self.cv(calib_size=self.calib_size) - for calib_ix, test_ix in splitter.split(indices, group_y, subject_metadata[session_mask]): + for calib_ix, test_ix in splitter.split( + indices, group_y, subject_metadata[session_mask] + ): yield calib_ix, test_ix else: # Handle standard CV like StratifiedKFold @@ -125,4 +130,4 @@ def split(self, y, metadata, **kwargs): random_state=self.random_state.randint(0, 2**10), ) for train_ix, test_ix in splitter.split(indices, group_y): - yield indices[train_ix], indices[test_ix] \ No newline at end of file + yield indices[train_ix], indices[test_ix]