Skip to content

Commit

Permalink
Merge pull request #65 from takuti/cv-benchmark
Browse files Browse the repository at this point in the history
Update cross validation interfaces per recent updates on `evaluate()`
  • Loading branch information
takuti authored Nov 27, 2022
2 parents 1b59992 + 78121a6 commit 25797dc
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 68 deletions.
76 changes: 38 additions & 38 deletions examples/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ using Recommendation
value_prediction_recommenders = [
ItemMean => [],
UserMean => [],
SVD => [4],
SVD => [8],
SVD => [16],
SVD => [32],
SVD => [64],
# BPRMatrixFactorization => [],
# FactorizationMachines => [],
# MatrixFactorization => [],
Expand All @@ -41,12 +42,16 @@ rank_by_score_recommenders = [
# TfIdf => [],
]

accuracy_metrics = [
function instantiate(metrics::AbstractVector{DataType})
[metric() for metric in metrics]
end

accuracy_metrics = instantiate([
RMSE,
MAE,
]
])

topk_metrics = [
topk_metrics = instantiate([
Recall,
Precision,
AUC,
Expand All @@ -56,7 +61,14 @@ topk_metrics = [
AggregatedDiversity,
ShannonEntropy,
GiniIndex,
]
])

# * IntraListSimilarity can be calculated with an item-item similarity metrix, which can be built by ItemKNN.
# * Serendipity requires context-specific definition of relevance and unexpectedness.
intra_list_metrics = instantiate([
Coverage,
Novelty
])

datasets = [
load_movielens_100k,
Expand All @@ -65,52 +77,40 @@ datasets = [
# load_lastfm
]

test_ratio = 0.2
topk = 10

function eval(instantiated_recommender::Recommender, truth_data::DataAccessor,
metrics::AbstractVector{T}, topk=nothing) where T
instantiated_metrics = [metric() for metric in metrics]
if isnothing(topk)
# accuracy metrics
results = evaluate(instantiated_recommender, truth_data, instantiated_metrics)
else
# intra-list metrics:
# * IntraListSimilarity can be calculated with an item-item similarity metrix, which can be built by ItemKNN.
# * Serendipity requires context-specific definition of relevance and unexpectedness.
coverage, novelty = evaluate(instantiated_recommender, truth_data,
[Coverage(), Novelty()], topk, allow_repeat=true)
@info " Coverage = $coverage"
@info " Novelty = $novelty"

# ranking / aggregated metrics
results = evaluate(instantiated_recommender, truth_data, instantiated_metrics, topk)
end

for (metric, res) in zip(metrics, results)
@info " $metric = $res"
end
end
n_folds = 5

for dataset in datasets
@info "Dataset: $dataset"
data = dataset()
train_data, truth_data = split_data(data, test_ratio)

@info "Evaluating value prediction-based recommenders"
for (recommender, params) in value_prediction_recommenders
@info "Recommender: $recommender($params...)"
r = recommender(train_data, params...)
fit!(r)
eval(r, truth_data, accuracy_metrics)
eval(r, truth_data, topk_metrics, topk)
results = cross_validation(n_folds, accuracy_metrics, recommender, data, params...)
for (metric, res) in zip(accuracy_metrics, results)
@info " $metric = $res"
end
results = cross_validation(n_folds, topk_metrics, topk, recommender, data, params...)
for (metric, res) in zip(topk_metrics, results)
@info " $metric = $res"
end
results = cross_validation(n_folds, intra_list_metrics, topk, recommender, data, params..., allow_repeat=true)
for (metric, res) in zip(intra_list_metrics, results)
@info " $metric = $res"
end
end

@info "Evaluating custom ranking score-based recommenders"
for (recommender, params) in rank_by_score_recommenders
@info "Recommender: $recommender($params...)"
r = recommender(train_data, params...)
fit!(r)
eval(r, truth_data, topk_metrics, topk)
results = cross_validation(n_folds, topk_metrics, topk, recommender, data, params...)
for (metric, res) in zip(topk_metrics, results)
@info " $metric = $res"
end
results = cross_validation(n_folds, intra_list_metrics, topk, recommender, data, params..., allow_repeat=true)
for (metric, res) in zip(intra_list_metrics, results)
@info " $metric = $res"
end
end
end
39 changes: 26 additions & 13 deletions src/evaluation/cross_validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,62 @@ export cross_validation, leave_one_out
"""
cross_validation(
n_folds::Integer,
metric::Type{<:RankingMetric},
metric::Union{RankingMetric, AggregatedMetric, Coverage, Novelty},
topk::Integer,
recommender_type::Type{<:Recommender},
data::DataAccessor,
recommender_args...
recommender_args...;
allow_repeat::Bool=false
)
Conduct `n_folds` cross validation for a combination of recommender `recommender_type` and ranking metric `metric`. A recommender is initialized with `recommender_args` and runs top-`k` recommendation.
"""
function cross_validation(n_folds::Integer, metric::Type{<:RankingMetric}, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
accum_accuracy = 0.0
function cross_validation(n_folds::Integer, metric::Metric, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...; allow_repeat=false)
cross_validation(n_folds, [metric], topk, recommender_type, data, recommender_args...)[1]
end

function cross_validation(n_folds::Integer, metrics::AbstractVector{T}, topk::Integer,
recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...; allow_repeat::Bool=false
) where T<:Metric
accum_accuracy = zeros(length(metrics))
for (train_data, truth_data) in split_data(data, n_folds)
recommender = recommender_type(train_data, recommender_args...)
fit!(recommender)
accum_accuracy += evaluate(recommender, truth_data, metric(), topk)
accum_accuracy += evaluate(recommender, truth_data, metrics, topk; allow_repeat=allow_repeat)
end
accum_accuracy / n_folds
end

"""
cross_validation(
n_folds::Integer,
metric::Type{<:AccuracyMetric},
metric::AccuracyMetric,
recommender_type::Type{<:Recommender},
data::DataAccessor,
recommender_args...
)
Conduct `n_folds` cross validation for a combination of recommender `recommender_type` and accuracy metric `metric`. A recommender is initialized with `recommender_args`.
"""
function cross_validation(n_folds::Integer, metric::Type{<:AccuracyMetric}, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
accum_accuracy = 0.0
function cross_validation(n_folds::Integer, metric::AccuracyMetric, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
cross_validation(n_folds, [metric], recommender_type, data, recommender_args...)[1]
end

function cross_validation(n_folds::Integer, metrics::AbstractVector{T},
recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...
) where T<:AccuracyMetric
accum_accuracy = zeros(length(metrics))
for (train_data, truth_data) in split_data(data, n_folds)
recommender = recommender_type(train_data, recommender_args...)
fit!(recommender)
accum_accuracy = evaluate(recommender, truth_data, metric())
accum_accuracy = evaluate(recommender, truth_data, metrics)
end
accum_accuracy / n_folds
end

"""
leave_one_out(
metric::Type{<:RankingMetric},
metric::RankingMetric,
topk::Integer,
recommender_type::Type{<:Recommender},
data::DataAccessor,
Expand All @@ -54,20 +67,20 @@ end
Conduct leave-one-out cross validation (LOOCV) for a combination of recommender `recommender_type` and accuracy metric `metric`. A recommender is initialized with `recommender_args` and runs top-`k` recommendation.
"""
function leave_one_out(metric::Type{<:RankingMetric}, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
function leave_one_out(metric::RankingMetric, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
cross_validation(length(data.events), metric, topk, recommender_type, data, recommender_args...)
end

"""
leave_one_out(
metric::Type{<:AccuracyMetric},
metric::AccuracyMetric,
recommender_type::Type{<:Recommender},
data::DataAccessor,
recommender_args...
)
Conduct leave-one-out cross validation (LOOCV) for a combination of recommender `recommender_type` and accuracy metric `metric`. A recommender is initialized with `recommender_args`.
"""
function leave_one_out(metric::Type{<:AccuracyMetric}, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
function leave_one_out(metric::AccuracyMetric, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
cross_validation(length(data.events), metric, recommender_type, data, recommender_args...)
end
4 changes: 2 additions & 2 deletions src/evaluation/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function evaluate(recommender::Recommender, truth_data::DataAccessor,
end

function evaluate(recommender::Recommender, truth_data::DataAccessor,
metric::Metric, topk::Integer; allow_repeat=false)
metric::Metric, topk::Integer; allow_repeat::Bool=false)
evaluate(recommender, truth_data, [metric], topk, allow_repeat=allow_repeat)[1]
end

Expand All @@ -28,7 +28,7 @@ function check_metrics_type(metrics::AbstractVector{T},
end

function evaluate(recommender::Recommender, truth_data::DataAccessor,
metrics::AbstractVector{T}, topk::Integer; allow_repeat=false) where {T<:Metric}
metrics::AbstractVector{T}, topk::Integer; allow_repeat::Bool=false) where {T<:Metric}
validate(recommender, truth_data)
check_metrics_type(metrics, Union{RankingMetric, AggregatedMetric, Coverage, Novelty})

Expand Down
30 changes: 15 additions & 15 deletions test/evaluation/test_cross_validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@ function test_cross_validation_accuracy(v)
data = DataAccessor(isa(v, Unknown) ? m : sparse(m))

# 1-fold cross validation is invalid
@test_throws ErrorException cross_validation(1, MAE, MF, data, 2)
@test_throws ErrorException cross_validation(1, MAE(), MF, data, 2)

# in n-fold cross validation, n must be smaller than or equal to the number of all samples
@test_throws ErrorException cross_validation(100, MAE, MF, data, 2)
@test_throws ErrorException cross_validation(100, MAE(), MF, data, 2)

fold = 5

# MF(data, 2)
@test 0.0 < cross_validation(fold, MAE, MF, data, 2) <= 2.5
@test 0.0 < cross_validation(fold, MAE(), MF, data, 2) <= 2.5

# UserMean(data)
@test 0.0 < cross_validation(fold, MAE, UserMean, data) <= 2.5
@test 0.0 < cross_validation(fold, MAE(), UserMean, data) <= 2.5

# ItemMean(data)
@test 0.0 < cross_validation(fold, MAE, ItemMean, data) <= 2.5
@test 0.0 < cross_validation(fold, MAE(), ItemMean, data) <= 2.5

# UserKNN(data, 2, true)
@test 0.0 < cross_validation(fold, MAE, UserKNN, data, 2, true) <= 2.5
@test 0.0 < cross_validation(fold, MAE(), UserKNN, data, 2, true) <= 2.5

# leave-one-out cross validation with MF(data, 2)
n_samples = length(data.events)
@test 0.0 < cross_validation(n_samples, MAE, MF, data, 2) <= 2.5
@test 0.0 < leave_one_out(MAE, MF, data, 2) <= 2.5
@test 0.0 < cross_validation(n_samples, MAE(), MF, data, 2) <= 2.5
@test 0.0 < leave_one_out(MAE(), MF, data, 2) <= 2.5
end

function test_cross_validation_ranking(v)
Expand All @@ -41,27 +41,27 @@ function test_cross_validation_ranking(v)
k = 4

# 1-fold cross validation is invalid
@test_throws ErrorException cross_validation(1, Recall, k, MF, data, 2)
@test_throws ErrorException cross_validation(1, Recall(), k, MF, data, 2)

# in n-fold cross validation, n must be smaller than or equal to the number of all samples
@test_throws ErrorException cross_validation(100, Recall, k, MF, data, 2)
@test_throws ErrorException cross_validation(100, Recall(), k, MF, data, 2)

# 3-fold cross validation
fold = 3

# MF(data, 2)
@test 0.0 <= cross_validation(fold, Recall, k, MF, data, 2) <= 1.0
@test 0.0 <= cross_validation(fold, Recall(), k, MF, data, 2) <= 1.0

# MostPopular(data)
@test 0.0 <= cross_validation(fold, Recall, k, MostPopular, data) <= 1.0
@test 0.0 <= cross_validation(fold, Recall(), k, MostPopular, data) <= 1.0

# UserKNN(data, 2, true)
@test 0.0 <= cross_validation(fold, Recall, k, UserKNN, data, 2, true) <= 1.0
@test 0.0 <= cross_validation(fold, Recall(), k, UserKNN, data, 2, true) <= 1.0

# leave-one-out cross validation with MF(data, 2)
n_samples = length(data.events)
@test 0.0 <= cross_validation(n_samples, Recall, k, MF, data, 2) <= 1.0
@test 0.0 <= leave_one_out(Recall, k, MF, data, 2) <= 1.0
@test 0.0 <= cross_validation(n_samples, Recall(), k, MF, data, 2) <= 1.0
@test 0.0 <= leave_one_out(Recall(), k, MF, data, 2) <= 1.0
end

println("-- Testing cross validation with accuracy metrics")
Expand Down

0 comments on commit 25797dc

Please sign in to comment.