-
Notifications
You must be signed in to change notification settings - Fork 187
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding: Pytests for evaluation splitters, and examples for meta split…
…ters
- Loading branch information
Showing
3 changed files
with
209 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import os | ||
import os.path as osp | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from sklearn.model_selection import StratifiedKFold, LeaveOneGroupOut | ||
|
||
from moabb.evaluations.splitters import (CrossSessionSplitter, CrossSubjectSplitter, WithinSessionSplitter) | ||
from moabb.datasets.fake import FakeDataset | ||
from moabb.paradigms.motor_imagery import FakeImageryParadigm | ||
|
||
dataset = FakeDataset(["left_hand", "right_hand"], n_subjects=3, seed=12) | ||
paradigm = FakeImageryParadigm() | ||
|
||
|
||
# Split done for the Within Session evaluation | ||
def eval_split_within_session(): | ||
for subject in dataset.subject_list: | ||
X, y, metadata = paradigm.get_data(dataset=dataset, subjects=[subject]) | ||
sessions = metadata.session | ||
for session in np.unique(sessions): | ||
ix = sessions == session | ||
cv = StratifiedKFold(5, shuffle=True, random_state=42) | ||
X_, y_ = X[ix], y[ix] | ||
for train, test in cv.split(X_, y_): | ||
yield X_[train], X_[test] | ||
|
||
# Split done for the Cross Session evaluation | ||
def eval_split_cross_session(): | ||
for subject in dataset.subject_list: | ||
X, y, metadata = paradigm.get_data(dataset=dataset, subjects=[subject]) | ||
groups = metadata.session.values | ||
cv = LeaveOneGroupOut() | ||
for train, test in cv.split(X, y, groups): | ||
yield X[train], X[test] | ||
|
||
# Split done for the Cross Subject evaluation | ||
def eval_split_cross_subject(): | ||
X, y, metadata = paradigm.get_data(dataset=dataset) | ||
groups = metadata.subject.values | ||
cv = LeaveOneGroupOut() | ||
for train, test in cv.split(X, y, groups): | ||
yield X[train], X[test] | ||
|
||
|
||
def test_within_session(): | ||
X, y, metadata = paradigm.get_data(dataset=dataset) | ||
|
||
split = WithinSessionSplitter(n_folds=5) | ||
|
||
for ix, ((X_train_t, X_test_t), (train, test)) in enumerate( | ||
zip(eval_split_within_session(), split.split(X, y, metadata, random_state=42))): | ||
X_train, X_test = X[train], X[test] | ||
|
||
# Check if the output is the same as the input | ||
assert np.array_equal(X_train, X_train_t) | ||
assert np.array_equal(X_test, X_test_t) | ||
|
||
|
||
def test_cross_session(): | ||
X, y, metadata = paradigm.get_data(dataset=dataset) | ||
|
||
split = CrossSessionSplitter() | ||
|
||
for ix, ((X_train_t, X_test_t), (train, test)) in enumerate( | ||
zip(eval_split_cross_session(), split.split(X, y, metadata))): | ||
X_train, X_test = X[train], X[test] | ||
|
||
# Check if the output is the same as the input | ||
assert np.array_equal(X_train, X_train_t) | ||
assert np.array_equal(X_test, X_test_t) | ||
|
||
|
||
def test_cross_subject(): | ||
X, y, metadata = paradigm.get_data(dataset=dataset) | ||
|
||
split = CrossSubjectSplitter() | ||
|
||
for ix, ((X_train_t, X_test_t), (train, test)) in enumerate( | ||
zip(eval_split_cross_subject(), split.split(X, y, metadata))): | ||
X_train, X_test = X[train], X[test] | ||
|
||
# Check if the output is the same as the input | ||
assert np.array_equal(X_train, X_train_t) | ||
assert np.array_equal(X_test, X_test_t) | ||
|