Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Nov 29, 2024
1 parent b87cf25 commit 42e6b7b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
7 changes: 5 additions & 2 deletions moabb/evaluations/metasplitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from moabb.evaluations.utils import sort_group


class PseudoOnlineSplit(BaseCrossValidator):
"""Pseudo-online split for evaluation test data.
Expand Down Expand Up @@ -69,7 +70,9 @@ def split(self, indices, y, metadata=None):
break # Take the fist run as calibration
else:
if self.calib_size is None:
raise ValueError('Data contains just one run. Need to provide calibration size.')
raise ValueError(
"Data contains just one run. Need to provide calibration size."
)
# Take first #calib_size samples as calibration
calib_size = self.calib_size
calib_ix = group[:calib_size].index
Expand All @@ -79,6 +82,6 @@ def split(self, indices, y, metadata=None):

else:
if self.calib_size is None:
raise ValueError('Need to provide calibration size.')
raise ValueError("Need to provide calibration size.")
calib_size = self.calib_size
yield list(indices[:calib_size]), list(indices[calib_size:])
19 changes: 12 additions & 7 deletions moabb/evaluations/splitters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy

from sklearn.model_selection import BaseCrossValidator, StratifiedKFold
from sklearn.utils import check_random_state
Expand Down Expand Up @@ -71,12 +70,12 @@ class WithinSessionSplitter(BaseCrossValidator):

def __init__(
self,
cv = StratifiedKFold,
custom_cv = False,
cv=StratifiedKFold,
custom_cv=False,
n_folds: int = 5,
random_state: int = 42,
shuffle: bool = True,
calib_size: int = None
calib_size: int = None,
):
self.n_folds = n_folds
self.shuffle = shuffle
Expand All @@ -87,7 +86,11 @@ def __init__(

def get_n_splits(self, metadata):
num_sessions_subjects = metadata.groupby(["subject", "session"]).ngroups
return self.cv.get_n_splits(metadata) if self.custom_cv else self.n_folds * num_sessions_subjects
return (
self.cv.get_n_splits(metadata)
if self.custom_cv
else self.n_folds * num_sessions_subjects
)

def split(self, y, metadata, **kwargs):
all_index = metadata.index.values
Expand Down Expand Up @@ -115,7 +118,9 @@ def split(self, y, metadata, **kwargs):
# Handle custom splitter
if isinstance(self.cv(), PseudoOnlineSplit):
splitter = self.cv(calib_size=self.calib_size)
for calib_ix, test_ix in splitter.split(indices, group_y, subject_metadata[session_mask]):
for calib_ix, test_ix in splitter.split(
indices, group_y, subject_metadata[session_mask]
):
yield calib_ix, test_ix
else:
# Handle standard CV like StratifiedKFold
Expand All @@ -125,4 +130,4 @@ def split(self, y, metadata, **kwargs):
random_state=self.random_state.randint(0, 2**10),
)
for train_ix, test_ix in splitter.split(indices, group_y):
yield indices[train_ix], indices[test_ix]
yield indices[train_ix], indices[test_ix]

0 comments on commit 42e6b7b

Please sign in to comment.