Skip to content

Commit

Permalink
Merge pull request #28 from takuti/compat-fit
Browse files Browse the repository at this point in the history
Rename `build!` to `fit!` to align with the ML standard
  • Loading branch information
takuti authored Feb 2, 2022
2 parents f56a4a5 + d85fece commit d270e4e
Show file tree
Hide file tree
Showing 31 changed files with 125 additions and 81 deletions.
Binary file modified docs/src/assets/images/overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions docs/src/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ recommender = MostPopular(data)
and building a recommendation engine should be easy:

```julia
build!(recommender)
fit!(recommender)
```

Personalized recommenders sometimes require us to specify the hyperparameters:
Expand All @@ -70,7 +70,7 @@ help?> Recommendation.MatrixFactorization

```julia
recommender = MatrixFactorization(data, 2)
build!(recommender, learning_rate=15e-4, max_iter=100)
fit!(recommender, learning_rate=15e-4, max_iter=100)
```

Once a recommendation engine has been built successfully, top-`2` recommendation for a user `4` is performed as follows:
Expand Down
4 changes: 3 additions & 1 deletion src/Recommendation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ using Random
include("types.jl")
include("utils.jl")

include("base_recommender.jl")
include("data_accessor.jl")
include("base_recommender.jl")

include("baseline/user_mean.jl")
include("baseline/item_mean.jl")
Expand All @@ -33,4 +33,6 @@ include("metric/ranking.jl")
include("evaluation/evaluate.jl")
include("evaluation/cross_validation.jl")

include("compat.jl")

end # module
27 changes: 20 additions & 7 deletions src/base_recommender.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,31 @@
export Recommender
export isbuilt, check_build_status, build!, recommend, predict, ranking
export isdefined, validate, fit!, recommend, predict, ranking

abstract type Recommender end

function check_build_status(recommender::Recommender)
if !isbuilt(recommender)
function validate(recommender::Recommender)
if !isdefined(recommender)
error("Recommender $(typeof(recommender)) is not built before making recommendation")
end
end

isbuilt(recommender::Recommender) = true
function validate(recommender::Recommender, data::DataAccessor)
validate(recommender)

function build!(recommender::Recommender; kwargs...)
error("build! is not implemented for recommender type $(typeof(recommender))")
n_rec_user, n_rec_item = size(recommender.data.R)
n_data_user, n_data_item = size(data.R)

if n_rec_user != n_data_user
error("number of users is mismatched: (recommender, target) = ($(n_rec_user), $(n_data_user)")
elseif n_rec_item != n_data_item
error("number of items is mismatched: (recommender, target) = ($(n_rec_item), $(n_data_item)")
end
end

isdefined(recommender::Recommender) = true

function fit!(recommender::Recommender; kwargs...)
error("fit! is not implemented for recommender type $(typeof(recommender))")
end

function recommend(recommender::Recommender, u::Integer, k::Integer, candidates::Array{T}) where {T<:Integer}
Expand All @@ -32,6 +45,6 @@ end

# Return a ranking score of item i for user u
function ranking(recommender::Recommender, u::Integer, i::Integer)
check_build_status(recommender)
validate(recommender)
predict(recommender, u, i)
end
6 changes: 3 additions & 3 deletions src/baseline/co_occurrence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ struct CoOccurrence <: Recommender
end
end

isbuilt(recommender::CoOccurrence) = isfilled(recommender.scores)
isdefined(recommender::CoOccurrence) = isfilled(recommender.scores)

function build!(recommender::CoOccurrence)
function fit!(recommender::CoOccurrence)
n_item = size(recommender.data.R, 2)

v_ref = recommender.data.R[:, recommender.i_ref]
Expand All @@ -37,6 +37,6 @@ function build!(recommender::CoOccurrence)
end

function ranking(recommender::CoOccurrence, u::Integer, i::Integer)
check_build_status(recommender)
validate(recommender)
recommender.scores[i]
end
6 changes: 3 additions & 3 deletions src/baseline/item_mean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ struct ItemMean <: Recommender
end
end

isbuilt(recommender::ItemMean) = isfilled(recommender.scores)
isdefined(recommender::ItemMean) = isfilled(recommender.scores)

function build!(recommender::ItemMean)
function fit!(recommender::ItemMean)
n_item = size(recommender.data.R, 2)

for i in 1:n_item
Expand All @@ -27,6 +27,6 @@ function build!(recommender::ItemMean)
end

function predict(recommender::ItemMean, u::Integer, i::Integer)
check_build_status(recommender)
validate(recommender)
recommender.scores[i]
end
6 changes: 3 additions & 3 deletions src/baseline/most_popular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ struct MostPopular <: Recommender
end
end

isbuilt(recommender::MostPopular) = isfilled(recommender.scores)
isdefined(recommender::MostPopular) = isfilled(recommender.scores)

function build!(recommender::MostPopular)
function fit!(recommender::MostPopular)
n_item = size(recommender.data.R, 2)

for i in 1:n_item
Expand All @@ -27,6 +27,6 @@ function build!(recommender::MostPopular)
end

function ranking(recommender::MostPopular, u::Integer, i::Integer)
check_build_status(recommender)
validate(recommender)
recommender.scores[i]
end
6 changes: 3 additions & 3 deletions src/baseline/threshold_percentage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ struct ThresholdPercentage <: Recommender
end
end

isbuilt(recommender::ThresholdPercentage) = isfilled(recommender.scores)
isdefined(recommender::ThresholdPercentage) = isfilled(recommender.scores)

function build!(recommender::ThresholdPercentage)
function fit!(recommender::ThresholdPercentage)
n_item = size(recommender.data.R, 2)

for i in 1:n_item
Expand All @@ -32,6 +32,6 @@ function build!(recommender::ThresholdPercentage)
end

function ranking(recommender::ThresholdPercentage, u::Integer, i::Integer)
check_build_status(recommender)
validate(recommender)
recommender.scores[i]
end
6 changes: 3 additions & 3 deletions src/baseline/user_mean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ struct UserMean <: Recommender
end
end

isbuilt(recommender::UserMean) = isfilled(recommender.scores)
isdefined(recommender::UserMean) = isfilled(recommender.scores)

function build!(recommender::UserMean)
function fit!(recommender::UserMean)
n_user = size(recommender.data.R, 1)

for u in 1:n_user
Expand All @@ -27,6 +27,6 @@ function build!(recommender::UserMean)
end

function predict(recommender::UserMean, u::Integer, i::Integer)
check_build_status(recommender)
validate(recommender)
recommender.scores[u]
end
8 changes: 8 additions & 0 deletions src/compat.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export build!

function bridge_fit!(recommender::Recommender; kwargs...)
@warn "`build!`` is deprecated and renamed to `fit!`"
fit!(recommender; kwargs...)
end

build!(recommender::Recommender; kwargs...) = bridge_fit!(recommender; kwargs...)
4 changes: 2 additions & 2 deletions src/evaluation/cross_validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function cross_validation(n_fold::Integer, metric::Type{<:RankingMetric}, k::Int

# get recommender from the specified data type
recommender = recommender_type(train_data, recommender_args...)
build!(recommender)
fit!(recommender)

accuracy = evaluate(recommender, truth_data, metric(), k)
if isnan(accuracy); continue; end
Expand Down Expand Up @@ -75,7 +75,7 @@ function cross_validation(n_fold::Integer, metric::Type{<:AccuracyMetric}, recom

# get recommender from the specified data type
recommender = recommender_type(train_data, recommender_args...)
build!(recommender)
fit!(recommender)

accuracy = evaluate(recommender, truth_data, metric())
if isnan(accuracy); continue; end
Expand Down
21 changes: 4 additions & 17 deletions src/evaluation/evaluate.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
export evaluate

function validate_size(recommender::Recommender, truth_data::DataAccessor)
n_rec_user, n_rec_item = size(recommender.data.R)
n_truth_user, n_truth_item = size(truth_data.R)

if n_rec_user != n_truth_user
error("number of users is mismatched: (recommenre, truth) = ($(n_rec_user), $(n_truth_user)")
elseif n_rec_item != n_truth_item
error("number of items is mismatched: (recommenre, truth) = ($(n_rec_item), $(n_truth_item)")
end

n_truth_user, n_truth_item
end

function evaluate(recommender::Recommender, truth_data::DataAccessor,
metric::AccuracyMetric)
check_build_status(recommender)
n_user, n_item = validate_size(recommender, truth_data)
validate(recommender, truth_data)
n_user, n_item = size(truth_data.R)

accum = 0.0

Expand All @@ -34,8 +21,8 @@ end

function evaluate(recommender::Recommender, truth_data::DataAccessor,
metric::RankingMetric, k::Integer=0)
check_build_status(recommender)
n_user, n_item = validate_size(recommender, truth_data)
validate(recommender, truth_data)
n_user, n_item = size(truth_data.R)

accum = 0.0

Expand Down
6 changes: 3 additions & 3 deletions src/model/factorization_machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ end

FactorizationMachines(data::DataAccessor) = FactorizationMachines(data, 20)

isbuilt(recommender::FactorizationMachines) = isfilled(recommender.V)
isdefined(recommender::FactorizationMachines) = isfilled(recommender.V)

function build!(recommender::FactorizationMachines;
function fit!(recommender::FactorizationMachines;
reg_w0::Float64=1e-3,
reg_w::Float64=1e-3,
reg_V::Float64=1e-3,
Expand Down Expand Up @@ -116,7 +116,7 @@ function build!(recommender::FactorizationMachines;
end

function predict(recommender::FactorizationMachines, u::Integer, i::Integer)
check_build_status(recommender)
validate(recommender)
n_user, n_item = size(recommender.data.R)

u_onehot = zeros(n_user)
Expand Down
6 changes: 3 additions & 3 deletions src/model/item_knn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ end

ItemKNN(data::DataAccessor) = ItemKNN(data, 5)

isbuilt(recommender::ItemKNN) = isfilled(recommender.sim)
isdefined(recommender::ItemKNN) = isfilled(recommender.sim)

function build!(recommender::ItemKNN; adjusted_cosine::Bool=false)
function fit!(recommender::ItemKNN; adjusted_cosine::Bool=false)
# cosine similarity

R = copy(recommender.data.R)
Expand Down Expand Up @@ -74,7 +74,7 @@ function build!(recommender::ItemKNN; adjusted_cosine::Bool=false)
end

function predict(recommender::ItemKNN, u::Integer, i::Integer)
check_build_status(recommender)
validate(recommender)

numer = denom = 0

Expand Down
6 changes: 3 additions & 3 deletions src/model/matrix_factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ const MF = MatrixFactorization

MF(data::DataAccessor) = MF(data, 20)

isbuilt(recommender::MF) = isfilled(recommender.P)
isdefined(recommender::MF) = isfilled(recommender.P)

function build!(recommender::MF;
function fit!(recommender::MF;
reg::Float64=1e-3, learning_rate::Float64=1e-3,
eps::Float64=1e-3, max_iter::Int=100,
random_init::Bool=false)
Expand Down Expand Up @@ -89,6 +89,6 @@ function build!(recommender::MF;
end

function predict(recommender::MF, u::Integer, i::Integer)
check_build_status(recommender)
validate(recommender)
dot(recommender.P[u, :], recommender.Q[i, :])
end
6 changes: 3 additions & 3 deletions src/model/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ end

SVD(data::DataAccessor) = SVD(data, 20)

isbuilt(recommender::SVD) = isfilled(recommender.U)
isdefined(recommender::SVD) = isfilled(recommender.U)

function build!(recommender::SVD)
function fit!(recommender::SVD)
R = copy(recommender.data.R)

res = svd(R)
Expand All @@ -40,6 +40,6 @@ function build!(recommender::SVD)
end

function predict(recommender::SVD, u::Integer, i::Integer)
check_build_status(recommender)
validate(recommender)
dot(recommender.U[u, :] .* recommender.S, recommender.Vt[:, i])
end
6 changes: 3 additions & 3 deletions src/model/user_knn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ end
UserKNN(data::DataAccessor, k::Integer) = UserKNN(data, k, false)
UserKNN(data::DataAccessor) = UserKNN(data, 20, false)

isbuilt(recommender::UserKNN) = isfilled(recommender.sim)
isdefined(recommender::UserKNN) = isfilled(recommender.sim)

function build!(recommender::UserKNN)
function fit!(recommender::UserKNN)
# Pearson correlation

R = copy(recommender.data.R)
Expand Down Expand Up @@ -71,7 +71,7 @@ function build!(recommender::UserKNN)
end

function predict(recommender::UserKNN, u::Integer, i::Integer)
check_build_status(recommender)
validate(recommender)

numer = denom = 0

Expand Down
2 changes: 1 addition & 1 deletion test/baseline/test_co_occurrence.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function test_co_occurrence(data)
recommender = CoOccurrence(data, 1)
build!(recommender)
fit!(recommender)
@test ranking(recommender, 1, 1) == 100.0
@test ranking(recommender, 1, 2) == 50.0
@test ranking(recommender, 1, 3) == 0.0
Expand Down
2 changes: 1 addition & 1 deletion test/baseline/test_item_mean.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function test_item_mean(data)
recommender = ItemMean(data)
build!(recommender)
fit!(recommender)
actual = predict(recommender, 1, 1)

@test actual == 2.5
Expand Down
6 changes: 3 additions & 3 deletions test/baseline/test_most_popular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@ function test_most_popular()

data = DataAccessor([1 2 3; 4 5 nothing])
recommender = MostPopular(data)
build!(recommender)
fit!(recommender)
@test ranking(recommender, 1, 1) == 2.0
@test ranking(recommender, 1, 3) == 1.0

data = DataAccessor(sparse([1 2 3; 4 5 0]))
recommender = MostPopular(data)
build!(recommender)
fit!(recommender)
@test ranking(recommender, 1, 1) == 2.0
@test ranking(recommender, 1, 3) == 1.0

n_user, n_item = 5, 10
events = [Event(1, 2, 1), Event(3, 2, 1), Event(2, 6, 4)]
data = DataAccessor(events, n_user, n_item)
recommender = MostPopular(data)
build!(recommender)
fit!(recommender)
@test ranking(recommender, 1, 1) == 0.0
@test ranking(recommender, 1, 2) == 2.0
@test ranking(recommender, 1, 6) == 1.0
Expand Down
2 changes: 1 addition & 1 deletion test/baseline/test_threshold_percentage.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function test_threshold_percentage(data)
recommender = ThresholdPercentage(data, 2.0)
build!(recommender)
fit!(recommender)
@test ranking(recommender, 1, 1) == 50.0
@test ranking(recommender, 1, 2) == 100.0
end
Expand Down
Loading

0 comments on commit d270e4e

Please sign in to comment.