Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating data splitters for moabb evaluation #624

Open
wants to merge 29 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bacedc5
Creating new splitters and base evaluation
brunaafl Jun 6, 2024
419b2ca
Adding metasplitters
brunaafl Jun 7, 2024
d6e795d
Fixing LazyEvaluation
brunaafl Jun 10, 2024
140670c
Merge branch 'NeuroTechX:develop' into eval_splitters
brunaafl Jun 10, 2024
d724674
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
a278026
More optimized version of TimeSeriesSplit
brunaafl Jun 10, 2024
300a6b9
More optimized version of TimeSeriesSplit
brunaafl Jun 10, 2024
7cb79f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
55db70f
Addressing some comments: documentation, types, inconsistencies
brunaafl Jun 10, 2024
2851a15
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
c73dd1a
Addressing some comments: optimizing code, adjusts
brunaafl Jun 12, 2024
2b0e735
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2024
cf4b709
Adding examples
brunaafl Jun 26, 2024
177bf65
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2024
a6b5772
Adding: Pytests for evaluation splitters, and examples for meta split…
brunaafl Aug 15, 2024
26b13d5
Changing: name of TimeSeriesSplit to PseudoOnlineSplit
brunaafl Sep 30, 2024
e6661c4
Merge branch 'develop' into eval_splitters
brunaafl Sep 30, 2024
430e3a8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
698e539
Fixing pre-commit
brunaafl Sep 30, 2024
0fff053
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Sep 30, 2024
98d12ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
558d27b
Adding some tests for metasplitters
brunaafl Oct 1, 2024
34ea645
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Oct 1, 2024
b435bf8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2024
d8f26a3
Fixing pre-commit
brunaafl Oct 1, 2024
eaf0fb9
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Oct 1, 2024
e5159f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2024
516a5e8
Fixing pre-commit
brunaafl Oct 1, 2024
b29ecd2
Merge remote-tracking branch 'origin/eval_splitters' into eval_splitters
brunaafl Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/source/images/crosssess.pdf
Binary file not shown.
Binary file added docs/source/images/crosssubj.pdf
Binary file not shown.
Binary file added docs/source/images/withinsess.pdf
Binary file not shown.
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Enhancements
- Add new dataset :class:`moabb.datasets.Rodrigues2017` dataset (:gh:`602` by `Gregoire Cattan`_ and `Pedro L. C. Rodrigues`_)
- Change unittest to pytest (:gh:`618` by `Bruno Aristimunha`_)
- Remove tensorflow import warning (:gh:`622` by `Bruno Aristimunha`_)
- Add data splitter classes (:gh:`612` by `Bruna Lopes_`)

Bugs
~~~~
Expand Down
352 changes: 352 additions & 0 deletions moabb/evaluations/metasplitters.py

Large diffs are not rendered by default.

319 changes: 319 additions & 0 deletions moabb/evaluations/splitters.py
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
import numpy as np
from sklearn.model_selection import (
BaseCrossValidator,
GroupKFold,
LeaveOneGroupOut,
StratifiedKFold,
)


class WithinSessionSplitter(BaseCrossValidator):
"""Data splitter for within session evaluation.

Within-session evaluation uses k-fold cross_validation to determine train
and test sets on separate session for each subject. This splitter assumes that
all data from all subjects is already known and loaded.

. image:: images/withinsess.pdf
:alt: The schematic diagram of the WithinSession split
:align: center

Parameters
----------
n_folds : int
Number of folds. Must be at least 2.

Examples
-----------

>>> import pandas as pd
>>> import numpy as np
>>> from moabb.evaluations.splitters import WithinSessionSplitter
>>> X = np.array([[1, 2], [3, 4], [5, 6], [1,4], [7, 4], [5, 8], [0,3], [2,4]])
>>> y = np.array([1, 2, 1, 2, 1, 2, 1, 2])
>>> subjects = np.array([1, 1, 1, 1, 1, 1, 1, 1])
>>> sessions = np.array(['T', 'T', 'E', 'E', 'T', 'T', 'E', 'E'])
>>> metadata = pd.DataFrame(data={'subject': subjects, 'session': sessions})
>>> csess = WithinSessionSplitter(2)
>>> csess.get_n_splits(metadata)
>>> for i, (train_index, test_index) in enumerate(csess.split(X, y, metadata)):
... print(f"Fold {i}:")
... print(f" Train: index={train_index}, group={subjects[train_index]}, session={sessions[train_index]}")
... print(f" Test: index={test_index}, group={subjects[test_index]}, sessions={sessions[test_index]}")
Fold 0:
Train: index=[2 7], group=[1 1], session=['E' 'E']
Test: index=[3 6], group=[1 1], sessions=['E' 'E']
Fold 1:
Train: index=[3 6], group=[1 1], session=['E' 'E']
Test: index=[2 7], group=[1 1], sessions=['E' 'E']
Fold 2:
Train: index=[4 5], group=[1 1], session=['T' 'T']
Test: index=[0 1], group=[1 1], sessions=['T' 'T']
Fold 3:
Train: index=[0 1], group=[1 1], session=['T' 'T']
Test: index=[4 5], group=[1 1], sessions=['T' 'T']


"""

def __init__(self, n_folds=5):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing that would be nice is to add a shuffle option to pass on these splits in a random order (random on the patients/session and folds).

This will allow to subsample a number of folds with diverse patient/session fi we don't want to do the full CV procedure.

self.n_folds = n_folds

def get_n_splits(self, metadata):
sessions_subjects = metadata.groupby(["subject", "session"]).ngroups
return self.n_folds * sessions_subjects

def split(self, X, y, metadata, **kwargs):

assert isinstance(self.n_folds, int)

subjects = metadata.subject.values
cv = StratifiedKFold(n_splits=self.n_folds, shuffle=True, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to be able to set a random_state for each of this CV to be able to have reproducible splits.


for subject in np.unique(subjects):
mask = subjects == subject
X_, y_, meta_ = (
X[mask],
y[mask],
metadata[mask],
)

sessions = meta_.session.values

for session in np.unique(sessions):
mask_s = sessions == session
X_s, y_s, _ = (
X_[mask_s],
y_[mask_s],
meta_[mask_s],
)

for ix_train, ix_test in cv.split(X_s, y_s):

ix_train_global = np.where(mask)[0][np.where(mask_s)[0][ix_train]]
ix_test_global = np.where(mask)[0][np.where(mask_s)[0][ix_test]]
yield ix_train_global, ix_test_global


class IndividualWithinSessionSplitter(BaseCrossValidator):
"""Data splitter for within session evaluation.

Within-session evaluation uses k-fold cross_validation to determine train
and test sets on separate session for each subject. This splitter does not assume
that all data and metadata from all subjects is already loaded. If X, y and metadata
are from a single subject, it returns data split for this subject only.

It can be used as basis for WithinSessionSplitter or to avoid downloading all data at
once when it is not needed,

Parameters
----------
n_folds : int
Number of folds. Must be at least 2.

"""

def __init__(self, n_folds: int):
self.n_folds = n_folds

def get_n_splits(self, metadata):
return self.n_folds

def split(self, X, y, metadata, **kwargs):

brunaafl marked this conversation as resolved.
Show resolved Hide resolved
assert len(np.unique(metadata.subject)) == 1
assert isinstance(self.n_folds, int)

sessions = metadata.subject.values
cv = StratifiedKFold(n_splits=self.n_folds, shuffle=True, **kwargs)

for session in np.unique(sessions):
mask = sessions == session
X_, y_, _ = (
X[mask],
y[mask],
metadata[mask],
)

for ix_train, ix_test in cv.split(X_, y_):
yield ix_train, ix_test


class CrossSessionSplitter(BaseCrossValidator):
"""Data splitter for cross session evaluation.

Cross-session evaluation uses a Leave-One-Group-Out cross-validation to
evaluate performance across sessions, but for a single subject. This splitter
assumes that all data from all subjects is already known and loaded.

. image:: images/crosssess.pdf
:alt: The schematic diagram of the CrossSession split
:align: center

Parameters
----------
n_folds :
Not used. For compatibility with other cross-validation splitters.
Default:None

Examples
----------
>>> import numpy as np
>>> import pandas as pd
>>> from moabb.evaluations.splitters import CrossSessionSplitter
>>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [8, 9], [5, 4], [2, 5], [1, 7]])
>>> y = np.array([1, 2, 1, 2, 1, 2, 1, 2])
>>> subjects = np.array([1, 1, 1, 1, 2, 2, 2, 2])
>>> sessions = np.array(['T', 'T', 'E', 'E', 'T', 'T', 'E', 'E'])
>>> metadata = pd.DataFrame(data={'subject': subjects, 'session': sessions})
>>> csess = CrossSessionSplitter()
>>> csess.get_n_splits(metadata)
4
>>> for i, (train_index, test_index) in enumerate(csess.split(X, y, metadata)):
... print(f"Fold {i}:")
... print(f" Train: index={train_index}, group={subjects[train_index]}, session={sessions[train_index]}")
... print(f" Test: index={test_index}, group={subjects[test_index]}, sessions={sessions[test_index]}")
Fold 0:
Train: index=[0 1], group=[1 1], session=['T' 'T']
Test: index=[2 3], group=[1 1], sessions=['E' 'E']
Fold 1:
Train: index=[2 3], group=[1 1], session=['E' 'E']
Test: index=[0 1], group=[1 1], sessions=['T' 'T']
Fold 2:
Train: index=[4 5], group=[2 2], session=['T' 'T']
Test: index=[6 7], group=[2 2], sessions=['E' 'E']
Fold 3:
Train: index=[6 7], group=[2 2], session=['E' 'E']
Test: index=[4 5], group=[2 2], sessions=['T' 'T']

"""

def __init__(self, n_folds=None):
self.n_folds = n_folds

def get_n_splits(self, metadata):
sessions_subjects = len(metadata.groupby(["subject", "session"]).first())
return sessions_subjects

def split(self, X, y, metadata):

subjects = metadata.subject.values
split = IndividualCrossSessionSplitter()

for subject in np.unique(subjects):
mask = subjects == subject
X_, y_, meta_ = (
X[mask],
y[mask],
metadata[mask],
)

for ix_train, ix_test in split.split(X_, y_, meta_):
ix_train = np.where(mask)[0][ix_train]
ix_test = np.where(mask)[0][ix_test]
yield ix_train, ix_test


class IndividualCrossSessionSplitter(BaseCrossValidator):
"""Data splitter for cross session evaluation.

Cross-session evaluation uses a Leave-One-Group-Out cross-validation to
evaluate performance across sessions, but for a single subject. This splitter does
not assumethat all data and metadata from all subjects is already loaded. If X, y
and metadata are from a single subject, it returns data split for this subject only.

It can be used as basis for CrossSessionSplitter or to avoid downloading all data at
once when it is not needed,

Parameters
----------
n_folds :
Not used. For compatibility with other cross-validation splitters.
Default:None

"""

def __init__(self, n_folds=None):
self.n_folds = n_folds

def get_n_splits(self, metadata):
sessions = metadata.session.values
return np.unique(sessions)

def split(self, X, y, metadata):
assert len(np.unique(metadata.subject)) == 1

cv = LeaveOneGroupOut()
sessions = metadata.session.values

for ix_train, ix_test in cv.split(X, y, groups=sessions):
yield ix_train, ix_test


class CrossSubjectSplitter(BaseCrossValidator):
"""Data splitter for cross session evaluation.

Cross-session evaluation uses a Leave-One-Group-Out cross-validation to
evaluate performance across sessions, but for a single subject. This splitter
assumes that all data from all subjects is already known and loaded.

. image:: images/crosssubj.pdf
:alt: The schematic diagram of the CrossSubj split
:align: center

Parameters
----------
n_groups : int or None
If None, Leave-One-Subject-Out is performed.
If int, Leave-k-Subjects-Out is performed.

Examples
--------
>>> import numpy as np
>>> import pandas as pd
>>> from moabb.evaluations.splitters import CrossSubjectSplitter
>>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8],[8,9],[5,4],[2,5],[1,7]])
>>> y = np.array([1, 2, 1, 2, 1, 2, 1, 2])
>>> subjects = np.array([1, 1, 2, 2, 3, 3, 4, 4])
>>> metadata = pd.DataFrame(data={'subject': subjects})
>>> csubj = CrossSubjectSplitter()
>>> csubj.get_n_splits(metadata)
4
>>> for i, (train_index, test_index) in enumerate(csubj.split(X, y, metadata)):
... print(f"Fold {i}:")
... print(f" Train: index={train_index}, group={subjects[train_index]}")
... print(f" Test: index={test_index}, group={subjects[test_index]}")
Fold 0:
Train: index=[2 3 4 5 6 7], group=[2 2 3 3 4 4]
Test: index=[0 1], group=[1 1]
Fold 1:
Train: index=[0 1 4 5 6 7], group=[1 1 3 3 4 4]
Test: index=[2 3], group=[2 2]
Fold 2:
Train: index=[0 1 2 3 6 7], group=[1 1 2 2 4 4]
Test: index=[4 5], group=[3 3]
Fold 3:
Train: index=[0 1 2 3 4 5], group=[1 1 2 2 3 3]
Test: index=[6 7], group=[4 4]


"""

def __init__(self, n_groups=None):
self.n_groups = n_groups

def get_n_splits(self, metadata):
return len(metadata.subject.unique())

def split(self, X, y, metadata):

groups = metadata.subject.values

# Define split
if self.n_groups is None:
cv = LeaveOneGroupOut()
else:
cv = GroupKFold(n_splits=self.n_groups)

for ix_train, ix_test in cv.split(metadata, groups=groups):
yield ix_train, ix_test
13 changes: 13 additions & 0 deletions moabb/evaluations/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import re
from pathlib import Path
from pickle import HIGHEST_PROTOCOL, dump
from typing import Sequence

import numpy as np
from numpy import argmax
from sklearn.pipeline import Pipeline

Expand Down Expand Up @@ -222,6 +224,17 @@ def create_save_path(
print("No hdf5_path provided, models will not be saved.")


def sort_group(groups):
runs_sort = []
pattern = r"([0-9]+)(|[a-zA-Z]+[a-zA-Z0-9]*)"
for i, group in enumerate(groups):
index, description = re.fullmatch(pattern, group).groups()
index = int(index)
runs_sort.append(index)
sorted_ix = np.argsort(runs_sort)
return groups[sorted_ix]


def _convert_sklearn_params_to_optuna(param_grid: dict) -> dict:
"""
Function to convert the parameter in Optuna format. This function will
Expand Down
Loading
Loading