Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General regularisation #66

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

General regularisation #66

wants to merge 7 commits into from

Conversation

zsteve
Copy link
Member

@zsteve zsteve commented May 23, 2021

Instead of using the quadreg function call for the quadratic regularisation, this PR provides a general set of functions ot_reg_plan and ot_reg_cost for computing OT with 'general' regularisations. Now quadreg is wrapped as reg_func = "L2", and method = "lorenz" selects the current algorithm. However, there are other approaches out there (e.g. L-BFGS-B) for the same OT problem that could be implemented using the same interface.

The idea is that eventually more regularisations can be added (see e.g. https://arxiv.org/abs/1710.06276), and there should be a unified interface for this.

@@ -92,7 +92,7 @@ sinkhorn2(μ, ν, C, ε)
# resulting transport plan $\gamma$ is *sparse*. We take advantage of this and represent it as
# a sparse matrix.

quadreg(μ, ν, C, ε; maxiter=500);
ot_reg_plan(μ, ν, C, ε; reg_func = "L2", method = "lorenz", maxiter=500);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
ot_reg_plan(μ, ν, C, ε; reg_func = "L2", method = "lorenz", maxiter=500);
ot_reg_plan(μ, ν, C, ε; reg_func="L2", method="lorenz", maxiter=500);

@@ -188,7 +188,7 @@ heatmap(
# Notice how the "edges" of the transport plan are sharper if we use quadratic regularisation
# instead of entropic regularisation:

γquad = Matrix(quadreg(μ, ν, C, 5; maxiter=500))
γquad = Matrix(ot_reg_plan(μ, ν, C, 5; reg_func = "L2", method = "lorenz", maxiter=500))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
γquad = Matrix(ot_reg_plan(μ, ν, C, 5; reg_func = "L2", method = "lorenz", maxiter=500))
γquad = Matrix(ot_reg_plan(μ, ν, C, 5; reg_func="L2", method="lorenz", maxiter=500))

See also: [`ot_reg_plan`](@ref)

"""
function ot_reg_cost(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function ot_reg_cost(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
nothing

quadreg(mu, nu, C, eps; kwargs...)
else
@warn "Unimplemented"
nothing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
nothing

@coveralls
Copy link

coveralls commented May 23, 2021

Pull Request Test Coverage Report for Build 869354795

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 14 of 15 (93.33%) changed or added relevant lines in 1 file are covered.
  • 2 unchanged lines in 1 file lost coverage.
  • Overall coverage decreased (-0.2%) to 85.502%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/OptimalTransport.jl 14 15 93.33%
Files with Coverage Reduction New Missed Lines %
src/OptimalTransport.jl 2 84.46%
Totals Coverage Status
Change from base Build 869183384: -0.2%
Covered Lines: 230
Relevant Lines: 269

💛 - Coveralls

@codecov-commenter
Copy link

codecov-commenter commented May 23, 2021

Codecov Report

Merging #66 (38d770d) into master (d5ebf7a) will decrease coverage by 0.21%.
The diff coverage is 80.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #66      +/-   ##
==========================================
- Coverage   85.71%   85.50%   -0.22%     
==========================================
  Files           2        2              
  Lines         259      269      +10     
==========================================
+ Hits          222      230       +8     
- Misses         37       39       +2     
Impacted Files Coverage Δ
src/OptimalTransport.jl 84.46% <80.00%> (-0.19%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d5ebf7a...38d770d. Read the comment docs.

Lorenz, D.A., Manns, P. and Meyer, C., 2019. Quadratically regularized optimal transport. Applied Mathematics & Optimization, pp.1-31.
"""
function ot_reg_plan(mu, nu, C, eps; reg_func="L2", method="lorenz", kwargs...)
if (reg_func == "L2") && (method == "lorenz")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach is problematic in my opinion: it is not possible to add support for other methods or regularization types for users or downstream packages, one always has to modify this function here.

I guess this could be avoided with the suggestion in #63 (comment) - everyone could just add other regularizations and/or algorithms.

@davibarreira
Copy link
Member

Can this PR be closed, @zsteve ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants