Skip to content

Commit

Permalink
Let cross validators take instantiated metric(s) as argument
Browse files Browse the repository at this point in the history
because metrics don't need to be instantiated for every validation fold
unlike recommender instance.
  • Loading branch information
takuti committed Nov 25, 2022
1 parent 1b59992 commit ec030cd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 27 deletions.
36 changes: 24 additions & 12 deletions src/evaluation/cross_validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ export cross_validation, leave_one_out
"""
cross_validation(
n_folds::Integer,
metric::Type{<:RankingMetric},
metric::RankingMetric,
topk::Integer,
recommender_type::Type{<:Recommender},
data::DataAccessor,
Expand All @@ -12,40 +12,52 @@ export cross_validation, leave_one_out
Conduct `n_folds` cross validation for a combination of recommender `recommender_type` and ranking metric `metric`. A recommender is initialized with `recommender_args` and runs top-`k` recommendation.
"""
function cross_validation(n_folds::Integer, metric::Type{<:RankingMetric}, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
accum_accuracy = 0.0
function cross_validation(n_folds::Integer, metric::RankingMetric, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
cross_validation(n_folds, [metric], topk, recommender_type, data, recommender_args...)[1]
end

function cross_validation(n_folds::Integer, metrics::AbstractVector{T}, topk::Integer,
recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...
) where T<:RankingMetric
accum_accuracy = zeros(length(metrics))
for (train_data, truth_data) in split_data(data, n_folds)
recommender = recommender_type(train_data, recommender_args...)
fit!(recommender)
accum_accuracy += evaluate(recommender, truth_data, metric(), topk)
accum_accuracy += evaluate(recommender, truth_data, metrics, topk)
end
accum_accuracy / n_folds
end

"""
cross_validation(
n_folds::Integer,
metric::Type{<:AccuracyMetric},
metric::AccuracyMetric,
recommender_type::Type{<:Recommender},
data::DataAccessor,
recommender_args...
)
Conduct `n_folds` cross validation for a combination of recommender `recommender_type` and accuracy metric `metric`. A recommender is initialized with `recommender_args`.
"""
function cross_validation(n_folds::Integer, metric::Type{<:AccuracyMetric}, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
accum_accuracy = 0.0
function cross_validation(n_folds::Integer, metric::AccuracyMetric, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
cross_validation(n_folds, [metric], recommender_type, data, recommender_args...)[1]
end

function cross_validation(n_folds::Integer, metrics::AbstractVector{T},
recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...
) where T<:AccuracyMetric
accum_accuracy = zeros(length(metrics))
for (train_data, truth_data) in split_data(data, n_folds)
recommender = recommender_type(train_data, recommender_args...)
fit!(recommender)
accum_accuracy = evaluate(recommender, truth_data, metric())
accum_accuracy = evaluate(recommender, truth_data, metrics)
end
accum_accuracy / n_folds
end

"""
leave_one_out(
metric::Type{<:RankingMetric},
metric::RankingMetric,
topk::Integer,
recommender_type::Type{<:Recommender},
data::DataAccessor,
Expand All @@ -54,20 +66,20 @@ end
Conduct leave-one-out cross validation (LOOCV) for a combination of recommender `recommender_type` and accuracy metric `metric`. A recommender is initialized with `recommender_args` and runs top-`k` recommendation.
"""
function leave_one_out(metric::Type{<:RankingMetric}, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
function leave_one_out(metric::RankingMetric, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
cross_validation(length(data.events), metric, topk, recommender_type, data, recommender_args...)
end

"""
leave_one_out(
metric::Type{<:AccuracyMetric},
metric::AccuracyMetric,
recommender_type::Type{<:Recommender},
data::DataAccessor,
recommender_args...
)
Conduct leave-one-out cross validation (LOOCV) for a combination of recommender `recommender_type` and accuracy metric `metric`. A recommender is initialized with `recommender_args`.
"""
function leave_one_out(metric::Type{<:AccuracyMetric}, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
function leave_one_out(metric::AccuracyMetric, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
cross_validation(length(data.events), metric, recommender_type, data, recommender_args...)
end
30 changes: 15 additions & 15 deletions test/evaluation/test_cross_validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@ function test_cross_validation_accuracy(v)
data = DataAccessor(isa(v, Unknown) ? m : sparse(m))

# 1-fold cross validation is invalid
@test_throws ErrorException cross_validation(1, MAE, MF, data, 2)
@test_throws ErrorException cross_validation(1, MAE(), MF, data, 2)

# in n-fold cross validation, n must be smaller than or equal to the number of all samples
@test_throws ErrorException cross_validation(100, MAE, MF, data, 2)
@test_throws ErrorException cross_validation(100, MAE(), MF, data, 2)

fold = 5

# MF(data, 2)
@test 0.0 < cross_validation(fold, MAE, MF, data, 2) <= 2.5
@test 0.0 < cross_validation(fold, MAE(), MF, data, 2) <= 2.5

# UserMean(data)
@test 0.0 < cross_validation(fold, MAE, UserMean, data) <= 2.5
@test 0.0 < cross_validation(fold, MAE(), UserMean, data) <= 2.5

# ItemMean(data)
@test 0.0 < cross_validation(fold, MAE, ItemMean, data) <= 2.5
@test 0.0 < cross_validation(fold, MAE(), ItemMean, data) <= 2.5

# UserKNN(data, 2, true)
@test 0.0 < cross_validation(fold, MAE, UserKNN, data, 2, true) <= 2.5
@test 0.0 < cross_validation(fold, MAE(), UserKNN, data, 2, true) <= 2.5

# leave-one-out cross validation with MF(data, 2)
n_samples = length(data.events)
@test 0.0 < cross_validation(n_samples, MAE, MF, data, 2) <= 2.5
@test 0.0 < leave_one_out(MAE, MF, data, 2) <= 2.5
@test 0.0 < cross_validation(n_samples, MAE(), MF, data, 2) <= 2.5
@test 0.0 < leave_one_out(MAE(), MF, data, 2) <= 2.5
end

function test_cross_validation_ranking(v)
Expand All @@ -41,27 +41,27 @@ function test_cross_validation_ranking(v)
k = 4

# 1-fold cross validation is invalid
@test_throws ErrorException cross_validation(1, Recall, k, MF, data, 2)
@test_throws ErrorException cross_validation(1, Recall(), k, MF, data, 2)

# in n-fold cross validation, n must be smaller than or equal to the number of all samples
@test_throws ErrorException cross_validation(100, Recall, k, MF, data, 2)
@test_throws ErrorException cross_validation(100, Recall(), k, MF, data, 2)

# 3-fold cross validation
fold = 3

# MF(data, 2)
@test 0.0 <= cross_validation(fold, Recall, k, MF, data, 2) <= 1.0
@test 0.0 <= cross_validation(fold, Recall(), k, MF, data, 2) <= 1.0

# MostPopular(data)
@test 0.0 <= cross_validation(fold, Recall, k, MostPopular, data) <= 1.0
@test 0.0 <= cross_validation(fold, Recall(), k, MostPopular, data) <= 1.0

# UserKNN(data, 2, true)
@test 0.0 <= cross_validation(fold, Recall, k, UserKNN, data, 2, true) <= 1.0
@test 0.0 <= cross_validation(fold, Recall(), k, UserKNN, data, 2, true) <= 1.0

# leave-one-out cross validation with MF(data, 2)
n_samples = length(data.events)
@test 0.0 <= cross_validation(n_samples, Recall, k, MF, data, 2) <= 1.0
@test 0.0 <= leave_one_out(Recall, k, MF, data, 2) <= 1.0
@test 0.0 <= cross_validation(n_samples, Recall(), k, MF, data, 2) <= 1.0
@test 0.0 <= leave_one_out(Recall(), k, MF, data, 2) <= 1.0
end

println("-- Testing cross validation with accuracy metrics")
Expand Down

0 comments on commit ec030cd

Please sign in to comment.