Skip to content

Commit

Permalink
Feat/mdm predict distances (#248)
Browse files Browse the repository at this point in the history
* - move workaround on _predict_distances from distance to classification
- override only the instance method and not the class function

* rename to "distance_logeuclid_cpm" to "distance_logeuclid_to_convex_hull_cpm"
add "weights_euclid_cpm"

* separate weights from distance functions

* complete tests

* add warnings

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update pyriemann_qiskit/utils/distance.py

Co-authored-by: Quentin Barthélemy <[email protected]>

* Update pyriemann_qiskit/utils/distance.py

Co-authored-by: Quentin Barthélemy <[email protected]>

* Update pyriemann_qiskit/utils/distance.py

Co-authored-by: Quentin Barthélemy <[email protected]>

* replace "_cpm"

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* fix lint

* set _weights_distance as private
complete api.rst

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update pyriemann_qiskit/utils/mean.py

Co-authored-by: Quentin Barthélemy <[email protected]>

* Update pyriemann_qiskit/utils/distance.py

Co-authored-by: Quentin Barthélemy <[email protected]>

* Update pyriemann_qiskit/classification.py

Co-authored-by: Quentin Barthélemy <[email protected]>

* Update pyriemann_qiskit/utils/distance.py

Co-authored-by: Quentin Barthélemy <[email protected]>

* update api.rst

* replace pyQiskitOptimizer by :class:`pyriemann_qiskit.utils.docplex.pyQiskitOptimizer`

* move and rename _predict_distance

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* create module utils.utils

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* improve code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

---------

Co-authored-by: Gregoire Cattan <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Quentin Barthélemy <[email protected]>
  • Loading branch information
4 people authored Mar 5, 2024
1 parent 045a36a commit a89478b
Show file tree
Hide file tree
Showing 12 changed files with 243 additions and 145 deletions.
4 changes: 2 additions & 2 deletions benchmarks/light_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@
)

pipelines["QMDM_mean"] = QuantumMDMWithRiemannianPipeline(
metric={"mean": "euclid_cpm", "distance": "euclid"},
metric={"mean": "qeuclid", "distance": "euclid"},
quantum=True,
regularization=Shrinkage(shrinkage=0.9),
)

pipelines["QMDM_dist"] = QuantumMDMWithRiemannianPipeline(
metric={"mean": "logeuclid", "distance": "logeuclid_cpm"}, quantum=True
metric={"mean": "logeuclid", "distance": "qlogeuclid_hull"}, quantum=True
)

pipelines["RG_LDA"] = make_pipeline(
Expand Down
19 changes: 15 additions & 4 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,17 @@ Ensemble
Utils function
--------------

Utils functions are low level functions for the `classification` module.
Utils functions are low level functions for the `classification` and `pipelines` module.

Utils
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. _hyper_params_factory_api:
.. currentmodule:: pyriemann_qiskit.utils.utils

.. autosummary::
:toctree: generated/

is_qfunction

Hyper-parameters generation
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -96,8 +106,8 @@ Mean
.. autosummary::
:toctree: generated/

mean_euclid_cpm
mean_logeuclid_cpm
qmean_euclid
qmean_logeuclid

Distance
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -107,7 +117,8 @@ Distance
.. autosummary::
:toctree: generated/

distance_logeuclid_cpm
qdistance_logeuclid_to_convex_hull
weights_logeuclid_to_convex_hull

Docplex
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
6 changes: 3 additions & 3 deletions examples/ERP/classify_P300_bi_quantum_mdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@

pipelines = {}

pipelines["mean=logeuclid_cpm/distance=logeuclid"] = QuantumMDMWithRiemannianPipeline(
pipelines["mean=qlogeuclid/distance=logeuclid"] = QuantumMDMWithRiemannianPipeline(
metric="mean", quantum=quantum
)

pipelines["mean=logeuclid/distance=logeuclid_cpm"] = QuantumMDMWithRiemannianPipeline(
pipelines["mean=logeuclid/distance=qlogeuclid"] = QuantumMDMWithRiemannianPipeline(
metric="distance", quantum=quantum
)

pipelines["Voting logeuclid_cpm"] = QuantumMDMVotingClassifier(quantum=quantum)
pipelines["Voting qlogeuclid"] = QuantumMDMVotingClassifier(quantum=quantum)

##############################################################################
# Run evaluation
Expand Down
67 changes: 50 additions & 17 deletions pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from scipy.special import softmax
import logging
import numpy as np
from warnings import warn

from pyriemann.classification import MDM
from pyriemann_qiskit.datasets import get_feature_dimension
Expand All @@ -16,6 +17,8 @@
NaiveQAOAOptimizer,
set_global_optimizer,
)
from pyriemann_qiskit.utils.distance import distance_functions
from pyriemann_qiskit.utils.utils import is_qfunction
from qiskit.utils import QuantumInstance
from qiskit.utils.quantum_instance import logger
from qiskit_ibm_provider import IBMProvider, least_busy
Expand Down Expand Up @@ -583,8 +586,8 @@ class QuanticMDM(QuanticClassifierBase):

"""Quantum-enhanced MDM classifier
This class is a quantic implementation of the Minimum Distance to Mean (MDM)
[1]_, which can run with quantum optimization.
This class is a quantic implementation of the Minimum Distance to Mean
(MDM) [1]_, which can run with quantum optimization.
Only log-Euclidean distance between trial and class prototypes is supported
at the moment, but any type of metric can be used for centroid estimation.
Expand All @@ -600,37 +603,36 @@ class QuanticMDM(QuanticClassifierBase):
Parameters
----------
metric : string | dict, default={"mean": 'logeuclid', "distance": 'cpm'}
metric : string | dict, default={"mean": 'logeuclid', \
"distance": 'qlogeuclid_hull'}
The type of metric used for centroid and distance estimation.
see `mean_covariance` for the list of supported metric.
the metric could be a dict with two keys, `mean` and `distance` in
order to pass different metrics for the centroid estimation and the
distance estimation. Typical usecase is to pass 'logeuclid' metric for
the mean in order to boost the computional speed and 'riemann' for the
distance in order to keep the good sensitivity for the classification.
quantum : bool (default: True)
Only applies if `metric` contains a cpm distance or mean.
distance estimation.
quantum : bool, default=True
Only applies if `metric` contains a quantic distance or mean.
- If true will run on local or remote backend
(depending on q_account_token value),
- If false, will perform classical computing instead.
q_account_token : string (default:None)
q_account_token : string, default=None
If `quantum` is True and `q_account_token` provided,
the classification task will be running on a IBM quantum backend.
If `load_account` is provided, the classifier will use the previous
token saved with `IBMProvider.save_account()`.
verbose : bool (default:True)
verbose : bool, default=True
If true, will output all intermediate results and logs.
shots : int (default:1024)
shots : int, default=1024
Number of repetitions of each circuit, for sampling.
seed: int | None (default: None)
seed : int | None, default=None
Random seed for the simulation
upper_bound : int (default: 7)
upper_bound : int, default=7
The maximum integer value for matrix normalization.
regularization: MixinTransformer (defulat: None)
regularization : MixinTransformer, default=None
Additional post-processing to regularize means.
classical_optimizer : OptimizationAlgorithm
An instance of OptimizationAlgorithm [3]_
classical_optimizer : OptimizationAlgorithm, default=CobylaOptimizer()
An instance of OptimizationAlgorithm [3]_.
See Also
--------
Expand All @@ -655,7 +657,7 @@ class QuanticMDM(QuanticClassifierBase):

def __init__(
self,
metric={"mean": "logeuclid", "distance": "logeuclid_cpm"},
metric={"mean": "logeuclid", "distance": "qlogeuclid_hull"},
quantum=True,
q_account_token=None,
verbose=True,
Expand All @@ -673,9 +675,40 @@ def __init__(
self.regularization = regularization
self.classical_optimizer = classical_optimizer

@staticmethod
def _override_predict_distance(mdm):
"""Override _predict_distances method of MDM.
We override the _predict_distances method inside MDM to allow the use
of qdistance.
This is due to the fact the the signature of qdistances is different
from the usual distance functions.
"""

def _predict_distances(X):
if is_qfunction(mdm.metric_dist):
if "hull" in mdm.metric_dist:
warn("qdistances to hull should not be use inside MDM")
else:
warn(
"q-distances for MDM are toy functions.\
Use pyRiemann distances instead."
)
distance = distance_functions[mdm.metric_dist]
centroids = np.array(mdm.covmeans_)
weights = [distance(centroids, x) for x in X]
return 1 - np.array(weights)
else:
return MDM._predict_distances(mdm, X)

return _predict_distances

def _init_algo(self, n_features):
self._log("Quantic MDM initiating algorithm")
classifier = MDM(metric=self.metric)
classifier._predict_distances = QuanticMDM._override_predict_distance(
classifier
)
if self.quantum:
self._log("Using NaiveQAOAOptimizer")
self._optimizer = NaiveQAOAOptimizer(
Expand Down
16 changes: 8 additions & 8 deletions pyriemann_qiskit/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pyriemann.estimation import XdawnCovariances, ERPCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.preprocessing import Whitening
from pyriemann_qiskit.utils.mean import is_cpm_mean
from pyriemann_qiskit.utils.utils import is_qfunction
from pyriemann_qiskit.utils.filtering import NoDimRed
from pyriemann_qiskit.utils.hyper_params_factory import (
# gen_zz_feature_map,
Expand Down Expand Up @@ -312,7 +312,7 @@ class QuantumMDMWithRiemannianPipeline(BasePipeline):
Parameters
----------
metric : string | dict, default={"mean": 'logeuclid', "distance": 'logeuclid_cpm'}
metric : string | dict, default={"mean": 'logeuclid', "distance": 'qlogeuclid'}
The type of metric used for centroid and distance estimation.
quantum : bool (default: True)
- If true will run on local or remote backend
Expand Down Expand Up @@ -361,7 +361,7 @@ class QuantumMDMWithRiemannianPipeline(BasePipeline):

def __init__(
self,
metric={"mean": "logeuclid", "distance": "logeuclid_cpm"},
metric={"mean": "logeuclid", "distance": "qlogeuclid_hull"},
quantum=True,
q_account_token=None,
verbose=True,
Expand All @@ -384,7 +384,7 @@ def __init__(
def _create_pipe(self):
print(self.metric)
print(self.metric["mean"])
if is_cpm_mean(self.metric["mean"]):
if is_qfunction(self.metric["mean"]):
if self.quantum:
covariances = XdawnCovariances(
nfilter=1, estimator="scm", xdawn_estimator="lwf"
Expand Down Expand Up @@ -418,8 +418,8 @@ class QuantumMDMVotingClassifier(BasePipeline):
Voting classifier with two configurations of
QuantumMDMWithRiemannianPipeline:
- with mean = euclid_cpm and distance = euclid,
- with mean = logeuclid and distance = logeuclid_cpm.
- with mean = qeuclid and distance = euclid,
- with mean = logeuclid and distance = qlogeuclid.
Parameters
----------
Expand Down Expand Up @@ -472,15 +472,15 @@ def __init__(

def _create_pipe(self):
clf_mean_logeuclid_dist_cpm = QuantumMDMWithRiemannianPipeline(
{"mean": "logeuclid", "distance": "logeuclid_cpm"},
{"mean": "logeuclid", "distance": "qlogeuclid_hull"},
self.quantum,
self.q_account_token,
self.verbose,
self.shots,
self.upper_bound,
)
clf_mean_cpm_dist_euclid = QuantumMDMWithRiemannianPipeline(
{"mean": "euclid_cpm", "distance": "euclid"},
{"mean": "qeuclid", "distance": "euclid"},
self.quantum,
self.q_account_token,
self.verbose,
Expand Down
2 changes: 2 additions & 0 deletions pyriemann_qiskit/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from . import distance
from . import mean
from . import utils

__all__ = [
"hyper_params_factory",
Expand All @@ -45,4 +46,5 @@
"filter_subjects_by_incomplete_results",
"add_moabb_dataframe_results_to_caches",
"convert_caches_to_dataframes",
"utils",
]
Loading

0 comments on commit a89478b

Please sign in to comment.