-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
General regularisation #66
base: master
Are you sure you want to change the base?
Conversation
@@ -92,7 +92,7 @@ sinkhorn2(μ, ν, C, ε) | |||
# resulting transport plan $\gamma$ is *sparse*. We take advantage of this and represent it as | |||
# a sparse matrix. | |||
|
|||
quadreg(μ, ν, C, ε; maxiter=500); | |||
ot_reg_plan(μ, ν, C, ε; reg_func = "L2", method = "lorenz", maxiter=500); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
ot_reg_plan(μ, ν, C, ε; reg_func = "L2", method = "lorenz", maxiter=500); | |
ot_reg_plan(μ, ν, C, ε; reg_func="L2", method="lorenz", maxiter=500); |
@@ -188,7 +188,7 @@ heatmap( | |||
# Notice how the "edges" of the transport plan are sharper if we use quadratic regularisation | |||
# instead of entropic regularisation: | |||
|
|||
γquad = Matrix(quadreg(μ, ν, C, 5; maxiter=500)) | |||
γquad = Matrix(ot_reg_plan(μ, ν, C, 5; reg_func = "L2", method = "lorenz", maxiter=500)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
γquad = Matrix(ot_reg_plan(μ, ν, C, 5; reg_func = "L2", method = "lorenz", maxiter=500)) | |
γquad = Matrix(ot_reg_plan(μ, ν, C, 5; reg_func="L2", method="lorenz", maxiter=500)) |
See also: [`ot_reg_plan`](@ref) | ||
|
||
""" | ||
function ot_reg_cost(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function ot_reg_cost(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...) | |
nothing |
src/OptimalTransport.jl
Outdated
quadreg(mu, nu, C, eps; kwargs...) | ||
else | ||
@warn "Unimplemented" | ||
nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
nothing |
Pull Request Test Coverage Report for Build 869354795Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
Codecov Report
@@ Coverage Diff @@
## master #66 +/- ##
==========================================
- Coverage 85.71% 85.50% -0.22%
==========================================
Files 2 2
Lines 259 269 +10
==========================================
+ Hits 222 230 +8
- Misses 37 39 +2
Continue to review full report at Codecov.
|
Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. Applied Mathematics & Optimization, pp.1-31. | ||
""" | ||
function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...) | ||
if (reg_func == "L2") && (method == "lorenz") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach is problematic in my opinion: it is not possible to add support for other methods or regularization types for users or downstream packages, one always has to modify this function here.
I guess this could be avoided with the suggestion in #63 (comment) - everyone could just add other regularizations and/or algorithms.
Can this PR be closed, @zsteve ? |
Instead of using the
quadreg
function call for the quadratic regularisation, this PR provides a general set of functionsot_reg_plan
andot_reg_cost
for computing OT with 'general' regularisations. Nowquadreg
is wrapped asreg_func = "L2"
, andmethod = "lorenz"
selects the current algorithm. However, there are other approaches out there (e.g. L-BFGS-B) for the same OT problem that could be implemented using the same interface.The idea is that eventually more regularisations can be added (see e.g. https://arxiv.org/abs/1710.06276), and there should be a unified interface for this.