Skip to content

Commit

Permalink
Updates for latest StatsLearnModels.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
juliohm committed Feb 1, 2025
1 parent d10c45f commit 886d66d
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 66 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ GeoStatsTransforms = "0.10"
GeoTables = "1.21"
LossFunctions = "1.0"
Meshes = "0.47 - 0.52"
StatsLearnModels = "1.0"
StatsLearnModels = "1.1"
julia = "1.9"
18 changes: 14 additions & 4 deletions src/GeoStatsValidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,29 @@ using DensityRatioEstimation

using GeoStatsModels: GeoStatsModel
using StatsLearnModels: StatsLearnModel
using StatsLearnModels: Learn, input, output
using GeoStatsTransforms: Interpolate, InterpolateNeighbors
using StatsLearnModels: Learn
using GeoStatsTransforms: Interpolate
using GeoStatsTransforms: InterpolateNeighbors

using ColumnSelectors: selector
using GeoStatsBase: weight, folds, mean
using GeoStatsBase: WeightingMethod, DensityRatioWeighting, UniformWeighting
using GeoStatsBase: FoldingMethod, BallFolding, BlockFolding, OneFolding, UniformFolding
using LossFunctions: L2DistLoss, MisclassLoss
using LossFunctions.Traits: SupervisedLoss

include("utils.jl")
include("cverror.jl")

export cverror, LeaveOneOut, LeaveBallOut, KFoldValidation, BlockValidation, WeightedValidation, DensityRatioValidation
export
# estimators
LeaveOneOut,
LeaveBallOut,
KFoldValidation,
BlockValidation,
DensityRatioValidation,
WeightedValidation,

# main function
cverror

end
29 changes: 14 additions & 15 deletions src/cverror.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
"""
ErrorMethod
A method for estimating cross-validatory error.
A method for estimating cross-validation error.
"""
abstract type ErrorMethod end

abstract type ErrorSetup end

struct LearnSetup{M} <: ErrorSetup
struct LearnSetup{M,F,T} <: ErrorSetup
model::M
input::Vector{Symbol}
output::Vector{Symbol}
feats::F
targs::T
end

struct InterpSetup{I,M,K} <: ErrorSetup
Expand All @@ -25,16 +25,15 @@ end
"""
cverror(model::GeoStatsModel, geotable, method; kwargs...)
Estimate error of `model` in a given `geotable` with
error estimation `method` using `Interpolate` or
`InterpolateNeighbors` depending on the passed
`kwargs`.
Estimate cross-validation error of geostatistical `model`
on given `geotable` with error estimation `method` using
`Interpolate` or `InterpolateNeighbors` depending on `kwargs`.
cverror(model::StatsLearnModel, geotable, method)
cverror((model, invars => outvars), geotable, method)
cverror((model, feats => targs), geotable, method)
Estimate error of `model` in a given `geotable` with
error estimation `method` using the `Learn` transform.
Estimate cross-validation error of statistical learning `model`
on given `geotable` with error estimation `method`.
"""
function cverror end

Expand All @@ -43,9 +42,9 @@ cverror((model, cols)::Tuple{Any,Pair}, geotable::AbstractGeoTable, method::Erro

function cverror(model::StatsLearnModel, geotable::AbstractGeoTable, method::ErrorMethod)
names = setdiff(propertynames(geotable), [:geometry])
invars = input(model)(names)
outvars = output(model)(names)
setup = LearnSetup(model, invars, outvars)
feats = model.feats(names)
targs = model.targs(names)
setup = LearnSetup(model, feats, targs)
cverror(setup, geotable, method)
end

Expand All @@ -65,5 +64,5 @@ include("cverrors/loo.jl")
include("cverrors/lbo.jl")
include("cverrors/kfv.jl")
include("cverrors/bcv.jl")
include("cverrors/wcv.jl")
include("cverrors/drv.jl")
include("cverrors/wcv.jl")
13 changes: 7 additions & 6 deletions src/cverrors/bcv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
BlockValidation(sides; loss=Dict())
Cross-validation with blocks of given `sides`. Optionally,
specify `loss` function from `LossFunctions.jl` for some
of the variables. If only one side is provided, then blocks
become cubes.
specify a dictionary with `loss` functions from `LossFunctions.jl`
for some of the variables.
## References
Expand All @@ -19,12 +18,14 @@ become cubes.
of spatial models via spatial k-fold cross-validation]
(https://www.tandfonline.com/doi/full/10.1080/13658816.2017.1346255)
"""
struct BlockValidation{S} <: ErrorMethod
struct BlockValidation{S,L} <: ErrorMethod
sides::S
loss::Dict{Symbol,SupervisedLoss}
loss::L
end

BlockValidation(sides; loss=Dict()) = BlockValidation{typeof(sides)}(sides, loss)
BlockValidation(sides::Tuple; loss=Dict()) = BlockValidation(sides, assymbol(loss))

BlockValidation(sides::Number...; kwargs...) = BlockValidation(sides; kwargs...)

function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::BlockValidation)
# uniform weights
Expand Down
16 changes: 7 additions & 9 deletions src/cverrors/drv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ratio estimation, and then used in `k`-fold weighted cross-validation.
* `estimator` - Density ratio estimator (default to `LSIF()`)
* `optlib` - Optimization library (default to `default_optlib(estimator)`)
* `lambda` - Power of density ratios (default to `1.0`)
* `loss` - Dictionary with loss functions (default to `Dict()`)
Please see [DensityRatioEstimation.jl]
(https://github.com/JuliaEarth/DensityRatioEstimation.jl)
Expand All @@ -24,33 +25,30 @@ for a list of supported estimators.
* Hoffimann et al. 2020. [Geostatistical Learning: Challenges and Opportunities]
(https://arxiv.org/abs/2102.08791)
"""
struct DensityRatioValidation{T,E,O} <: ErrorMethod
struct DensityRatioValidation{T,E,O,L} <: ErrorMethod
k::Int
shuffle::Bool
lambda::T
dre::E
optlib::O
loss::Dict{Symbol,SupervisedLoss}
loss::L
end

function DensityRatioValidation(
k::Int;
shuffle=true,
lambda=1.0,
loss=Dict(),
estimator=LSIF(),
optlib=default_optlib(estimator)
optlib=default_optlib(estimator),
loss=Dict()
)
@assert k > 0 "number of folds must be positive"
@assert 0 lambda 1 "lambda must lie in [0,1]"
T = typeof(lambda)
E = typeof(estimator)
O = typeof(optlib)
DensityRatioValidation{T,E,O}(k, shuffle, lambda, estimator, optlib, loss)
DensityRatioValidation(k, shuffle, lambda, estimator, optlib, assymbol(loss))
end

function cverror(setup::LearnSetup, geotable::AbstractGeoTable, method::DensityRatioValidation)
vars = setup.input
vars = setup.feats

# density-ratio weights
weighting = DensityRatioWeighting(geotable, vars, estimator=method.dre, optlib=method.optlib)
Expand Down
10 changes: 5 additions & 5 deletions src/cverrors/kfv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
KFoldValidation(k; shuffle=true, loss=Dict())
`k`-fold cross-validation. Optionally, `shuffle` the
data, and specify `loss` function from `LossFunctions.jl`
for some of the variables.
data, and specify a dictionary with `loss` functions
from `LossFunctions.jl` for some of the variables.
## References
Expand All @@ -17,13 +17,13 @@ for some of the variables.
cross-validation and the repeated learning-testing methods]
(https://www.jstor.org/stable/2336116)
"""
struct KFoldValidation <: ErrorMethod
struct KFoldValidation{L} <: ErrorMethod
k::Int
shuffle::Bool
loss::Dict{Symbol,SupervisedLoss}
loss::L
end

KFoldValidation(k::Int; shuffle=true, loss=Dict()) = KFoldValidation(k, shuffle, loss)
KFoldValidation(k::Int; shuffle=true, loss=Dict()) = KFoldValidation(k, shuffle, assymbol(loss))

function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::KFoldValidation)
# uniform weights
Expand Down
13 changes: 6 additions & 7 deletions src/cverrors/lbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
LeaveBallOut(ball; loss=Dict())
Leave-`ball`-out (a.k.a. spatial leave-one-out) validation.
Optionally, specify `loss` function from the
[LossFunctions.jl](https://github.com/JuliaML/LossFunctions.jl)
package for some of the variables.
Optionally, specify a dictionary with `loss` functions from
`LossFunctions.jl` for some of the variables.
LeaveBallOut(radius; loss=Dict())
Expand All @@ -20,14 +19,14 @@ By default, use Euclidean ball of given `radius` in space.
for variable selection in the presence of spatial autocorrelation]
(https://onlinelibrary.wiley.com/doi/full/10.1111/geb.12161)
"""
struct LeaveBallOut{B<:MetricBall} <: ErrorMethod
struct LeaveBallOut{B,L} <: ErrorMethod
ball::B
loss::Dict{Symbol,SupervisedLoss}
loss::L
end

LeaveBallOut(ball; loss=Dict()) = LeaveBallOut{typeof(ball)}(ball, loss)
LeaveBallOut(ball; loss=Dict()) = LeaveBallOut(ball, assymbol(loss))

LeaveBallOut(radius::Number; loss=Dict()) = LeaveBallOut(MetricBall(radius), loss=loss)
LeaveBallOut(radius::Number; loss=Dict()) = LeaveBallOut(MetricBall(radius); loss=loss)

function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::LeaveBallOut)
# uniform weights
Expand Down
10 changes: 5 additions & 5 deletions src/cverrors/loo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
"""
LeaveOneOut(; loss=Dict())
Leave-one-out validation. Optionally, specify `loss` function
from `LossFunctions.jl` for some of the variables.
Leave-one-out validation. Optionally, specify a dictionary of
`loss` functions from `LossFunctions.jl` for some of the variables.
## References
* Stone. 1974. [Cross-Validatory Choice and Assessment of Statistical Predictions]
(https://rss.onlinelibrary.wiley.com/doi/abs/10.1111/j.2517-6161.1974.tb00994.x)
"""
struct LeaveOneOut <: ErrorMethod
loss::Dict{Symbol,SupervisedLoss}
struct LeaveOneOut{L} <: ErrorMethod
loss::L
end

LeaveOneOut(; loss=Dict()) = LeaveOneOut(loss)
LeaveOneOut(; loss=Dict()) = LeaveOneOut(assymbol(loss))

function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::LeaveOneOut)
# uniform weights
Expand Down
21 changes: 9 additions & 12 deletions src/cverrors/wcv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
An error estimation method which samples are weighted with
`weighting` method and split into folds with `folding` method.
Weights are raised to `lambda` power in `[0,1]`. Optionally,
specify `loss` function from `LossFunctions.jl` for some of
the variables.
specify a dictionary with `loss` functions from `LossFunctions.jl`
for some of the variables.
## References
Expand All @@ -18,20 +18,17 @@ the variables.
* Sugiyama et al. 2007. [Covariate shift adaptation by importance weighted
cross validation](http://www.jmlr.org/papers/volume8/sugiyama07a/sugiyama07a.pdf)
"""
struct WeightedValidation{W<:WeightingMethod,F<:FoldingMethod,T<:Real} <: ErrorMethod
struct WeightedValidation{W<:WeightingMethod,F<:FoldingMethod,T,L} <: ErrorMethod
weighting::W
folding::F
lambda::T
loss::Dict{Symbol,SupervisedLoss}

function WeightedValidation{W,F,T}(weighting, folding, lambda, loss) where {W,F,T}
@assert 0 lambda 1 "lambda must lie in [0,1]"
new(weighting, folding, lambda, loss)
end
loss::L
end

WeightedValidation(weighting::W, folding::F; lambda::T=1.0, loss=Dict()) where {W,F,T} =
WeightedValidation{W,F,T}(weighting, folding, lambda, loss)
function WeightedValidation(weighting, folding; lambda=1.0, loss=Dict())
@assert 0 lambda 1 "lambda must lie in [0,1]"
WeightedValidation(weighting, folding, lambda, loss)
end

function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::WeightedValidation)
ovars = _outputs(setup, geotable)
Expand Down Expand Up @@ -80,7 +77,7 @@ end

# output variables
_outputs(::InterpSetup, gtb) = setdiff(propertynames(gtb), [:geometry])
_outputs(s::LearnSetup, gtb) = s.output
_outputs(s::LearnSetup, gtb) = s.targs

# prediction for a given fold
function _prediction(s::InterpSetup{I}, geotable, f) where {I}
Expand Down
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
defaultloss(val) = defaultloss(scitype(val))
defaultloss(::Type{Continuous}) = L2DistLoss()
defaultloss(::Type{Categorical}) = MisclassLoss()

assymbol(obj) = Dict(Symbol.(keys(obj)) .=> values(obj))
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
GeoStatsModels = "ad987403-13c5-47b5-afee-0a48f6ac4f12"
GeoTables = "e502b557-6362-48c1-8219-d30d308dcdb0"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsLearnModels = "c146b59d-1589-421c-8e09-a22e554fd05c"
Expand Down
25 changes: 23 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using GeoStatsValidation
using StatsLearnModels
using GeoStatsModels
using LossFunctions
using GeoTables
using Meshes
using Random
Expand All @@ -16,8 +17,28 @@ using Test
model = DecisionTreeClassifier()

# dummy classifier → 0.5 misclassification rate
for method in
[LeaveOneOut(), LeaveBallOut(0.1), KFoldValidation(10), BlockValidation(0.1), DensityRatioValidation(10)]
for method in [
# methods
LeaveOneOut(),
LeaveBallOut(0.1),
KFoldValidation(10),
BlockValidation(0.1),
DensityRatioValidation(10)
]
e = cverror((model, :x => :y), gtb, method)
@test isapprox(e[:y], 0.5, atol=0.06)
end

# test with custom loss
loss = Dict("y" => MisclassLoss())
for method in [
# methods
LeaveOneOut(loss=loss),
LeaveBallOut(0.1, loss=loss),
KFoldValidation(10, loss=loss),
BlockValidation(0.1, loss=loss),
DensityRatioValidation(10, loss=loss)
]
e = cverror((model, :x => :y), gtb, method)
@test isapprox(e[:y], 0.5, atol=0.06)
end
Expand Down

0 comments on commit 886d66d

Please sign in to comment.