diff --git a/src/evaluation/cross_validation.jl b/src/evaluation/cross_validation.jl index 75c291b..4fe2d80 100644 --- a/src/evaluation/cross_validation.jl +++ b/src/evaluation/cross_validation.jl @@ -3,27 +3,28 @@ export cross_validation, leave_one_out """ cross_validation( n_folds::Integer, - metric::RankingMetric, + metric::Union{RankingMetric, AggregatedMetric, Coverage, Novelty}, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, - recommender_args... + recommender_args...; + allow_repeat::Bool=false ) Conduct `n_folds` cross validation for a combination of recommender `recommender_type` and ranking metric `metric`. A recommender is initialized with `recommender_args` and runs top-`k` recommendation. """ -function cross_validation(n_folds::Integer, metric::RankingMetric, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...) +function cross_validation(n_folds::Integer, metric::Metric, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...; allow_repeat=false) cross_validation(n_folds, [metric], topk, recommender_type, data, recommender_args...)[1] end function cross_validation(n_folds::Integer, metrics::AbstractVector{T}, topk::Integer, - recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args... - ) where T<:RankingMetric + recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...; allow_repeat::Bool=false + ) where T<:Metric accum_accuracy = zeros(length(metrics)) for (train_data, truth_data) in split_data(data, n_folds) recommender = recommender_type(train_data, recommender_args...) fit!(recommender) - accum_accuracy += evaluate(recommender, truth_data, metrics, topk) + accum_accuracy += evaluate(recommender, truth_data, metrics, topk; allow_repeat=allow_repeat) end accum_accuracy / n_folds end