diff --git a/moabb/paradigms/base.py b/moabb/paradigms/base.py index 527d51c20..ff5262738 100644 --- a/moabb/paradigms/base.py +++ b/moabb/paradigms/base.py @@ -22,17 +22,17 @@ get_filter_pipeline, get_resample_pipeline, ) +from moabb.utils import MoabbMetaClass log = logging.getLogger(__name__) -class BaseProcessing(metaclass=abc.ABCMeta): +class BaseProcessing(metaclass=MoabbMetaClass): """Base Processing. Please use one of the child classes - Parameters ---------- filters: list of list (defaults [[7, 35]]) @@ -500,7 +500,7 @@ class BaseParadigm(BaseProcessing): ---------- events: List of str | None (default None) - event to use for epoching. If None, default to all events defined in + events to use for epoching. If None, default to all events defined in the dataset. """ diff --git a/moabb/paradigms/motor_imagery.py b/moabb/paradigms/motor_imagery.py index 657a8e814..117ff13bc 100644 --- a/moabb/paradigms/motor_imagery.py +++ b/moabb/paradigms/motor_imagery.py @@ -14,43 +14,8 @@ class BaseMotorImagery(BaseParadigm): """Base Motor imagery paradigm. - Please use one of the child classes + Not to be instantiated. - Parameters - ---------- - - filters: list of list (defaults [[7, 35]]) - bank of bandpass filter to apply. - - events: List of str | None (default None) - event to use for epoching. If None, default to all events defined in - the dataset. - - tmin: float (default 0.0) - Start time (in second) of the epoch, relative to the dataset specific - task interval e.g. tmin = 1 would mean the epoch will start 1 second - after the beginning of the task as defined by the dataset. - - tmax: float | None, (default None) - End time (in second) of the epoch, relative to the beginning of the - dataset specific task interval. tmax = 5 would mean the epoch will end - 5 second after the beginning of the task as defined in the dataset. If - None, use the dataset value. - - baseline: None | tuple of length 2 - The time interval to consider as “baseline” when applying baseline - correction. If None, do not apply baseline correction. - If a tuple (a, b), the interval is between a and b (in seconds), - including the endpoints. - Correction is applied by computing the mean of the baseline period - and subtracting it from the data (see mne.Epochs) - - channels: list of str | None (default None) - list of channel to select. If None, use all EEG channels available in - the dataset. - - resample: float | None (default None) - If not None, resample the eeg data with the sampling rate provided. """ def __init__( @@ -74,9 +39,9 @@ def __init__( ) def is_valid(self, dataset): - ret = True - if not (dataset.paradigm == "imagery"): - ret = False + ret = dataset.paradigm == "imagery" + if not ret: + return ret # check if dataset has required events if self.events: @@ -105,78 +70,28 @@ def scoring(self): return "accuracy" -class SinglePass(BaseMotorImagery): - """Single Bandpass filter motor Imagery. - - Motor imagery paradigm with only one bandpass filter (default 8 to 32 Hz) +class LeftRightImagery(BaseMotorImagery): + """Motor Imagery for left hand/right hand classification. Parameters - ---------- + ----------- + fmin: float (default 8) - cutoff frequency (Hz) for the high pass filter + cutoff frequency (Hz) for the high pass filter. fmax: float (default 32) - cutoff frequency (Hz) for the low pass filter - - events: List of str | None (default None) - event to use for epoching. If None, default to all events defined in - the dataset. - - tmin: float (default 0.0) - Start time (in second) of the epoch, relative to the dataset specific - task interval e.g. tmin = 1 would mean the epoch will start 1 second - after the beginning of the task as defined by the dataset. - - tmax: float | None, (default None) - End time (in second) of the epoch, relative to the beginning of the - dataset specific task interval. tmax = 5 would mean the epoch will end - 5 second after the beginning of the task as defined in the dataset. If - None, use the dataset value. - - baseline: None | tuple of length 2 - The time interval to consider as “baseline” when applying baseline - correction. If None, do not apply baseline correction. - If a tuple (a, b), the interval is between a and b (in seconds), - including the endpoints. - Correction is applied by computing the mean of the baseline period - and subtracting it from the data (see mne.Epochs) - - channels: list of str | None (default None) - list of channel to select. If None, use all EEG channels available in - the dataset. - - resample: float | None (default None) - If not None, resample the eeg data with the sampling rate provided. - """ - - def __init__(self, fmin=8, fmax=32, **kwargs): - if "filters" in kwargs.keys(): - raise (ValueError("MotorImagery does not take argument filters")) - super().__init__(filters=[[fmin, fmax]], **kwargs) - + cutoff frequency (Hz) for the low pass filter. -class FilterBank(BaseMotorImagery): - """Filter Bank MI.""" - - def __init__( - self, - filters=([8, 12], [12, 16], [16, 20], [20, 24], [24, 28], [28, 32]), - **kwargs, - ): - """init.""" - super().__init__(filters=filters, **kwargs) - - -class LeftRightImagery(SinglePass): - """Motor Imagery for left hand/right hand classification. - - Metric is 'roc_auc' """ - def __init__(self, **kwargs): + def __init__(self, fmin=8, fmax=32, **kwargs): if "events" in kwargs.keys(): raise (ValueError("LeftRightImagery dont accept events")) - super().__init__(events=["left_hand", "right_hand"], **kwargs) + if "filters" in kwargs.keys(): + raise (ValueError("LeftRightImagery does not take argument filters")) + super().__init__( + filters=[[fmin, fmax]], events=["left_hand", "right_hand"], **kwargs + ) def used_events(self, dataset): return {ev: dataset.event_id[ev] for ev in self.events} @@ -186,63 +101,59 @@ def scoring(self): return "roc_auc" -class FilterBankLeftRightImagery(FilterBank): - """Filter Bank Motor Imagery for left hand/right hand classification. - - Metric is 'roc_auc' - """ +class FilterBankLeftRightImagery(LeftRightImagery): + """Filter Bank Motor Imagery for left/right hand classification.""" - def __init__(self, **kwargs): + def __init__( + self, + filters=([8, 12], [12, 16], [16, 20], [20, 24], [24, 28], [28, 32]), + **kwargs, + ): if "events" in kwargs.keys(): raise (ValueError("LeftRightImagery dont accept events")) - super().__init__(events=["left_hand", "right_hand"], **kwargs) - - def used_events(self, dataset): - return {ev: dataset.event_id[ev] for ev in self.events} - - @property - def scoring(self): - return "roc_auc" - + super(LeftRightImagery, self).__init__( + filters=filters, events=["left_hand", "right_hand"], **kwargs + ) -class FilterBankMotorImagery(FilterBank): - """Filter bank n-class motor imagery. - Metric is 'roc-auc' if 2 classes and 'accuracy' if more +class MotorImagery(BaseMotorImagery): + """N-class Motor Imagery. Parameters ----------- - events: List of str - event labels used to filter datasets (e.g. if only motor imagery is - desired). + fmin: float (default 8) + cutoff frequency (Hz) for the high pass filter. + + fmax: float (default 32) + cutoff frequency (Hz) for the low pass filter. + + n_classes: int (default number of available classes) + number of MotorImagery classes/events to select. - n_classes: int, - number of classes each dataset must have. If events is given, - requires all imagery sorts to be within the events list. """ - def __init__(self, n_classes=2, **kwargs): - "docstring" - super().__init__(**kwargs) + def __init__(self, fmin=8, fmax=32, n_classes=None, **kwargs): + if "filters" in kwargs.keys(): + raise (ValueError("MotorImagery does not take argument filters")) self.n_classes = n_classes - if self.events is None: log.warning("Choosing from all possible events") - else: + elif self.n_classes is not None: assert n_classes <= len(self.events), "More classes than events specified" + super().__init__(filters=[[fmin, fmax]], **kwargs) def is_valid(self, dataset): - ret = True - if not dataset.paradigm == "imagery": - ret = False - if self.events is None: - if not len(dataset.event_id) >= self.n_classes: - ret = False - else: + ret = dataset.paradigm == "imagery" + if not ret: + return ret + + if self.events is None and self.n_classes: + ret = len(dataset.event_id) >= self.n_classes + elif self.events and self.n_classes: overlap = len(set(self.events) & set(dataset.event_id.keys())) - if not overlap >= self.n_classes: - ret = False + ret = overlap >= self.n_classes + return ret def used_events(self, dataset): @@ -250,8 +161,8 @@ def used_events(self, dataset): if self.events is None: for k, v in dataset.event_id.items(): out[k] = v - if len(out) == self.n_classes: - break + if self.n_classes is None: + self.n_classes = len(out) else: for event in self.events: if event in dataset.event_id.keys(): @@ -288,120 +199,28 @@ def scoring(self): return "accuracy" -class MotorImagery(SinglePass): - """N-class motor imagery. - - Metric is 'roc-auc' if 2 classes and 'accuracy' if more +class FilterBankMotorImagery(MotorImagery): + """Filter bank N-class motor imagery. Parameters ----------- - events: List of str - event labels used to filter datasets (e.g. if only motor imagery is - desired). - - n_classes: int, - number of classes each dataset must have. If events is given, - requires all imagery sorts to be within the events list. - - fmin: float (default 8) - cutoff frequency (Hz) for the high pass filter - - fmax: float (default 32) - cutoff frequency (Hz) for the low pass filter - - tmin: float (default 0.0) - Start time (in second) of the epoch, relative to the dataset specific - task interval e.g. tmin = 1 would mean the epoch will start 1 second - after the beginning of the task as defined by the dataset. - - tmax: float | None, (default None) - End time (in second) of the epoch, relative to the beginning of the - dataset specific task interval. tmax = 5 would mean the epoch will end - 5 second after the beginning of the task as defined in the dataset. If - None, use the dataset value. - - baseline: None | tuple of length 2 - The time interval to consider as “baseline” when applying baseline - correction. If None, do not apply baseline correction. - If a tuple (a, b), the interval is between a and b (in seconds), - including the endpoints. - Correction is applied by computing the mean of the baseline period - and subtracting it from the data (see mne.Epochs) - - channels: list of str | None (default None) - list of channel to select. If None, use all EEG channels available in - the dataset. - - resample: float | None (default None) - If not None, resample the eeg data with the sampling rate provided. + n_classes: int (default number of available classes) + number of MotorImagery classes/events to select. """ - def __init__(self, n_classes=None, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + filters=([8, 12], [12, 16], [16, 20], [20, 24], [24, 28], [28, 32]), + n_classes=None, + **kwargs, + ): self.n_classes = n_classes - if self.events is None: log.warning("Choosing from all possible events") elif self.n_classes is not None: assert n_classes <= len(self.events), "More classes than events specified" - - def is_valid(self, dataset): - ret = True - if not dataset.paradigm == "imagery": - ret = False - elif self.n_classes is None and self.events is None: - pass - elif self.events is None: - if not len(dataset.event_id) >= self.n_classes: - ret = False - else: - overlap = len(set(self.events) & set(dataset.event_id.keys())) - if self.n_classes is not None and not overlap >= self.n_classes: - ret = False - return ret - - def used_events(self, dataset): - out = {} - if self.events is None: - for k, v in dataset.event_id.items(): - out[k] = v - if self.n_classes is None: - self.n_classes = len(out) - else: - for event in self.events: - if event in dataset.event_id.keys(): - out[event] = dataset.event_id[event] - if len(out) == self.n_classes: - break - if len(out) < self.n_classes: - raise ( - ValueError( - f"Dataset {dataset.code} did not have enough " - f"events in {self.events} to run analysis" - ) - ) - return out - - @property - def datasets(self): - if self.tmax is None: - interval = None - else: - interval = self.tmax - self.tmin - return utils.dataset_search( - paradigm="imagery", - events=self.events, - interval=interval, - has_all_events=False, - ) - - @property - def scoring(self): - if self.n_classes == 2: - return "roc_auc" - else: - return "accuracy" + super(MotorImagery, self).__init__(filters=filters, **kwargs) class FakeImageryParadigm(LeftRightImagery): diff --git a/moabb/utils.py b/moabb/utils.py index 04dc3bf83..d5f4fd260 100644 --- a/moabb/utils.py +++ b/moabb/utils.py @@ -1,5 +1,6 @@ """Util functions for moabb.""" +import abc import inspect import logging import os @@ -10,6 +11,7 @@ from typing import TYPE_CHECKING import numpy as np +from docstring_inheritance import NumpyDocstringInheritanceInitMeta from mne import get_config, set_config from mne import set_log_level as sll @@ -227,3 +229,7 @@ def depreciated_func(*args, **kwargs): return func return factory + + +class MoabbMetaClass(abc.ABCMeta, NumpyDocstringInheritanceInitMeta): + pass