From a0241cecd972e64a84772099b4ccd69d769149b8 Mon Sep 17 00:00:00 2001 From: toncho11 Date: Mon, 18 Mar 2024 11:12:39 +0100 Subject: [PATCH] Implementation of NCH (Nearest Convex Hull) classifier (#253) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Initial version of NearestConvexHull. * Added script for testing. * First version that runs. * Improved code. * Added support for parallel processing. It gives an error: AttributeError: Pipeline has none of the following attributes: decision_function. * renamed * New version that uses a new class that implements a NCH classifier. * small update * Updated to newest code - the new version of the distance function. Added an example that runs on a small number of test samples, so that we can get results quicker. * [pre-commit.ci] auto fixes from pre-commit.com hooks * reinforce constraint on weights * - remove constraints on weights - limite size of training set - change to slsqp optimizer * [pre-commit.ci] auto fixes from pre-commit.com hooks * Added n_max_hull parameter. MOABB support tested. * [pre-commit.ci] auto fixes from pre-commit.com hooks * added multiple hulls. * [pre-commit.ci] auto fixes from pre-commit.com hooks * Code cleanups. Added second parameter that specifies the number of hulls. * [pre-commit.ci] auto fixes from pre-commit.com hooks * Improved code. Added support for transform(). Added a new pipeline [NCH+LDA] * [pre-commit.ci] auto fixes from pre-commit.com hooks * updated default parameters * General improvements. Improvements requested by GC. * [pre-commit.ci] auto fixes from pre-commit.com hooks * removed commented code * Small adjustments. * Better class separation. * [pre-commit.ci] auto fixes from pre-commit.com hooks * Added support for n_samples_per_hull = -1 which takes all the samples for a class. * [pre-commit.ci] auto fixes from pre-commit.com hooks * Update pyriemann_qiskit/classification.py Set of SPD matrices. Co-authored-by: Quentin Barthélemy * Update pyriemann_qiskit/classification.py Added new lines to before Parameters Co-authored-by: Quentin Barthélemy * Update pyriemann_qiskit/classification.py [y == c, :, :] => [y == c] Co-authored-by: Quentin Barthélemy * Update pyriemann_qiskit/classification.py NearestConvexHull text change Co-authored-by: Quentin Barthélemy * Improvements proposed by Quentin. * [pre-commit.ci] auto fixes from pre-commit.com hooks * Added comment for the optimizer. * [pre-commit.ci] auto fixes from pre-commit.com hooks * Added some comments in classification. Changes about the global optimizer so, that it is more evident that a global one is used. * [pre-commit.ci] auto fixes from pre-commit.com hooks * Implemented min hull. Added support for both "min-hull" and "random-hull" using the constructor parameter "hull-type". * [pre-commit.ci] auto fixes from pre-commit.com hooks * Reverted to previous version as requested by Gregoire. * fix lint issues * [pre-commit.ci] auto fixes from pre-commit.com hooks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: gcattan Co-authored-by: Gregoire Cattan Co-authored-by: Quentin Barthélemy --- examples/ERP/classify_P300_nch.py | 150 +++++++++++++ pyriemann_qiskit/classification.py | 335 ++++++++++++++++++++++++++++- pyriemann_qiskit/utils/distance.py | 1 + 3 files changed, 484 insertions(+), 2 deletions(-) create mode 100644 examples/ERP/classify_P300_nch.py diff --git a/examples/ERP/classify_P300_nch.py b/examples/ERP/classify_P300_nch.py new file mode 100644 index 00000000..2876edc1 --- /dev/null +++ b/examples/ERP/classify_P300_nch.py @@ -0,0 +1,150 @@ +""" +==================================================================== +Classification of P300 datasets from MOABB using NCH +==================================================================== + +Demonstrates classification with QunatumNCH. +Evaluation is done using MOABB. + +If parameter "shots" is None then a classical SVM is used similar to the one +in scikit learn. +If "shots" is not None and IBM Qunatum token is provided with "q_account_token" +then a real Quantum computer will be used. +You also need to adjust the "n_components" in the PCA procedure to the number +of qubits supported by the real quantum computer you are going to use. +A list of real quantum computers is available in your IBM quantum account. + +""" +# Author: Anton Andreev +# Modified from plot_classify_EEG_tangentspace.py of pyRiemann +# License: BSD (3-clause) + +from pyriemann.estimation import XdawnCovariances +from sklearn.pipeline import make_pipeline +from matplotlib import pyplot as plt +import warnings +import seaborn as sns +from moabb import set_log_level +from moabb.datasets import bi2013a +from moabb.evaluations import WithinSessionEvaluation +from moabb.paradigms import P300 +from pyriemann_qiskit.classification import QuanticNCH +from pyriemann.classification import MDM + +print(__doc__) + +############################################################################## +# getting rid of the warnings about the future +warnings.simplefilter(action="ignore", category=FutureWarning) +warnings.simplefilter(action="ignore", category=RuntimeWarning) + +warnings.filterwarnings("ignore") + +set_log_level("info") + +############################################################################## +# Create Pipelines +# ---------------- +# +# Pipelines must be a dict of sklearn pipeline transformer. + +############################################################################## +# We have to do this because the classes are called 'Target' and 'NonTarget' +# but the evaluation function uses a LabelEncoder, transforming them +# to 0 and 1 +labels_dict = {"Target": 1, "NonTarget": 0} + +paradigm = P300(resample=128) + +datasets = [bi2013a()] # MOABB provides several other P300 datasets + +# reduce the number of subjects, the Quantum pipeline takes a lot of time +# if executed on the entire dataset +n_subjects = 1 +for dataset in datasets: + dataset.subject_list = dataset.subject_list[0:n_subjects] + +overwrite = True # set to True if we want to overwrite cached results + +pipelines = {} + +pipelines["NCH+RANDOM_HULL"] = make_pipeline( + # applies XDawn and calculates the covariance matrix, output it matrices + XdawnCovariances( + nfilter=3, + classes=[labels_dict["Target"]], + estimator="lwf", + xdawn_estimator="scm", + ), + QuanticNCH( + n_hulls_per_class=1, + n_samples_per_hull=3, + n_jobs=12, + hull_type="random-hull", + quantum=False, + ), +) + +pipelines["NCH+MIN_HULL"] = make_pipeline( + # applies XDawn and calculates the covariance matrix, output it matrices + XdawnCovariances( + nfilter=3, + classes=[labels_dict["Target"]], + estimator="lwf", + xdawn_estimator="scm", + ), + QuanticNCH( + n_hulls_per_class=1, + n_samples_per_hull=3, + n_jobs=12, + hull_type="min-hull", + quantum=False, + ), +) + +# this is a non quantum pipeline +pipelines["XD+MDM"] = make_pipeline( + XdawnCovariances( + nfilter=3, + classes=[labels_dict["Target"]], + estimator="lwf", + xdawn_estimator="scm", + ), + MDM(), +) + +print("Total pipelines to evaluate: ", len(pipelines)) + +evaluation = WithinSessionEvaluation( + paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite +) + +results = evaluation.process(pipelines) + +print("Averaging the session performance:") +print(results.groupby("pipeline").mean("score")[["score", "time"]]) + +############################################################################## +# Plot Results +# ---------------- +# +# Here we plot the results to compare the two pipelines + +fig, ax = plt.subplots(facecolor="white", figsize=[8, 4]) + +sns.stripplot( + data=results, + y="score", + x="pipeline", + ax=ax, + jitter=True, + alpha=0.5, + zorder=1, + palette="Set1", +) +sns.pointplot(data=results, y="score", x="pipeline", ax=ax, palette="Set1") + +ax.set_ylabel("ROC AUC") +ax.set_ylim(0.3, 1) + +plt.show() diff --git a/pyriemann_qiskit/classification.py b/pyriemann_qiskit/classification.py index ea2f0b45..1cc94ecf 100644 --- a/pyriemann_qiskit/classification.py +++ b/pyriemann_qiskit/classification.py @@ -10,6 +10,7 @@ import numpy as np from warnings import warn +from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin from pyriemann.classification import MDM from pyriemann_qiskit.datasets import get_feature_dimension from pyriemann_qiskit.utils import ( @@ -24,12 +25,18 @@ from qiskit_ibm_provider import IBMProvider, least_busy from qiskit_machine_learning.algorithms import QSVC, VQC, PegasosQSVC from qiskit_machine_learning.kernels.quantum_kernel import QuantumKernel -from qiskit_optimization.algorithms import CobylaOptimizer -from sklearn.base import BaseEstimator, ClassifierMixin +from qiskit_optimization.algorithms import ( + CobylaOptimizer, + # ADMMOptimizer, + SlsqpOptimizer, +) from sklearn.svm import SVC from .utils.hyper_params_factory import gen_zz_feature_map, gen_two_local, get_spsa from .utils import get_provider, get_devices, get_simulator +from .utils.distance import qdistance_logeuclid_to_convex_hull, distance_logeuclid +from joblib import Parallel, delayed +import random logger.level = logging.WARNING @@ -743,3 +750,327 @@ def predict(self, X): """ labels = self._predict(X) return self._map_indices_to_classes(labels) + + +class NearestConvexHull(BaseEstimator, ClassifierMixin, TransformerMixin): + + """Nearest Convex Hull Classifier (NCH) + + In NCH, for each class a convex hull is produced by the set of matrices + corresponding to each class. There is no training. Calculating a distance + to a hull is an optimization problem and it is calculated for each testing + sample (SPD matrix) and each hull/class. The minimal distance defines the + predicted class. + + Notes + ----- + .. versionadded:: 0.2.0 + + Parameters + ---------- + n_jobs : int, (default=6) + The number of jobs to use for the computation. This works by computing + each of the hulls in parallel. + n_hulls_per_class: int, (default 3) + The number of hulls used per class. + n_samples_per_hull: int, (default 15) + Defines how many samples are used to build a hull. + hull_type: string, (default "min-hull") + Selects how the hull is constructed. Possible values are + "min-hull" and "random-hull" + + References + ---------- + .. [1] \ + K. Zhao, A. Wiliem, S. Chen, and B. C. Lovell, + ‘Convex Class Model on Symmetric Positive Definite Manifolds’, + Image and Vision Computing, 2019. + """ + + def __init__( + self, n_jobs=6, n_hulls_per_class=3, n_samples_per_hull=10, hull_type="min-hull" + ): + """Init.""" + self.n_jobs = n_jobs + self.n_samples_per_hull = n_samples_per_hull + self.n_hulls_per_class = n_hulls_per_class + self.matrices_per_class_ = {} + self.debug = False + self.hull_type = hull_type + + if hull_type not in ["min-hull", "random-hull"]: + raise Exception("Error: Unknown hull type.") + + def fit(self, X, y): + """Fit (store the training data). + + Parameters + ---------- + X : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices. + y : ndarray, shape (n_matrices,) + Labels for each matrix. + sample_weight : None + Not used, here for compatibility with sklearn API. + + Returns + ------- + self : NearestConvexHull instance + The NearestConvexHull instance. + """ + + if self.debug: + print("Start NCH Train") + self.classes_ = np.unique(y) + + for c in self.classes_: + self.matrices_per_class_[c] = X[y == c] + + if self.debug: + print("Samples per class:") + for c in self.classes_: + print("Class: ", c, " Count: ", self.matrices_per_class_[c].shape[0]) + + print("End NCH Train") + + def _process_sample_min_hull(self, test_sample): + """Finds the closes N covmats and uses them to build a single hull per class""" + distances = [] + + for c in self.classes_: + distances_to_covs = [ + distance_logeuclid(test_sample, cov) + for cov in self.matrices_per_class_[c] + ] + + # take the first N min distances + indexes = np.argsort(np.array(distances_to_covs))[ + 0 : self.n_samples_per_hull + ] + + if self.debug: + print("Distances to test sample: ", distances_to_covs) + print("Smallest N distances indexes:", indexes) + print("Smallest N distances: ") + for pp in indexes: + print(distances_to_covs[pp]) + + d = qdistance_logeuclid_to_convex_hull( + self.matrices_per_class_[c][indexes], test_sample + ) + + if self.debug: + print("Final hull distance:", d) + + distances.append(d) + + return distances + + def _process_sample_random_hull(self, test_sample): + """Uses random samples to build a hull, can be several hulls per class""" + distances = [] + + for c in self.classes_: + total_distance = 0 + + # using multiple hulls + for i in range(0, self.n_hulls_per_class): + if self.n_samples_per_hull == -1: # use all data per class + hull_data = self.matrices_per_class_[c] + else: # use a subset of the data per class + random_samples = random.sample( + range(self.matrices_per_class_[c].shape[0]), + k=self.n_samples_per_hull, + ) + hull_data = self.matrices_per_class_[c][random_samples, :, :] + + distance = qdistance_logeuclid_to_convex_hull(hull_data, test_sample) + total_distance = total_distance + distance + + distances.append(total_distance) + + return distances + + def _predict_distances(self, X): + """Helper to predict the distance. Equivalent to transform.""" + dist = [] + + if self.debug: + print("Total test samples:", X.shape[0]) + + if self.hull_type == "min-hull": + self._process_sample = self._process_sample_min_hull + elif self.hull_type == "random-hull": + self._process_sample = self._process_sample_random_hull + else: + raise Exception("Error: Unknown hull type.") + + parallel = self.n_jobs > 1 + + if self.debug: + if parallel: + print("Running in parallel") + else: + print("Not running in parallel") + + if parallel: + dist = Parallel(n_jobs=self.n_jobs)( + delayed(self._process_sample)(test_sample) for test_sample in X + ) + + else: + for test_sample in X: + dist_sample = self._process_sample(test_sample) + dist.append(dist_sample) + + return dist + + def predict(self, X): + """Get the predictions. + Parameters + ---------- + X : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices. + Returns + ------- + pred : ndarray of int, shape (n_matrices,) + Predictions for each matrix according to the closest convex hull. + """ + if self.debug: + print("Start NCH Predict") + dist = self._predict_distances(X) + + predictions = [ + self.classes_[min(range(len(values)), key=values.__getitem__)] + for values in dist + ] + + if self.debug: + print("End NCH Predict") + + return predictions + + def transform(self, X): + """Get the distance to each convex hull. + + Parameters + ---------- + X : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices. + + Returns + ------- + dist : ndarray, shape (n_matrices, n_classes) + The distance to each convex hull. + """ + + if self.debug: + print("NCH Transform") + return self._predict_distances(X) + + +class QuanticNCH(QuanticClassifierBase): + + """A Quantum wrapper around the NCH algorithm. It allows both classical + and Quantum versions to be executed. + + Notes + ----- + .. versionadded:: 0.2.0 + + Parameters + ---------- + quantum : bool (default: True) + Only applies if `metric` contains a cpm distance or mean. + + - If true will run on local or remote backend + (depending on q_account_token value), + - If false, will perform classical computing instead. + q_account_token : string (default:None) + If `quantum` is True and `q_account_token` provided, + the classification task will be running on a IBM quantum backend. + If `load_account` is provided, the classifier will use the previous + token saved with `IBMProvider.save_account()`. + verbose : bool (default:True) + If true, will output all intermediate results and logs. + shots : int (default:1024) + Number of repetitions of each circuit, for sampling. + seed: int | None (default: None) + Random seed for the simulation + upper_bound : int (default: 7) + The maximum integer value for matrix normalization. + regularization: MixinTransformer (defulat: None) + Additional post-processing to regularize means. + classical_optimizer : OptimizationAlgorithm + An instance of OptimizationAlgorithm [3]_ + n_jobs : int, (default=6) + The number of jobs to use for the computation. This works by computing + each of the hulls in parallel. + n_hulls_per_class: int, (default 3) + The number of hulls used per class. + n_samples_per_hull: int, (default 15) + Defines how many samples are used to build a hull. + hull_type: string, (default "min-hull") + Selects how the hull is constructed. Possible values are + "min-hull" and "random-hull" + + """ + + def __init__( + self, + quantum=True, + q_account_token=None, + verbose=True, + shots=1024, + seed=None, + upper_bound=7, + regularization=None, + n_jobs=6, + classical_optimizer=SlsqpOptimizer(), # set here new default optimizer + n_hulls_per_class=3, + n_samples_per_hull=10, + hull_type="min-hull", + ): + QuanticClassifierBase.__init__( + self, quantum, q_account_token, verbose, shots, None, seed + ) + self.upper_bound = upper_bound + self.regularization = regularization + self.classical_optimizer = classical_optimizer + self.n_hulls_per_class = n_hulls_per_class + self.n_samples_per_hull = n_samples_per_hull + self.n_jobs = n_jobs + self.hull_type = hull_type + + def _init_algo(self, n_features): + self._log("Nearest Convex Hull Classifier initiating algorithm") + + classifier = NearestConvexHull( + n_hulls_per_class=self.n_hulls_per_class, + n_samples_per_hull=self.n_samples_per_hull, + n_jobs=self.n_jobs, + hull_type=self.hull_type, + ) + + if self.quantum: + self._log("Using NaiveQAOAOptimizer") + self._optimizer = NaiveQAOAOptimizer( + quantum_instance=self._quantum_instance, upper_bound=self.upper_bound + ) + else: + self._log("Using ClassicalOptimizer") + self._optimizer = ClassicalOptimizer(self.classical_optimizer) + + # sets the optimizer for the distance functions + # used in NearestConvexHull class + set_global_optimizer(self._optimizer) + + return classifier + + def predict(self, X): + # self._log("QuanticNCH Predict") + return self._predict(X) + + def transform(self, X): + # self._log("QuanticNCH Transform") + return self._classifier.transform(X) diff --git a/pyriemann_qiskit/utils/distance.py b/pyriemann_qiskit/utils/distance.py index 28fdc732..eedbdfc5 100644 --- a/pyriemann_qiskit/utils/distance.py +++ b/pyriemann_qiskit/utils/distance.py @@ -124,6 +124,7 @@ def log_prod(m1, m2): prob.set_objective("min", objective) prob.add_constraint(prob.sum(w) == 1) + weights = optimizer.solve(prob, reshape=False) return weights