Skip to content

Commit

Permalink
Merge pull request #122 from lionelkusch/PR_estimation_threshold
Browse files Browse the repository at this point in the history
Estimation threshold(1/4): add comments and docstring of the functions
  • Loading branch information
bthirion authored Feb 20, 2025
2 parents d6e93f7 + 59df1ca commit 6d76835
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 101 deletions.
2 changes: 1 addition & 1 deletion doc_conf/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Functions
permutation_test
permutation_test_pval
reid
standardized_svr
empirical_thresholding
zscore_from_pval

Classes
Expand Down
12 changes: 12 additions & 0 deletions doc_conf/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,18 @@ @article{liuFastPowerfulConditional2021
keywords = {Statistics - Methodology},
}

@thesis{chevalier_statistical_2020,
title = {Statistical control of sparse models in high dimension},
url = {https://theses.hal.science/tel-03147200},
institution = {Université Paris-Saclay},
type = {phdthesis},
author = {Chevalier, Jérôme-Alexis},
urldate = {2024-10-17},
date = {2020-12-11},
langid = {english},
}
}
@article{benjamini1995controlling,
title={Controlling the false discovery rate: a practical and powerful approach to multiple testing},
author={Benjamini, Yoav and Hochberg, Yosef},
Expand Down
13 changes: 7 additions & 6 deletions examples/plot_fmri_data_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@
from hidimstat.adaptative_permutation_threshold_SVR import ada_svr
from hidimstat.clustered_inference import clustered_inference
from hidimstat.ensemble_clustered_inference import ensemble_clustered_inference
from hidimstat.empirical_thresholding import empirical_thresholding
from hidimstat.permutation_test import permutation_test, permutation_test_pval
from hidimstat.standardized_svr import standardized_svr
from hidimstat.stat_tools import pval_from_scale, zscore_from_pval

n_job = None
Expand Down Expand Up @@ -139,10 +139,11 @@ def preprocess_haxby(subject=2, memory=None):
# First, we try to recover the discriminative partern by computing
# p-values from SVR decoder weights and a parametric approximation
# of the distribution of these weights.

# We precomputed the regularization parameter by CV (C = 0.1) to reduce the
# computation time of the example.
beta_hat, scale = standardized_svr(X, y, Cs=[0.1])
beta_hat, scale = empirical_thresholding(
X,
y,
linear_estimator=LinearSVR(),
)
pval_std_svr, _, one_minus_pval_std_svr, _ = pval_from_scale(beta_hat, scale)

#############################################################################
Expand Down Expand Up @@ -317,4 +318,4 @@ def plot_map(
# (EnCluDL) seems realistic as we recover the visual cortex and do not make
# spurious discoveries.

# show()
show()
4 changes: 2 additions & 2 deletions src/hidimstat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .noise_std import group_reid, reid
from .permutation_test import permutation_test, permutation_test_pval
from .scenario import multivariate_1D_simulation
from .standardized_svr import standardized_svr
from .empirical_thresholding import empirical_thresholding
from .stat_tools import zscore_from_pval
from .cpi import CPI
from .loco import LOCO
Expand All @@ -35,7 +35,7 @@
"permutation_test",
"permutation_test_pval",
"reid",
"standardized_svr",
"empirical_thresholding",
"zscore_from_pval",
"CPI",
"LOCO",
Expand Down
72 changes: 72 additions & 0 deletions src/hidimstat/empirical_thresholding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np
from numpy.linalg import norm
from sklearn.model_selection import GridSearchCV
from sklearn.svm import LinearSVR


def empirical_thresholding(
X,
y,
linear_estimator=GridSearchCV(
LinearSVR(), param_grid={"C": np.logspace(-7, 1, 9)}, n_jobs=None
),
):
"""
Perform empirical thresholding on the input data and target using a linear
estimator.
This function fits a linear estimator to the input data and target,
and then uses the estimated coefficients to perform empirical thresholding.
The threshold is calculated for keeping only extreme coefficients.
For more details, see the section 6.3.2 of :cite:`chevalier_statistical_2020`
Parameters
----------
X : ndarray, shape (n_samples, n_features)
The input data.
y : ndarray, shape (n_samples,)
The target values.
linear_estimator : estimator object, optional (default=GridSearchCV(
LinearSVR(),param_grid={"C": np.logspace(-7, 1, 9)}, n_jobs=None))
The linear estimator to use for thresholding. It should be a scikit-learn
estimator object that implements the `fit` method and has a `coef_`
attribute or a `best_estimator_` attribute with a `coef_` attribute
(e.g., a `GridSearchCV` object).
Returns
-------
beta_hat : ndarray, shape (n_features,)
The estimated coefficients of the linear estimator.
scale : ndarray, shape (n_features,)
The threshold values for each feature.
Raises
------
ValueError
If the `linear_estimator` does not have a `coef_` attribute
or a `best_estimator_` attribute with a `coef_` attribute.
Notes
-----
The threshold is calculated as the standard deviation of the estimated
coefficients multiplied by the square root of the number of features.
This is based on the assumption that the coefficients follow a normal
distribution with mean zero.
"""
_, n_features = X.shape

linear_estimator.fit(X, y)

if hasattr(linear_estimator, "coef_"):
beta_hat = linear_estimator.coef_
elif hasattr(linear_estimator, "best_estimator_") and hasattr(
linear_estimator.best_estimator_, "coef_"
):
beta_hat = linear_estimator.best_estimator_.coef_ # for CV object
else:
raise ValueError("linear estimator should be linear.")

std = norm(beta_hat) / np.sqrt(n_features)
scale = std * np.ones(beta_hat.size)

return beta_hat, scale
49 changes: 0 additions & 49 deletions src/hidimstat/standardized_svr.py

This file was deleted.

78 changes: 78 additions & 0 deletions test/test_empirical_thresholding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Test the empirical thresholding module
"""

import pytest
import numpy as np
from numpy.testing import assert_almost_equal

from hidimstat.scenario import multivariate_1D_simulation
from hidimstat.empirical_thresholding import empirical_thresholding
from hidimstat.stat_tools import pval_from_scale

from sklearn.linear_model import Lasso
from sklearn.tree import DecisionTreeRegressor


def test_emperical_thresholding():
"""Testing the procedure on a simulation with no structure and a support
of size 1. Computing one-sided p-values, we want a low p-value
for the first feature and p-values close to 0.5 for the others."""

n_samples, n_features = 20, 50
support_size = 1
sigma = 0.1
rho = 0.0

X_init, y, beta, noise = multivariate_1D_simulation(
n_samples=n_samples,
n_features=n_features,
support_size=support_size,
sigma=sigma,
rho=rho,
shuffle=False,
seed=3,
)

beta_hat, scale_hat = empirical_thresholding(X_init, y)

pval, pval_corr, _, _ = pval_from_scale(beta_hat, scale_hat)

expected = 0.5 * np.ones(n_features)
expected[:support_size] = 0.0

assert_almost_equal(pval_corr, expected, decimal=1)


def test_emperical_thresholding_lasso():
"""Testing the procedure on a simulation with no structure and a support
of size 1 with lasso."""

n_samples, n_features = 20, 50
support_size = 1
sigma = 0.1
rho = 0.0

X_init, y, beta, noise = multivariate_1D_simulation(
n_samples=n_samples,
n_features=n_features,
support_size=support_size,
sigma=sigma,
rho=rho,
shuffle=False,
seed=3,
)

with pytest.raises(ValueError, match="linear estimator should be linear."):
beta_hat, scale_hat = empirical_thresholding(
X_init, y, linear_estimator=DecisionTreeRegressor()
)

beta_hat, scale_hat = empirical_thresholding(X_init, y, linear_estimator=Lasso())

pval, pval_corr, _, _ = pval_from_scale(beta_hat, scale_hat)

expected = 0.5 * np.ones(n_features)
expected[:support_size] = 0.0

assert_almost_equal(pval_corr, expected, decimal=1)
43 changes: 0 additions & 43 deletions test/test_standardized_svr.py

This file was deleted.

0 comments on commit 6d76835

Please sign in to comment.