Skip to content

Commit

Permalink
Add L2 regularization (#94)
Browse files Browse the repository at this point in the history
* add L2 regularization

* add unit test

* format

* add regularization to refs page in docs

* fix typo

* remove input arguments for regularize! template

* clarify docstrings
  • Loading branch information
JoshuaLampert authored Oct 22, 2024
1 parent f713f46 commit 8c2b6b0
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 37 deletions.
7 changes: 7 additions & 0 deletions docs/src/ref.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ Modules = [KernelInterpolation]
Pages = ["interpolation.jl"]
```

## Regularization

```@autodocs
Modules = [KernelInterpolation]
Pages = ["regularization.jl"]
```

## [Differential Operators](@id api-diffops)

```@autodocs
Expand Down
26 changes: 26 additions & 0 deletions examples/interpolation/regularization_2d.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using KernelInterpolation
using Plots

# interpolate Franke function
function f(x)
0.75 * exp(-0.25 * ((9 * x[1] - 2)^2 + (9 * x[2] - 2)^2)) +
0.75 * exp(-(9 * x[1] + 1)^2 / 49 - (9 * x[2] + 1) / 10) +
0.5 * exp(-0.25 * ((9 * x[1] - 7)^2 + (9 * x[2] - 3)^2)) -
0.2 * exp(-(9 * x[1] - 4)^2 - (9 * x[2] - 7)^2)
end

n = 1089
nodeset = random_hypercube(n; dim = 2)
values = f.(nodeset) .+ 0.03 * randn(n)

kernel = ThinPlateSplineKernel{dim(nodeset)}()
itp_reg = interpolate(nodeset, values, kernel, reg = L2Regularization(1e-2))
itp = interpolate(nodeset, values, kernel)

N = 40
many_nodes = homogeneous_hypercube(N; dim = 2)

p1 = plot(many_nodes, itp_reg, st = :surface, training_nodes = false)
p2 = plot(many_nodes, itp, st = :surface, training_nodes = false)

plot(p1, p2, layout = (2, 1))
4 changes: 3 additions & 1 deletion src/KernelInterpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ module KernelInterpolation

using DiffEqCallbacks: PeriodicCallback, PeriodicCallbackAffect
using ForwardDiff: ForwardDiff
using LinearAlgebra: Symmetric, norm, tr, muladd, dot
using LinearAlgebra: Symmetric, norm, tr, muladd, dot, diagind
using Printf: @sprintf
using ReadVTK: VTKFile, get_points, get_point_data, get_data
using RecipesBase: RecipesBase, @recipe, @series
Expand All @@ -34,6 +34,7 @@ include("nodes.jl")
include("differential_operators.jl")
include("equations.jl")
include("kernel_matrices.jl")
include("regularization.jl")
include("interpolation.jl")
include("discretization.jl")
include("callbacks_step/callbacks_step.jl")
Expand All @@ -52,6 +53,7 @@ export PartialDerivative, Gradient, Laplacian, EllipticOperator
export PoissonEquation, EllipticEquation, AdvectionEquation, HeatEquation,
AdvectionDiffusionEquation
export SpatialDiscretization, Semidiscretization, semidiscretize
export NoRegularization, L2Regularization
export NodeSet, empty_nodeset, separation_distance, dim, eachdim, values_along_dim,
distance_matrix, random_hypercube, random_hypercube_boundary, homogeneous_hypercube,
homogeneous_hypercube_boundary, random_hypersphere, random_hypersphere_boundary
Expand Down
52 changes: 19 additions & 33 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ Return the system matrix, i.e., the matrix ``A`` in the linear system
Ac = f,
```
where ``c`` are the coefficients of the kernel interpolant and ``f`` the vector
of known values. The exact form of ``A`` differs depending on whether classical interpolation
or collocation is used.
of known values. The exact form of ``A`` differs depending on which method is used.
"""
system_matrix(itp::Interpolation) = itp.system_matrix

@doc raw"""
interpolate(nodeset, centers = nodeset, values, kernel = GaussKernel{dim(nodeset)}(), m = order(kernel))
interpolate(nodeset, centers = nodeset, values, kernel = GaussKernel{dim(nodeset)}();
m = order(kernel), reg = NoRegularization())
Interpolate the `values` evaluated at the nodes in the `nodeset` to a function using the kernel `kernel`
and polynomials up to a order `m` (i.e. degree - 1), i.e., determine the coefficients ``c_j`` and ``d_k`` in the expansion
Expand All @@ -136,51 +136,37 @@ maximum degree of `m - 1`. If `m = 0`, no polynomial is added. The additional co
\sum_{j = 1}^N c_jp_k(x_j) = 0, \quad k = 1,\ldots, Q = \begin{pmatrix}m - 1 + d\\d\end{pmatrix}
```
are enforced. Returns an [`Interpolation`](@ref) object.
If `centers` is provided, the interpolant is a least squares approximation with the centers used for the basis.
"""
function interpolate(nodeset::NodeSet{Dim, RealT},
values::Vector{RealT},
kernel = GaussKernel{Dim}();
m = order(kernel)) where {Dim, RealT}
@assert dim(kernel) == Dim
n = length(nodeset)
@assert length(values) == n
xx = polyvars(Dim)
ps = monomials(xx, 0:(m - 1))
q = length(ps)
k_matrix = kernel_matrix(nodeset, kernel)
p_matrix = polynomial_matrix(nodeset, ps)
system_matrix = [k_matrix p_matrix
p_matrix' zeros(q, q)]
b = [values; zeros(q)]
symmetric_system_matrix = Symmetric(system_matrix)
c = symmetric_system_matrix \ b
return Interpolation(kernel, nodeset, nodeset, c, symmetric_system_matrix, ps, xx)
end
If `centers` is provided, the interpolant is a least squares approximation with the centers used for the basis.
# Least squares approximation
A regularization can be applied to the kernel matrix using the `reg` argument, cf. [`regularize!`](@ref).
"""
function interpolate(nodeset::NodeSet{Dim, RealT}, centers::NodeSet{Dim, RealT},
values::Vector{RealT},
kernel = GaussKernel{Dim}();
m = order(kernel)) where {Dim, RealT}
values::Vector{RealT}, kernel = GaussKernel{Dim}();
m = order(kernel), reg = NoRegularization()) where {Dim, RealT}
@assert dim(kernel) == Dim
n = length(nodeset)
@assert length(values) == n
xx = polyvars(Dim)
ps = monomials(xx, 0:(m - 1))
q = length(ps)

k_matrix = kernel_matrix(nodeset, centers, kernel)
p_matrix1 = polynomial_matrix(nodeset, ps)
p_matrix2 = polynomial_matrix(centers, ps)
system_matrix = [k_matrix p_matrix1
p_matrix2' zeros(q, q)]
if nodeset == centers
system_matrix = interpolation_matrix(nodeset, kernel, ps, reg)
else
system_matrix = least_squares_matrix(nodeset, centers, kernel, ps, reg)
end
b = [values; zeros(q)]
c = system_matrix \ b
return Interpolation(kernel, nodeset, centers, c, system_matrix, ps, xx)
end

function interpolate(nodeset::NodeSet{Dim, RealT},
values::Vector{RealT}, kernel = GaussKernel{Dim}();
kwargs...) where {Dim, RealT}
interpolate(nodeset, nodeset, values, kernel; kwargs...)
end

# Evaluate interpolant
function (itp::Interpolation)(x)
s = 0
Expand Down
42 changes: 42 additions & 0 deletions src/kernel_matrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,48 @@ function polynomial_matrix(nodeset, ps)
return A
end

"""
interpolation_matrix(nodeset, kernel, ps, reg)
Return the interpolation matrix for the `nodeset`, `kernel`, polynomials `ps`, and regularization `reg`.
The interpolation matrix is defined as
```math
A = \begin{pmatrix}K & P\\P^T & 0\end{pmatrix},
```
where ``K`` is the [`regularize!`](@ref)d [`kernel_matrix`](@ref) and ``P`` the [`polynomial_matrix`](@ref)`.
"""
function interpolation_matrix(nodeset, kernel, ps, reg)
q = length(ps)
k_matrix = kernel_matrix(nodeset, kernel)
regularize!(k_matrix, reg)
p_matrix = polynomial_matrix(nodeset, ps)
system_matrix = [k_matrix p_matrix
p_matrix' zeros(q, q)]
return Symmetric(system_matrix)
end

"""
least_squares_matrix(nodeset, centers, kernel, ps, reg)
Return the least squares matrix for the `nodeset`, `centers`, `kernel`, polynomials `ps`, and regularization `reg`.
The least squares matrix is defined as
```math
A = \begin{pmatrix}K & P_1\\P_2' & 0\end{pmatrix},
```
where ``K`` is the [`regularize!`](@ref)d [`kernel_matrix`](@ref), ``P_1`` the [`polynomial_matrix`](@ref)`
for the `nodeset` and ``P_2`` the [`polynomial_matrix`](@ref)` for the `centers`.
"""
function least_squares_matrix(nodeset, centers, kernel, ps, reg)
q = length(ps)
k_matrix = kernel_matrix(nodeset, centers, kernel)
regularize!(k_matrix, reg)
p_matrix1 = polynomial_matrix(nodeset, ps)
p_matrix2 = polynomial_matrix(centers, ps)
system_matrix = [k_matrix p_matrix1
p_matrix2' zeros(q, q)]
return system_matrix
end

@doc raw"""
pde_matrix(diff_op_or_pde, nodeset1, nodeset2, kernel)
Expand Down
39 changes: 39 additions & 0 deletions src/regularization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
AbstractRegularization
An abstract supertype of regularizations. A regularization implements a function
[`regularize!`](@ref) that takes a matrix and returns a regularized version of it.
"""
abstract type AbstractRegularization end

"""
regularize!(A, reg::AbstractRegularization)
Apply the regularization `reg` to the matrix `A` in place.
"""
function regularize! end

"""
NoRegularization()
A regularization that does nothing.
"""
struct NoRegularization <: AbstractRegularization end

function regularize!(A, ::NoRegularization)
return nothing
end

"""
L2Regularization(regularization_parameter::Real)
A regularization that adds a multiple of the identity matrix to the input matrix.
"""
struct L2Regularization{RealT <: Real} <: AbstractRegularization
regularization_parameter::RealT
end

function regularize!(A, reg::L2Regularization)
A[diagind(A)] .+= reg.regularization_parameter
return nothing
end
11 changes: 11 additions & 0 deletions test/test_examples_interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,14 @@ end
l2_ls=0.5375130503454387, linf_ls=0.06810374254243684,
interpolation_test=false, least_square_test=true)
end

@testitem "regularization_2d.jl" setup=[
Setup,
AdditionalImports,
InterpolationExamples
] begin
@test_include_example(joinpath(EXAMPLES_DIR, "regularization_2d.jl"),
l2=1.2759520194191292, linf=0.19486087346749836,
l2_reg=0.40908417484986226, linf_reg=0.034926306286874445,
interpolation_test=false, regularization_test=true)
end
17 changes: 17 additions & 0 deletions test/test_unit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ end
@test isapprox(kernel_norm(itp), 2.5193566316951626)

# Conditionally positive definite kernel
# Interpolation
kernel = ThinPlateSplineKernel{dim(nodes)}()
itp = @test_nowarn interpolate(nodes, ff, kernel)
expected_coefficients = [
Expand All @@ -679,6 +680,22 @@ end
@test isapprox(itp([0.5, 0.5]), 1.0)
@test isapprox(kernel_norm(itp), 0.0)

# Regularization
itp = @test_nowarn interpolate(nodes, ff, kernel, reg = L2Regularization(1e-3))
coeffs = coefficients(itp)
@test length(coeffs) == length(expected_coefficients)
for i in eachindex(coeffs)
@test isapprox(coeffs[i], expected_coefficients[i], atol = 1e-15)
end
@test order(itp) == order(kernel)
@test length(kernel_coefficients(itp)) == length(nodes)
@test length(polynomial_coefficients(itp)) == order(itp) + 1
@test length(polynomial_basis(itp)) ==
binomial(order(itp) - 1 + dim(nodes), dim(nodes))
@test system_matrix(itp) isa Symmetric
@test size(system_matrix(itp)) == (7, 7)
@test isapprox(itp([0.5, 0.5]), 1.0)

# Least squares approximation
centers = NodeSet([0.0 0.0
1.0 0.0
Expand Down
17 changes: 14 additions & 3 deletions test/test_util.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
test_include_example(example; l2=nothing, linf=nothing,
l2_ls=nothing, linf_ls=nothing,
atol=1e-12, rtol=sqrt(eps()),
kwargs...)
Expand All @@ -12,16 +11,21 @@ macro test_include_example(example, args...)
local linf = get_kwarg(args, :linf, nothing)
local l2_ls = get_kwarg(args, :l2_ls, nothing)
local linf_ls = get_kwarg(args, :linf_ls, nothing)
local l2_reg = get_kwarg(args, :l2_reg, nothing)
local linf_reg = get_kwarg(args, :linf_reg, nothing)
local atol = get_kwarg(args, :atol, 1e-12)
local rtol = get_kwarg(args, :rtol, sqrt(eps()))
local interpolation_test = get_kwarg(args, :interpolation_test, true)
local least_square_test = get_kwarg(args, :least_square_test, false)
local regularization_test = get_kwarg(args, :regularization_test, false)
local pde_test = get_kwarg(args, :pde_test, false)
local kwargs = Pair{Symbol, Any}[]
for arg in args
if (arg.head == :(=) &&
!(arg.args[1] in (:l2, :linf, :l2_ls, :linf_ls, :atol, :rtol,
:interpolation_test, :least_square_test, :pde_test)))
!(arg.args[1] in (:l2, :linf, :l2_ls, :linf_ls, :l2_reg, :linf_reg,
:atol, :rtol,
:interpolation_test, :least_square_test, :regularization_test,
:pde_test)))
push!(kwargs, Pair(arg.args...))
end
end
Expand Down Expand Up @@ -55,6 +59,13 @@ macro test_include_example(example, args...)
@test isapprox(norm(many_values .- many_values_ls, Inf), $linf_ls;
atol = $atol, rtol = $rtol)
end
if $regularization_test
many_values_reg = itp_reg.(many_nodes)
@test isapprox(norm(many_values .- many_values_reg), $l2_reg;
atol = $atol, rtol = $rtol)
@test isapprox(norm(many_values .- many_values_reg, Inf), $linf_reg;
atol = $atol, rtol = $rtol)
end
else
# PDE test
# assumes `many_nodes`, `nodes_inner` and `nodeset_boundary` are defined in the example
Expand Down

0 comments on commit 8c2b6b0

Please sign in to comment.