Skip to content

Commit

Permalink
Add leave-one-out CV as a synonym of k-fold CV
Browse files Browse the repository at this point in the history
  • Loading branch information
takuti committed Apr 3, 2022
1 parent 45025aa commit 4cd0ced
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/src/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Pages = ["evaluation.md"]

```@docs
cross_validation
leave_one_out
```

## Rating metrics
Expand Down
31 changes: 30 additions & 1 deletion src/evaluation/cross_validation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export cross_validation
export cross_validation, leave_one_out

"""
cross_validation(
Expand Down Expand Up @@ -114,3 +114,32 @@ function cross_validation(n_folds::Integer, metric::Type{<:AccuracyMetric}, reco

accum / n_folds
end

"""
leave_one_out(
metric::Type{<:RankingMetric},
topk::Integer,
recommender_type::Type{<:Recommender},
data::DataAccessor,
recommender_args...
)
Conduct leave-one-out cross validation (LOOCV) for a combination of recommender `recommender_type` and accuracy metric `metric`. A recommender is initialized with `recommender_args` and runs top-`k` recommendation.
"""
function leave_one_out(metric::Type{<:RankingMetric}, topk::Integer, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
cross_validation(length(data.events), metric, topk, recommender_type, data, recommender_args...)
end

"""
leave_one_out(
metric::Type{<:AccuracyMetric},
recommender_type::Type{<:Recommender},
data::DataAccessor,
recommender_args...
)
Conduct leave-one-out cross validation (LOOCV) for a combination of recommender `recommender_type` and accuracy metric `metric`. A recommender is initialized with `recommender_args`.
"""
function leave_one_out(metric::Type{<:AccuracyMetric}, recommender_type::Type{<:Recommender}, data::DataAccessor, recommender_args...)
cross_validation(length(data.events), metric, recommender_type, data, recommender_args...)
end
2 changes: 2 additions & 0 deletions test/evaluation/test_cross_validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ function test_cross_validation_accuracy(v)
# leave-one-out cross validation with MF(data, 2)
n_samples = length(data.events)
@test 0.0 < cross_validation(n_samples, MAE, MF, data, 2) <= 2.5
@test 0.0 < leave_one_out(MAE, MF, data, 2) <= 2.5
end

function test_cross_validation_ranking(v)
Expand Down Expand Up @@ -60,6 +61,7 @@ function test_cross_validation_ranking(v)
# leave-one-out cross validation with MF(data, 2)
n_samples = length(data.events)
@test 0.0 <= cross_validation(n_samples, Recall, k, MF, data, 2) <= 1.0
@test 0.0 <= leave_one_out(Recall, k, MF, data, 2) <= 1.0
end

println("-- Testing cross validation with accuracy metrics")
Expand Down

0 comments on commit 4cd0ced

Please sign in to comment.