Skip to content

Commit

Permalink
add cross_validation
Browse files Browse the repository at this point in the history
  • Loading branch information
HDembinski committed May 13, 2024
1 parent 7b65ff7 commit 7f40d3b
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 16 deletions.
71 changes: 61 additions & 10 deletions doc/tutorial/leave-one-out-cross-validation.ipynb

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ profile = "black"
multi_line_output = 3

[tool.ruff]
select = ["E", "F", "W", "D"]
ignore = ["D212", "D203"]
unfixable = ["ERA"]
extend-select = ["E", "F", "W", "D"]
lint.ignore = ["D212", "D203"]
lint.unfixable = ["ERA"]

[tool.ruff.lint.per-file-ignores]
"test_*.py" = ["D"]

[tool.mypy]
strict = true
Expand Down
39 changes: 38 additions & 1 deletion src/resample/jackknife.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any, Callable, Collection, Generator, List

import numpy as np
from numpy.typing import ArrayLike
from numpy.typing import ArrayLike, floating


def resample(
Expand Down Expand Up @@ -319,3 +319,40 @@ def variance(
thetas = jackknife(fn, sample, *args)
n = len(sample)
return (n - 1) * np.var(thetas, ddof=0, axis=0)


def cross_validation(
predict: Callable[..., float], x: "ArrayLike", y: "ArrayLike", *args: "ArrayLike"
) -> floating[Any]:
"""
Calculate mean-squared error of model with leave-one-out-cross-validation.
Wikipedia:
https://en.wikipedia.org/wiki/Cross-validation_(statistics)
Parameters
----------
predict : callable
Function with the signature (x_in, y_in, x_out, *args). It takes x_in, y_in,
which are arrays with the same length. x_out should be one element of the x
array. *args are further optional arguments for the function. The function
should return the prediction y(x_out).
x : array-like
Explanatory variable. Must be an array of shape (N, ...), where N is the number
of samples.
y : array-like
Observations. Must be an array of shape (N, ...).
*args:
Optional arguments which are passed unmodified to predict.
Returns
-------
float
Variance of the difference (y[i] - predict(..., x[i], *args)).
"""
deltas = []
for i, (x_in, y_in) in enumerate(resample(x, y, copy=False)):
yip = predict(x_in, y_in, x[i], *args)
deltas.append((y[i] - yip))
return np.var(deltas)
30 changes: 28 additions & 2 deletions tests/test_jackknife.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import numpy as np
import pytest
from numpy.testing import assert_almost_equal, assert_equal

from resample.jackknife import bias, bias_corrected, jackknife, resample, variance
from scipy.optimize import curve_fit
from resample.jackknife import (
bias,
bias_corrected,
jackknife,
resample,
variance,
cross_validation,
)


def test_resample_1d():
Expand Down Expand Up @@ -120,3 +127,22 @@ def test_resample_deprecation():
with pytest.warns(VisibleDeprecationWarning):
with pytest.raises(ValueError): # too many arguments
resample(data, True, 1)


@pytest.mark.filterwarnings("ignore:Covariance")
def test_cross_validation():
x = [1, 2, 3]
y = [3, 4, 5]

def predict(xi, yi, xo, npar):
def model(x, *par):
return np.polyval(par, x)

popt = curve_fit(model, xi, yi, p0=np.zeros(npar))[0]
return model(xo, *popt)

v = cross_validation(predict, x, y, 2)
assert v == pytest.approx(0)

v2 = cross_validation(predict, x, y, 1)
assert v2 == pytest.approx(1.5)

0 comments on commit 7f40d3b

Please sign in to comment.