Skip to content

Commit

Permalink
Add specific tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lionelkusch committed Feb 25, 2025
1 parent 793c86f commit 8c953e2
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions test/test_adaptative_permutation_threshold_SVR.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import numpy as np
from numpy.testing import assert_almost_equal
from sklearn.svm import SVR

from hidimstat.adaptative_permutation_threshold_SVR import ada_svr, ada_svr_pvalue
from hidimstat.scenario import multivariate_1D_simulation
from hidimstat.permutation_test import permutation_test


def test_ada_svr():
Expand Down Expand Up @@ -42,3 +44,54 @@ def test_ada_svr():

assert_almost_equal(pval[:support_size], expected[:support_size], decimal=1)
assert_almost_equal(pval_corr[support_size:], expected[support_size:], decimal=1)


def test_ada_svr_rcond():
"""
Testing the effect of rcond
"""
# create dataset
X, y, beta, _ = multivariate_1D_simulation(
n_samples=20,
n_features=50,
support_size=3,
sigma=0.1,
shuffle=False,
seed=42,
)
X[:10] *= 1e-5
beta_hat, scale = ada_svr(X, y)
beta_hat_2, scale_2 = ada_svr(X, y, rcond=1e-15)
assert np.max(np.abs(beta_hat - beta_hat_2)) > 1
assert np.max(np.abs(scale - scale_2)) > 1


def test_ada_svr_vs_permutation():
"""
Validate the adaptive permutation threshold procedure against a permutation
test. The adaptive permutation threshold procedure should good approciation
of the proba of the permutation test.
"""
# create dataset
X, y, beta, _ = multivariate_1D_simulation(
n_samples=10,
n_features=100,
support_size=1,
sigma=0.1,
shuffle=False,
seed=42,
)
beta_hat, scale = ada_svr(X, y)
# fit a SVR to get the coefficients
estimator = SVR(kernel="linear", epsilon=0.0, gamma="scale", C=1.0)
estimator.fit(X, y)
beta_hat_svr = estimator.coef_

# compare that the coefficiants are the same that the one of SVR
assert np.max(np.abs(beta_hat - beta_hat_svr.T[:, 0])) < 2e-4

proba = permutation_test(
X, y, estimator=estimator, n_permutations=10000, n_jobs=8, seed=42, proba=True
)
assert np.max(np.abs(np.mean(proba, axis=0))) < 1e-3
assert np.max(np.abs(scale - np.std(proba, axis=0)) / scale) < 1e-1

0 comments on commit 8c953e2

Please sign in to comment.