Skip to content

Commit

Permalink
Symstr (#45)
Browse files Browse the repository at this point in the history
* penalties for basic regressions can be specified with  strings like sklearn

* patch release
  • Loading branch information
tlienart authored Dec 27, 2019
1 parent 289a373 commit aa7c4a9
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJLinearModels"
uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692"
authors = ["Thibaut Lienart <[email protected]>"]
version = "0.2.3"
version = "0.2.4"

This comment has been minimized.

Copy link
@tlienart

tlienart Dec 27, 2019

Author Collaborator

[deps]
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down
8 changes: 4 additions & 4 deletions src/mlj/classifiers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
@with_kw_noshow mutable struct LogisticClassifier <: MLJBase.Probabilistic
lambda::Real = 1.0
gamma::Real = 0.0
penalty::Symbol = :l2
penalty::SymStr = :l2
fit_intercept::Bool = true
penalize_intercept::Bool = false
solver::Option{Solver} = nothing
multi_class::Bool = false
end

glr(m::LogisticClassifier) = LogisticRegression(m.lambda, m.gamma; penalty=m.penalty,
glr(m::LogisticClassifier) = LogisticRegression(m.lambda, m.gamma; penalty=Symbol(m.penalty),
multi_class=m.multi_class,
fit_intercept=m.fit_intercept,
penalize_intercept=m.penalize_intercept)
Expand All @@ -26,13 +26,13 @@ descr(::Type{LogisticClassifier}) = "Classifier corresponding to the loss functi
@with_kw_noshow mutable struct MultinomialClassifier <: MLJBase.Probabilistic
lambda::Real = 1.0
gamma::Real = 0.0
penalty::Symbol = :l2
penalty::SymStr = :l2
fit_intercept::Bool = true
penalize_intercept::Bool = false
solver::Option{Solver} = nothing
end

glr(m::MultinomialClassifier) = MultinomialRegression(m.lambda, m.gamma; penalty=m.penalty,
glr(m::MultinomialClassifier) = MultinomialRegression(m.lambda, m.gamma; penalty=Symbol(m.penalty),
fit_intercept=m.fit_intercept,
penalize_intercept=m.penalize_intercept)

Expand Down
2 changes: 2 additions & 0 deletions src/mlj/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ export LinearRegressor, RidgeRegressor, LassoRegressor, ElasticNetRegressor,
RobustRegressor, HuberRegressor, QuantileRegressor, LADRegressor,
LogisticClassifier, MultinomialClassifier

const SymStr = Union{Symbol,String}

include("regressors.jl")
include("classifiers.jl")

Expand Down
17 changes: 9 additions & 8 deletions src/mlj/regressors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ descr(::Type{ElasticNetRegressor}) = "Regression with objective function ``|Xθ
rho::RobustRho = HuberRho(0.1)
lambda::Real = 1.0
gamma::Real = 0.0
penalty::Symbol = :l2
penalty::SymStr = :l2
fit_intercept::Bool = true
penalize_intercept::Bool = false
solver::Option{Solver} = nothing
end

glr(m::RobustRegressor) = RobustRegression(m.rho, m.lambda, m.gamma; penalty=m.penalty,
glr(m::RobustRegressor) = RobustRegression(m.rho, m.lambda, m.gamma; penalty=Symbol(m.penalty),
fit_intercept=m.fit_intercept,
penalize_intercept=m.penalize_intercept)

Expand All @@ -89,13 +89,13 @@ descr(::Type{RobustRegressor}) = "Robust regression with objective ``∑ρ(Xθ -
delta::Real = 0.5
lambda::Real = 1.0
gamma::Real = 0.0
penalty::Symbol = :l2
penalty::SymStr = :l2
fit_intercept::Bool = true
penalize_intercept::Bool = false
solver::Option{Solver} = nothing
end

glr(m::HuberRegressor) = HuberRegression(m.delta, m.lambda, m.gamma; penalty=m.penalty,
glr(m::HuberRegressor) = HuberRegression(m.delta, m.lambda, m.gamma; penalty=Symbol(m.penalty),
fit_intercept=m.fit_intercept,
penalize_intercept=m.penalize_intercept)

Expand All @@ -109,13 +109,14 @@ descr(::Type{HuberRegressor}) = "Robust regression with objective ``∑ρ(Xθ -
delta::Real = 0.5
lambda::Real = 1.0
gamma::Real = 0.0
penalty::Symbol = :l2
penalty::SymStr = :l2
fit_intercept::Bool = true
penalize_intercept::Bool = false
solver::Option{Solver} = nothing
end

glr(m::QuantileRegressor) = QuantileRegression(m.delta, m.lambda, m.gamma; penalty=m.penalty,
glr(m::QuantileRegressor) = QuantileRegression(m.delta, m.lambda, m.gamma;
penalty=Symbol(m.penalty),
fit_intercept=m.fit_intercept,
penalize_intercept=m.penalize_intercept)

Expand All @@ -128,13 +129,13 @@ descr(::Type{QuantileRegressor}) = "Robust regression with objective ``∑ρ(Xθ
@with_kw_noshow mutable struct LADRegressor <: MLJBase.Deterministic
lambda::Real = 1.0
gamma::Real = 0.0
penalty::Symbol = :l2
penalty::SymStr = :l2
fit_intercept::Bool = true
penalize_intercept::Bool = false
solver::Option{Solver} = nothing
end

glr(m::LADRegressor) = LADRegression(m.lambda, m.gamma; penalty=m.penalty,
glr(m::LADRegressor) = LADRegression(m.lambda, m.gamma; penalty=Symbol(m.penalty),
fit_intercept=m.fit_intercept,
penalize_intercept=m.penalize_intercept)

Expand Down
9 changes: 9 additions & 0 deletions test/interface/fitpredict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,12 @@ end
mcr = MLJBase.misclassification_rate(ŷ, yc)
@test mcr 0.2
end

# see issue https://github.com/alan-turing-institute/MLJ.jl/issues/387
@testset "String-Symbol" begin
model = LogisticClassifier(penalty="l1")
@test model.penalty == "l1"
gr = MLJLinearModels.glr(model)
@test gr isa GLR
@test gr.penalty isa ScaledPenalty{L1Penalty}
end

1 comment on commit aa7c4a9

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/7199

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.4 -m "<description of version>" aa7c4a93047e0fb8e9ff99e8105bfca79e87cba0
git push origin v0.2.4

Please sign in to comment.