forked from NeuroTechX/moabb
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_within_session_p300.py
137 lines (107 loc) · 3.9 KB
/
plot_within_session_p300.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
===========================
Within Session P300
===========================
This Example shows how to perform a within session analysis on three different
P300 datasets.
We will compare two pipelines :
- Riemannian Geometry
- xDawn with Linear Discriminant Analysis
We will use the P300 paradigm, which uses the AUC as metric.
"""
# Authors: Pedro Rodrigues <[email protected]>
#
# License: BSD (3-clause)
import warnings
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pyriemann.estimation import Xdawn, XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline
import moabb
from moabb.datasets import BNCI2014009
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import P300
##############################################################################
# getting rid of the warnings about the future
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)
moabb.set_log_level("info")
##############################################################################
# This is an auxiliary transformer that allows one to vectorize data
# structures in a pipeline For instance, in the case of a X with dimensions
# Nt x Nc x Ns, one might be interested in a new data structure with
# dimensions Nt x (Nc.Ns)
class Vectorizer(BaseEstimator, TransformerMixin):
def __init__(self):
pass
def fit(self, X, y):
"""fit."""
return self
def transform(self, X):
"""transform. """
return np.reshape(X, (X.shape[0], -1))
##############################################################################
# Create pipelines
# ----------------
#
# Pipelines must be a dict of sklearn pipeline transformer.
pipelines = {}
##############################################################################
# we have to do this because the classes are called 'Target' and 'NonTarget'
# but the evaluation function uses a LabelEncoder, transforming them
# to 0 and 1
labels_dict = {"Target": 1, "NonTarget": 0}
pipelines["RG+LDA"] = make_pipeline(
XdawnCovariances(
nfilter=2, classes=[labels_dict["Target"]], estimator="lwf", xdawn_estimator="scm"
),
TangentSpace(),
LDA(solver="lsqr", shrinkage="auto"),
)
pipelines["Xdw+LDA"] = make_pipeline(
Xdawn(nfilter=2, estimator="scm"), Vectorizer(), LDA(solver="lsqr", shrinkage="auto")
)
##############################################################################
# Evaluation
# ----------
#
# We define the paradigm (P300) and use all three datasets available for it.
# The evaluation will return a dataframe containing a single AUC score for
# each subject / session of the dataset, and for each pipeline.
#
# Results are saved into the database, so that if you add a new pipeline, it
# will not run again the evaluation unless a parameter has changed. Results can
# be overwritten if necessary.
paradigm = P300(resample=128)
dataset = BNCI2014009()
dataset.subject_list = dataset.subject_list[:2]
datasets = [dataset]
overwrite = True # set to True if we want to overwrite cached results
evaluation = WithinSessionEvaluation(
paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite
)
results = evaluation.process(pipelines)
##############################################################################
# Plot Results
# ----------------
#
# Here we plot the results.
fig, ax = plt.subplots(facecolor="white", figsize=[8, 4])
sns.stripplot(
data=results,
y="score",
x="pipeline",
ax=ax,
jitter=True,
alpha=0.5,
zorder=1,
palette="Set1",
)
sns.pointplot(data=results, y="score", x="pipeline", ax=ax, zorder=1, palette="Set1")
ax.set_ylabel("ROC AUC")
ax.set_ylim(0.5, 1)
fig.show()