From 2ef3e2b3e6837037c69bb8edebf7f4dce53e9b61 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Sat, 2 Oct 2021 13:14:12 +1000 Subject: [PATCH 01/23] first attempt at gromov-wasserstein --- src/gromov.jl | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 src/gromov.jl diff --git a/src/gromov.jl b/src/gromov.jl new file mode 100644 index 00000000..de43e947 --- /dev/null +++ b/src/gromov.jl @@ -0,0 +1,75 @@ +# Gromov-Wasserstein solver + +abstract type EntropicGromovWasserstein end + +struct EntropicGromovWassersteinGibbs <: EntropicGromovWasserstein + alg_step::Sinkhorn +end + +function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ::AbstractMatrix, Cν::AbstractMatrix, ε::Real, + alg::EntropicGromovWassersteinGibbs; atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) + T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) + C = similar(Cμ, T, size(μ, 1), size(ν, 1)) + tmp = similar(C) + plan = similar(C) + @. plan = μ * ν' + + function get_new_cost!(C, plan, tmp, Cμ, Cν) + A_batched_mul_B!(tmp, Cμ, plan) + A_batched_mul_B!(C, tmp, Cν) + end + + get_new_cost!(C, plan, tmp, Cμ, Cν) + solver_step = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) + to_check_step = check_convergence + for iter in 1:maxiter + # perform Sinkhorn algorithm + solve!(solver_step) + # compute optimal transport plan + plan = sinkhorn_plan(solver) + + to_check_step -= 1 + if to_check_step == 0 || iter == maxiter + # reset counter + to_check_step = check_convergence + + # TODO: convergence check + # isconverged, abserror = OptimalTransport.check_convergence(solver) + # @debug string(solver.alg) * + # " (" * + # string(iter) * + # "/" * + # string(maxiter) * + # ": absolute error of source marginal = " * + # string(maximum(abserror)) + + if isconverged + @debug "$(solver.alg) ($iter/$maxiter): converged" + break + end + end + update_cost!(solver, C) + end + + return plan +end + +# support for `SinkhornGibbs` and `SinkhornStabilized` +function update_cost!(solver::SinkhornSolver{SinkhornGibbs}, C::AbstractMatrix) + cache = solver.cache + @. cache.K = exp(-C / solver.eps) + newsolver = SinkhornSolver( + solver.source, + solver.target, + C, + solver.eps, + solver.alg, + solver.atol, + solver.rtol, + solver.maxiter, + solver.check_convergence, + cache, + solver.convergence_cache, + ) + return newsolver +end From 11efd8c669335e245a5d25a298a91d4f9c565478 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Tue, 8 Mar 2022 20:28:49 +1100 Subject: [PATCH 02/23] update --- src/gromov.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gromov.jl b/src/gromov.jl index de43e947..5e3ab14f 100644 --- a/src/gromov.jl +++ b/src/gromov.jl @@ -26,7 +26,7 @@ function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ # perform Sinkhorn algorithm solve!(solver_step) # compute optimal transport plan - plan = sinkhorn_plan(solver) + plan = sinkhorn_plan(solver_step) to_check_step -= 1 if to_check_step == 0 || iter == maxiter From 0956c3bdd2e3e44099e1eb2a1543f71cac88d57e Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Tue, 8 Mar 2022 22:26:51 +1100 Subject: [PATCH 03/23] fixed computation of entropic gromov-wasserstein --- Project.toml | 3 ++- src/#gromov.jl# | 57 +++++++++++++++++++++++++++++++++++++++ src/OptimalTransport.jl | 3 +++ src/gromov.jl | 60 +++++++++++++++-------------------------- test/gromov.jl | 29 ++++++++++++++++++++ test/gromov.jl~ | 0 6 files changed, 112 insertions(+), 40 deletions(-) create mode 100644 src/#gromov.jl# create mode 100644 test/gromov.jl create mode 100644 test/gromov.jl~ diff --git a/Project.toml b/Project.toml index dbd16a25..a6d1aaf1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,14 @@ name = "OptimalTransport" uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" authors = ["zsteve "] -version = "0.3.19" +version = "0.3.20" [deps] ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" diff --git a/src/#gromov.jl# b/src/#gromov.jl# new file mode 100644 index 00000000..72a85935 --- /dev/null +++ b/src/#gromov.jl# @@ -0,0 +1,57 @@ +# Gromov-Wasserstein solver + +abstract type EntropicGromovWasserstein end + +struct EntropicGromovWassersteinGibbs <: EntropicGromovWasserstein + alg_step::Sinkhorn +end + +function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ::AbstractMatrix, Cν::AbstractMatrix, ε::Real, + alg::EntropicGromovWasserstein = EntropicGromovWassersteinGibbs(SinkhornGibbs()); atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) + T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) + C = similar(Cμ, T, size(μ, 1), size(ν, 1)) + tmp = similar(C) + plan = similar(C) + @. plan = μ * ν' + plan_prev = similar(C) + plan_prev .= plan + norm_plan = sum(plan) + + _atol = atol === nothing ? 0 : atol + _rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol + + function get_new_cost!(C, plan, tmp, Cμ, Cν) + A_batched_mul_B!(tmp, Cμ, plan) + A_batched_mul_B!(C, tmp, -4Cν) + # seems to be a missing factor of 4 (or something like that...) compared to the POT implementation? + # added the factor of 4 here to ensure reproducibility for the same value of ε. + # https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247 + end + + get_new_cost!(C, plan, tmp, Cμ, Cν) + to_check_step = check_convergence + + isconverged = false + for iter in 1:maxiter + # perform Sinkhorn algorithm + solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) + solve!(solver) + # compute optimal transport plan + plan = sinkhorn_plan(solver) + + to_check_step -= 1 + if to_check_step == 0 || iter == maxiter + # reset counter + to_check_step = check_convergence + isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan) + if isconverged + @debug "$Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged" + break + end + plan_prev .= plan + end + get_new_cost!(C, plan, tmp, Cμ, Cν) + end + + return plan +end diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 1653431e..7eb165f4 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -13,6 +13,7 @@ using LinearAlgebra using IterativeSolvers using LogExpFunctions: LogExpFunctions using NNlib: NNlib +using Logging export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling export SinkhornBarycenterGibbs @@ -42,4 +43,6 @@ include("quadratic_newton.jl") include("dual/entropic_dual.jl") +include("gromov.jl") + end diff --git a/src/gromov.jl b/src/gromov.jl index 5e3ab14f..b77a5347 100644 --- a/src/gromov.jl +++ b/src/gromov.jl @@ -7,69 +7,51 @@ struct EntropicGromovWassersteinGibbs <: EntropicGromovWasserstein end function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ::AbstractMatrix, Cν::AbstractMatrix, ε::Real, - alg::EntropicGromovWassersteinGibbs; atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) + alg::EntropicGromovWasserstein = EntropicGromovWassersteinGibbs(SinkhornGibbs()); atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) C = similar(Cμ, T, size(μ, 1), size(ν, 1)) tmp = similar(C) plan = similar(C) @. plan = μ * ν' + plan_prev = similar(C) + plan_prev .= plan + norm_plan = sum(plan) + + _atol = atol === nothing ? 0 : atol + _rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol function get_new_cost!(C, plan, tmp, Cμ, Cν) A_batched_mul_B!(tmp, Cμ, plan) - A_batched_mul_B!(C, tmp, Cν) + A_batched_mul_B!(C, tmp, -4Cν) + # seems to be a missing factor of 4 (or something like that...) compared to the POT implementation? + # added the factor of 4 here to ensure reproducibility for the same value of ε. + # https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247 end get_new_cost!(C, plan, tmp, Cμ, Cν) - solver_step = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) - to_check_step = check_convergence + to_check_step = check_convergence + + isconverged = false for iter in 1:maxiter # perform Sinkhorn algorithm - solve!(solver_step) + solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) + solve!(solver) # compute optimal transport plan - plan = sinkhorn_plan(solver_step) + plan = sinkhorn_plan(solver) to_check_step -= 1 if to_check_step == 0 || iter == maxiter # reset counter to_check_step = check_convergence - - # TODO: convergence check - # isconverged, abserror = OptimalTransport.check_convergence(solver) - # @debug string(solver.alg) * - # " (" * - # string(iter) * - # "/" * - # string(maxiter) * - # ": absolute error of source marginal = " * - # string(maximum(abserror)) - + isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan) if isconverged - @debug "$(solver.alg) ($iter/$maxiter): converged" + @debug "Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged" break end + plan_prev .= plan end - update_cost!(solver, C) + get_new_cost!(C, plan, tmp, Cμ, Cν) end return plan end - -# support for `SinkhornGibbs` and `SinkhornStabilized` -function update_cost!(solver::SinkhornSolver{SinkhornGibbs}, C::AbstractMatrix) - cache = solver.cache - @. cache.K = exp(-C / solver.eps) - newsolver = SinkhornSolver( - solver.source, - solver.target, - C, - solver.eps, - solver.alg, - solver.atol, - solver.rtol, - solver.maxiter, - solver.check_convergence, - cache, - solver.convergence_cache, - ) - return newsolver -end diff --git a/test/gromov.jl b/test/gromov.jl new file mode 100644 index 00000000..52ec64f9 --- /dev/null +++ b/test/gromov.jl @@ -0,0 +1,29 @@ +using OptimalTransport + +using Distances +using PythonOT: PythonOT + +using Random +using Test +using LinearAlgebra + +const POT = PythonOT + +Random.seed!(100) + +M, N = 10, 10 + +μ = fill(1/M, M) +μ_spt = rand(M) +ν = fill(1/N, N) +ν_spt = rand(N) + +Cμ = pairwise(SqEuclidean(), μ_spt) +Cν = pairwise(SqEuclidean(), ν_spt) + +γ = OptimalTransport.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01; check_convergence = 10) +γ_pot = PythonOT.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01) + +norm(γ .- γ_pot, 1) +norm(γ, 1) +norm(γ_pot, 1) diff --git a/test/gromov.jl~ b/test/gromov.jl~ new file mode 100644 index 00000000..e69de29b From c22d7e71e67ec49fb181ef4efefbc1c74a05dda4 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Tue, 8 Mar 2022 22:26:51 +1100 Subject: [PATCH 04/23] fixed computation of entropic gromov-wasserstein --- Project.toml | 3 ++- src/#gromov.jl# | 57 +++++++++++++++++++++++++++++++++++++++ src/OptimalTransport.jl | 3 +++ src/gromov.jl | 60 +++++++++++++++-------------------------- test/gromov.jl | 29 ++++++++++++++++++++ 5 files changed, 112 insertions(+), 40 deletions(-) create mode 100644 src/#gromov.jl# create mode 100644 test/gromov.jl diff --git a/Project.toml b/Project.toml index dbd16a25..a6d1aaf1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,14 @@ name = "OptimalTransport" uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" authors = ["zsteve "] -version = "0.3.19" +version = "0.3.20" [deps] ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" diff --git a/src/#gromov.jl# b/src/#gromov.jl# new file mode 100644 index 00000000..72a85935 --- /dev/null +++ b/src/#gromov.jl# @@ -0,0 +1,57 @@ +# Gromov-Wasserstein solver + +abstract type EntropicGromovWasserstein end + +struct EntropicGromovWassersteinGibbs <: EntropicGromovWasserstein + alg_step::Sinkhorn +end + +function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ::AbstractMatrix, Cν::AbstractMatrix, ε::Real, + alg::EntropicGromovWasserstein = EntropicGromovWassersteinGibbs(SinkhornGibbs()); atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) + T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) + C = similar(Cμ, T, size(μ, 1), size(ν, 1)) + tmp = similar(C) + plan = similar(C) + @. plan = μ * ν' + plan_prev = similar(C) + plan_prev .= plan + norm_plan = sum(plan) + + _atol = atol === nothing ? 0 : atol + _rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol + + function get_new_cost!(C, plan, tmp, Cμ, Cν) + A_batched_mul_B!(tmp, Cμ, plan) + A_batched_mul_B!(C, tmp, -4Cν) + # seems to be a missing factor of 4 (or something like that...) compared to the POT implementation? + # added the factor of 4 here to ensure reproducibility for the same value of ε. + # https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247 + end + + get_new_cost!(C, plan, tmp, Cμ, Cν) + to_check_step = check_convergence + + isconverged = false + for iter in 1:maxiter + # perform Sinkhorn algorithm + solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) + solve!(solver) + # compute optimal transport plan + plan = sinkhorn_plan(solver) + + to_check_step -= 1 + if to_check_step == 0 || iter == maxiter + # reset counter + to_check_step = check_convergence + isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan) + if isconverged + @debug "$Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged" + break + end + plan_prev .= plan + end + get_new_cost!(C, plan, tmp, Cμ, Cν) + end + + return plan +end diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 1653431e..7eb165f4 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -13,6 +13,7 @@ using LinearAlgebra using IterativeSolvers using LogExpFunctions: LogExpFunctions using NNlib: NNlib +using Logging export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling export SinkhornBarycenterGibbs @@ -42,4 +43,6 @@ include("quadratic_newton.jl") include("dual/entropic_dual.jl") +include("gromov.jl") + end diff --git a/src/gromov.jl b/src/gromov.jl index 5e3ab14f..b77a5347 100644 --- a/src/gromov.jl +++ b/src/gromov.jl @@ -7,69 +7,51 @@ struct EntropicGromovWassersteinGibbs <: EntropicGromovWasserstein end function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ::AbstractMatrix, Cν::AbstractMatrix, ε::Real, - alg::EntropicGromovWassersteinGibbs; atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) + alg::EntropicGromovWasserstein = EntropicGromovWassersteinGibbs(SinkhornGibbs()); atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) C = similar(Cμ, T, size(μ, 1), size(ν, 1)) tmp = similar(C) plan = similar(C) @. plan = μ * ν' + plan_prev = similar(C) + plan_prev .= plan + norm_plan = sum(plan) + + _atol = atol === nothing ? 0 : atol + _rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol function get_new_cost!(C, plan, tmp, Cμ, Cν) A_batched_mul_B!(tmp, Cμ, plan) - A_batched_mul_B!(C, tmp, Cν) + A_batched_mul_B!(C, tmp, -4Cν) + # seems to be a missing factor of 4 (or something like that...) compared to the POT implementation? + # added the factor of 4 here to ensure reproducibility for the same value of ε. + # https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247 end get_new_cost!(C, plan, tmp, Cμ, Cν) - solver_step = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) - to_check_step = check_convergence + to_check_step = check_convergence + + isconverged = false for iter in 1:maxiter # perform Sinkhorn algorithm - solve!(solver_step) + solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) + solve!(solver) # compute optimal transport plan - plan = sinkhorn_plan(solver_step) + plan = sinkhorn_plan(solver) to_check_step -= 1 if to_check_step == 0 || iter == maxiter # reset counter to_check_step = check_convergence - - # TODO: convergence check - # isconverged, abserror = OptimalTransport.check_convergence(solver) - # @debug string(solver.alg) * - # " (" * - # string(iter) * - # "/" * - # string(maxiter) * - # ": absolute error of source marginal = " * - # string(maximum(abserror)) - + isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan) if isconverged - @debug "$(solver.alg) ($iter/$maxiter): converged" + @debug "Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged" break end + plan_prev .= plan end - update_cost!(solver, C) + get_new_cost!(C, plan, tmp, Cμ, Cν) end return plan end - -# support for `SinkhornGibbs` and `SinkhornStabilized` -function update_cost!(solver::SinkhornSolver{SinkhornGibbs}, C::AbstractMatrix) - cache = solver.cache - @. cache.K = exp(-C / solver.eps) - newsolver = SinkhornSolver( - solver.source, - solver.target, - C, - solver.eps, - solver.alg, - solver.atol, - solver.rtol, - solver.maxiter, - solver.check_convergence, - cache, - solver.convergence_cache, - ) - return newsolver -end diff --git a/test/gromov.jl b/test/gromov.jl new file mode 100644 index 00000000..52ec64f9 --- /dev/null +++ b/test/gromov.jl @@ -0,0 +1,29 @@ +using OptimalTransport + +using Distances +using PythonOT: PythonOT + +using Random +using Test +using LinearAlgebra + +const POT = PythonOT + +Random.seed!(100) + +M, N = 10, 10 + +μ = fill(1/M, M) +μ_spt = rand(M) +ν = fill(1/N, N) +ν_spt = rand(N) + +Cμ = pairwise(SqEuclidean(), μ_spt) +Cν = pairwise(SqEuclidean(), ν_spt) + +γ = OptimalTransport.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01; check_convergence = 10) +γ_pot = PythonOT.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01) + +norm(γ .- γ_pot, 1) +norm(γ, 1) +norm(γ_pot, 1) From 267dfadd5d1681c9c69b4a444d79aa487061b191 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Tue, 8 Mar 2022 23:05:17 +1100 Subject: [PATCH 05/23] exports and tests --- Project.toml | 1 - src/OptimalTransport.jl | 2 ++ src/gromov.jl | 4 ++-- test/Project.toml | 2 ++ test/gromov.jl | 26 ++++++++++++++------------ test/runtests.jl | 4 ++++ 6 files changed, 24 insertions(+), 15 deletions(-) create mode 100644 test/Project.toml diff --git a/Project.toml b/Project.toml index a6d1aaf1..3c57e214 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 7eb165f4..7bc59673 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -18,12 +18,14 @@ using Logging export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling export SinkhornBarycenterGibbs export QuadraticOTNewton +export EntropicGromovWassersteinSinkhorn export sinkhorn, sinkhorn2 export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter export sinkhorn_unbalanced, sinkhorn_unbalanced2 export sinkhorn_divergence export quadreg +export entropic_gromov_wasserstein include("utils.jl") diff --git a/src/gromov.jl b/src/gromov.jl index b77a5347..317d288b 100644 --- a/src/gromov.jl +++ b/src/gromov.jl @@ -2,12 +2,12 @@ abstract type EntropicGromovWasserstein end -struct EntropicGromovWassersteinGibbs <: EntropicGromovWasserstein +struct EntropicGromovWassersteinSinkhorn <: EntropicGromovWasserstein alg_step::Sinkhorn end function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ::AbstractMatrix, Cν::AbstractMatrix, ε::Real, - alg::EntropicGromovWasserstein = EntropicGromovWassersteinGibbs(SinkhornGibbs()); atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) + alg::EntropicGromovWasserstein = EntropicGromovWassersteinSinkhorn(SinkhornGibbs()); atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) C = similar(Cμ, T, size(μ, 1), size(ν, 1)) tmp = similar(C) diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 00000000..794185b8 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,2 @@ +[deps] +OptimalTransport = "7e02d93a-ae51-4f58-b602-d97af76e3b33" diff --git a/test/gromov.jl b/test/gromov.jl index 52ec64f9..a55abbfb 100644 --- a/test/gromov.jl +++ b/test/gromov.jl @@ -11,19 +11,21 @@ const POT = PythonOT Random.seed!(100) -M, N = 10, 10 +@testset "gromov.jl" begin + @testset "entropic_gromov_wasserstein" begin + M, N = 250, 200 -μ = fill(1/M, M) -μ_spt = rand(M) -ν = fill(1/N, N) -ν_spt = rand(N) + μ = fill(1/M, M) + μ_spt = rand(M) + ν = fill(1/N, N) + ν_spt = rand(N) -Cμ = pairwise(SqEuclidean(), μ_spt) -Cν = pairwise(SqEuclidean(), ν_spt) + Cμ = pairwise(SqEuclidean(), μ_spt) + Cν = pairwise(SqEuclidean(), ν_spt) -γ = OptimalTransport.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01; check_convergence = 10) -γ_pot = PythonOT.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01) + γ = entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01; check_convergence = 10) + γ_pot = PythonOT.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01) -norm(γ .- γ_pot, 1) -norm(γ, 1) -norm(γ_pot, 1) + @test γ ≈ γ_pot rtol = 1e-6 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 66314dfa..ffce6ef0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,6 +36,10 @@ const GROUP = get(ENV, "GROUP", "All") @safetestset "Quadratically regularized OT" begin include("quadratic.jl") end + + @safetestset "Gromov-Wasserstein OT" begin + include("gromov.jl") + end end # CUDA requires Julia >= 1.6 From 21609b0dc564e2110e17a159fb1ea7f1483a81bc Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Sun, 13 Mar 2022 09:12:37 +1100 Subject: [PATCH 06/23] formatting --- src/gromov.jl | 19 +++++++++++++++---- test/gpu/simple_gpu.jl | 8 +++++--- test/gromov.jl | 6 +++--- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/gromov.jl b/src/gromov.jl index 317d288b..23d9ad7b 100644 --- a/src/gromov.jl +++ b/src/gromov.jl @@ -2,12 +2,23 @@ abstract type EntropicGromovWasserstein end -struct EntropicGromovWassersteinSinkhorn <: EntropicGromovWasserstein +struct EntropicGromovWassersteinSinkhorn <: EntropicGromovWasserstein alg_step::Sinkhorn end -function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ::AbstractMatrix, Cν::AbstractMatrix, ε::Real, - alg::EntropicGromovWasserstein = EntropicGromovWassersteinSinkhorn(SinkhornGibbs()); atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) +function entropic_gromov_wasserstein( + μ::AbstractVector, + ν::AbstractVector, + Cμ::AbstractMatrix, + Cν::AbstractMatrix, + ε::Real, + alg::EntropicGromovWasserstein=EntropicGromovWassersteinSinkhorn(SinkhornGibbs()); + atol=nothing, + rtol=nothing, + check_convergence=10, + maxiter::Int=1_000, + kwargs..., +) T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) C = similar(Cμ, T, size(μ, 1), size(ν, 1)) tmp = similar(C) @@ -22,7 +33,7 @@ function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ function get_new_cost!(C, plan, tmp, Cμ, Cν) A_batched_mul_B!(tmp, Cμ, plan) - A_batched_mul_B!(C, tmp, -4Cν) + return A_batched_mul_B!(C, tmp, -4Cν) # seems to be a missing factor of 4 (or something like that...) compared to the POT implementation? # added the factor of 4 here to ensure reproducibility for the same value of ε. # https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247 diff --git a/test/gpu/simple_gpu.jl b/test/gpu/simple_gpu.jl index aaca5b86..8f72214d 100644 --- a/test/gpu/simple_gpu.jl +++ b/test/gpu/simple_gpu.jl @@ -97,11 +97,13 @@ Random.seed!(100) @testset "quadreg" begin # use a different reg parameter ε_quad = 1.0f0 - γ = quadreg(cu_μ, cu_ν, cu_C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1f-5, 50)) + γ = quadreg( + cu_μ, cu_ν, cu_C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1.0f-5, 50) + ) # compare with results on the CPU @test convert(Array, γ) ≈ - quadreg(μ, ν, C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1f-5, 50)) atol = - 1f-4 rtol = 1f-4 + quadreg(μ, ν, C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1.0f-5, 50)) atol = + 1.0f-4 rtol = 1.0f-4 end end end diff --git a/test/gromov.jl b/test/gromov.jl index a55abbfb..c65ece78 100644 --- a/test/gromov.jl +++ b/test/gromov.jl @@ -15,15 +15,15 @@ Random.seed!(100) @testset "entropic_gromov_wasserstein" begin M, N = 250, 200 - μ = fill(1/M, M) + μ = fill(1 / M, M) μ_spt = rand(M) - ν = fill(1/N, N) + ν = fill(1 / N, N) ν_spt = rand(N) Cμ = pairwise(SqEuclidean(), μ_spt) Cν = pairwise(SqEuclidean(), ν_spt) - γ = entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01; check_convergence = 10) + γ = entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01; check_convergence=10) γ_pot = PythonOT.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01) @test γ ≈ γ_pot rtol = 1e-6 From 9699e04db95a1601b49525c9b24f031897292902 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Sun, 13 Mar 2022 09:19:35 +1100 Subject: [PATCH 07/23] Update test/gpu/simple_gpu.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/gpu/simple_gpu.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gpu/simple_gpu.jl b/test/gpu/simple_gpu.jl index 8f72214d..f07355d8 100644 --- a/test/gpu/simple_gpu.jl +++ b/test/gpu/simple_gpu.jl @@ -102,7 +102,7 @@ Random.seed!(100) ) # compare with results on the CPU @test convert(Array, γ) ≈ - quadreg(μ, ν, C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1.0f-5, 50)) atol = + quadreg(μ, ν, C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1.0f-5, 50)) atol = 1.0f-4 rtol = 1.0f-4 end end From 85103974ffd010e7f6d0754fbae6a664d2cbea6e Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Sun, 13 Mar 2022 10:02:37 +1100 Subject: [PATCH 08/23] update docstrings --- src/#gromov.jl# | 57 ------------------------------------------------- src/gromov.jl | 23 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 57 deletions(-) delete mode 100644 src/#gromov.jl# diff --git a/src/#gromov.jl# b/src/#gromov.jl# deleted file mode 100644 index 72a85935..00000000 --- a/src/#gromov.jl# +++ /dev/null @@ -1,57 +0,0 @@ -# Gromov-Wasserstein solver - -abstract type EntropicGromovWasserstein end - -struct EntropicGromovWassersteinGibbs <: EntropicGromovWasserstein - alg_step::Sinkhorn -end - -function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ::AbstractMatrix, Cν::AbstractMatrix, ε::Real, - alg::EntropicGromovWasserstein = EntropicGromovWassersteinGibbs(SinkhornGibbs()); atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...) - T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) - C = similar(Cμ, T, size(μ, 1), size(ν, 1)) - tmp = similar(C) - plan = similar(C) - @. plan = μ * ν' - plan_prev = similar(C) - plan_prev .= plan - norm_plan = sum(plan) - - _atol = atol === nothing ? 0 : atol - _rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol - - function get_new_cost!(C, plan, tmp, Cμ, Cν) - A_batched_mul_B!(tmp, Cμ, plan) - A_batched_mul_B!(C, tmp, -4Cν) - # seems to be a missing factor of 4 (or something like that...) compared to the POT implementation? - # added the factor of 4 here to ensure reproducibility for the same value of ε. - # https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247 - end - - get_new_cost!(C, plan, tmp, Cμ, Cν) - to_check_step = check_convergence - - isconverged = false - for iter in 1:maxiter - # perform Sinkhorn algorithm - solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) - solve!(solver) - # compute optimal transport plan - plan = sinkhorn_plan(solver) - - to_check_step -= 1 - if to_check_step == 0 || iter == maxiter - # reset counter - to_check_step = check_convergence - isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan) - if isconverged - @debug "$Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged" - break - end - plan_prev .= plan - end - get_new_cost!(C, plan, tmp, Cμ, Cν) - end - - return plan -end diff --git a/src/gromov.jl b/src/gromov.jl index 23d9ad7b..aa5b8533 100644 --- a/src/gromov.jl +++ b/src/gromov.jl @@ -6,6 +6,29 @@ struct EntropicGromovWassersteinSinkhorn <: EntropicGromovWasserstein alg_step::Sinkhorn end +""" + entropic_gromov_wasserstein( + μ, ν, Cμ, Cν, ε, alg=EntropicGromovWassersteinSinkhorn(SinkhornGibbs()); + atol = nothing, rtol = nothing, check_convergence = 10, maxiter = 1_000, kwargs... + ) + +Computes the transport map for the entropically regularized Gromov-Wasserstein optimal transport problem with source and target +marginals `μ` and `ν` and corresponding cost matrices `Cμ` and `Cν`. That is, we seek `γ` a local minimizer of +```math + \\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\sum_{i, j, i', j'} |C^{(\\mu)}_{i,i'} - C^{(\\nu)}_{j,j'}| \\gamma_{i,j} \\gamma_{i',j'} + \\varepsilon \\Omega(\\gamma), +``` +where ``\\Omega(\\gamma)`` is the entropic regularization term, see e.g. [`sinkhorn`](@ref). + +This function employs the iterative method described in (Section 10.6.4, [^PC19]), which solves a series of Sinkhorn iteration sub-problems to arrive at a solution. Note that the Gromov-Wasserstein problem is non-convex owing to the cross-terms in the +objective function, and thus in general one is guaranteed to arrive at a local optimum. + +Every `check_convergence` steps, the current iteration of `γ` is compared with `γ_prev` (the previous iteration from `check_convergence` ago). +The quantity ``\\| \\gamma - \\gamma_\\text{prev} \\|_1`` is compared against `atol` and `rtol`. + +[^PC19]: Peyré, G. and Cuturi, M., 2019. Computational optimal transport: With applications to data science. Foundations and Trends® in Machine Learning, 11(5-6), pp.355-607. + +See also: [`sinkhorn`](@ref) +""" function entropic_gromov_wasserstein( μ::AbstractVector, ν::AbstractVector, From 20d5885a64e10a6ece2f13f656ac283ec8af9d22 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Sun, 13 Mar 2022 10:03:13 +1100 Subject: [PATCH 09/23] delete cache file --- test/gromov.jl~ | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test/gromov.jl~ diff --git a/test/gromov.jl~ b/test/gromov.jl~ deleted file mode 100644 index e69de29b..00000000 From df41c287317497a445a52360234d35dc3d7782d1 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Sun, 13 Mar 2022 10:32:59 +1100 Subject: [PATCH 10/23] add docs and format --- docs/src/index.md | 9 +++++++++ test/gpu/simple_gpu.jl | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index ca941f9d..6f7e6c02 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -67,6 +67,15 @@ Currently the following algorithms for solving quadratically regularised optimal QuadraticOTNewton ``` +## Gromov-Wasserstein optimal transport + +```@docs +entropic_gromov_wasserstein +``` + +Currently, only entropy-regularised Gromov-Wasserstein is supported. For exact computations, we refer the user to +[PythonOT](https://github.com/JuliaOptimalTransport/PythonOT.jl) to access functionality from the [Python Optimal Transport library](https://pythonot.github.io/). + ## Dual ```@docs diff --git a/test/gpu/simple_gpu.jl b/test/gpu/simple_gpu.jl index f07355d8..8f72214d 100644 --- a/test/gpu/simple_gpu.jl +++ b/test/gpu/simple_gpu.jl @@ -102,7 +102,7 @@ Random.seed!(100) ) # compare with results on the CPU @test convert(Array, γ) ≈ - quadreg(μ, ν, C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1.0f-5, 50)) atol = + quadreg(μ, ν, C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1.0f-5, 50)) atol = 1.0f-4 rtol = 1.0f-4 end end From a7c1a381b04655bc89905bb813993c894319a0c8 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Sun, 13 Mar 2022 10:42:29 +1100 Subject: [PATCH 11/23] remove unnecessary Logging import --- gpu/Manifest.toml | 129 ++++++++++++++++++++++++++++++++++++++++ gpu/Project.toml | 2 + src/OptimalTransport.jl | 1 - 3 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 gpu/Manifest.toml create mode 100644 gpu/Project.toml diff --git a/gpu/Manifest.toml b/gpu/Manifest.toml new file mode 100644 index 00000000..3fca1e6f --- /dev/null +++ b/gpu/Manifest.toml @@ -0,0 +1,129 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.7.0" +manifest_format = "2.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" + +[[deps.Conda]] +deps = ["Downloads", "JSON", "VersionParsing"] +git-tree-sha1 = "6e47d11ea2776bc5627421d59cdcc1296c058071" +uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d" +version = "1.7.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.3" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.9" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" + +[[deps.Parsers]] +deps = ["Dates"] +git-tree-sha1 = "85b5da0fa43588c75bb1ff986493443f821c70b7" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.2.3" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.PyCall]] +deps = ["Conda", "Dates", "Libdl", "LinearAlgebra", "MacroTools", "Serialization", "VersionParsing"] +git-tree-sha1 = "1fc929f47d7c151c839c5fc1375929766fb8edcc" +uuid = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" +version = "1.93.1" + +[[deps.Random]] +deps = ["SHA", "Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.VersionParsing]] +git-tree-sha1 = "58d6e80b4ee071f5efd07fda82cb9fbe17200868" +uuid = "81def892-9a0e-5fdd-b105-ffc91e053289" +version = "1.3.0" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" diff --git a/gpu/Project.toml b/gpu/Project.toml new file mode 100644 index 00000000..7d22f28a --- /dev/null +++ b/gpu/Project.toml @@ -0,0 +1,2 @@ +[deps] +PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 7bc59673..685c3ab6 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -13,7 +13,6 @@ using LinearAlgebra using IterativeSolvers using LogExpFunctions: LogExpFunctions using NNlib: NNlib -using Logging export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling export SinkhornBarycenterGibbs From 19e4cabec86343411c5dac621024454f125602e7 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Sun, 13 Mar 2022 11:17:55 +1100 Subject: [PATCH 12/23] fix missing power of 2 --- src/gromov.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gromov.jl b/src/gromov.jl index aa5b8533..a6989be3 100644 --- a/src/gromov.jl +++ b/src/gromov.jl @@ -15,7 +15,7 @@ end Computes the transport map for the entropically regularized Gromov-Wasserstein optimal transport problem with source and target marginals `μ` and `ν` and corresponding cost matrices `Cμ` and `Cν`. That is, we seek `γ` a local minimizer of ```math - \\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\sum_{i, j, i', j'} |C^{(\\mu)}_{i,i'} - C^{(\\nu)}_{j,j'}| \\gamma_{i,j} \\gamma_{i',j'} + \\varepsilon \\Omega(\\gamma), + \\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\sum_{i, j, i', j'} |C^{(\\mu)}_{i,i'} - C^{(\\nu)}_{j,j'}|^2 \\gamma_{i,j} \\gamma_{i',j'} + \\varepsilon \\Omega(\\gamma), ``` where ``\\Omega(\\gamma)`` is the entropic regularization term, see e.g. [`sinkhorn`](@ref). From 6e3ac4cb54f53f20a6c6b917f0ce7d5fe13b0289 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Sun, 28 Aug 2022 22:33:14 +1000 Subject: [PATCH 13/23] update version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3c57e214..0ccb544c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OptimalTransport" uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" authors = ["zsteve "] -version = "0.3.20" +version = "0.3.21" [deps] ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31" From 5c376ae45014ba22196ae6ef78f605b18e7529f2 Mon Sep 17 00:00:00 2001 From: stephen zhang Date: Tue, 20 Dec 2022 14:49:01 +1100 Subject: [PATCH 14/23] add docs workflow --- .github/workflows/main | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 .github/workflows/main diff --git a/.github/workflows/main b/.github/workflows/main new file mode 100644 index 00000000..e69de29b From af2a49324b1a46a91144dc13189e2a8e7dab902d Mon Sep 17 00:00:00 2001 From: stephen zhang Date: Wed, 25 Jan 2023 12:44:22 +1100 Subject: [PATCH 15/23] add Gromov-Wasserstein to readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c66d42f7..48b58ddc 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [![Coveralls](https://coveralls.io/repos/github/JuliaOptimalTransport/OptimalTransport.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaOptimalTransport/OptimalTransport.jl?branch=master) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) -This package provides some [Julia](https://julialang.org/) implementations of algorithms for computational [optimal transport](https://optimaltransport.github.io/), including the Earth-Mover's (Wasserstein) distance, Sinkhorn algorithm for entropically regularized optimal transport as well as some variants or extensions. +This package provides some [Julia](https://julialang.org/) implementations of algorithms for computational [optimal transport](https://optimaltransport.github.io/), including the Earth-Mover's (Wasserstein) distance, Sinkhorn algorithm for entropically regularized optimal transport as well as variants and extensions, including unbalanced transport and Gromov-Wasserstein matching. Notably, OptimalTransport.jl provides GPU acceleration through [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl/) and [NNlibCUDA.jl](https://github.com/FluxML/NNlibCUDA.jl). From 6bc3127004d858171576b601115f0365af977fdb Mon Sep 17 00:00:00 2001 From: stephen zhang Date: Wed, 25 Jan 2023 12:46:20 +1100 Subject: [PATCH 16/23] bump Julia ver for CI --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 953bfade..6c1f04ea 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ jobs: strategy: matrix: version: - - '1.6' + - '1.8' - '1' - 'nightly' os: From a806f0f337633c7d5b89350a3c3fc5fd7ab4950a Mon Sep 17 00:00:00 2001 From: stephen zhang Date: Wed, 25 Jan 2023 13:06:28 +1100 Subject: [PATCH 17/23] minor edit to runtests --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index ffce6ef0..ad1a7f4d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using OptimalTransport -using Pkg: Pkg +using Pkg using SafeTestsets using Test @@ -43,7 +43,7 @@ const GROUP = get(ENV, "GROUP", "All") end # CUDA requires Julia >= 1.6 - if (GROUP == "All" || GROUP == "GPU") && VERSION >= v"1.6" + if (GROUP == "All" || GROUP == "GPU") && VERSION >= v"1.8" # activate separate environment: CUDA can't be added to test/Project.toml since it # is not available on older Julia versions Pkg.activate("gpu") From f704397cc7ec19fc63fbbb9d0e1dac2f143014f9 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Fri, 27 Jan 2023 18:37:48 +1100 Subject: [PATCH 18/23] Update .github/workflows/CI.yml Co-authored-by: David Widmann --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6c1f04ea..953bfade 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ jobs: strategy: matrix: version: - - '1.8' + - '1.6' - '1' - 'nightly' os: From 71351b90cdcf6f3503f4e832b2a61761bf7e98dc Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Fri, 27 Jan 2023 18:40:35 +1100 Subject: [PATCH 19/23] Update test/runtests.jl Co-authored-by: David Widmann --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index ad1a7f4d..30969e3b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using OptimalTransport -using Pkg +using Pkg: Pkg using SafeTestsets using Test From f2acc569274a34bebc976f71d8250de95d16872a Mon Sep 17 00:00:00 2001 From: stephen zhang Date: Fri, 27 Jan 2023 18:41:24 +1100 Subject: [PATCH 20/23] delete junk files/dirs --- gpu/Manifest.toml | 129 ---------------------------------------------- gpu/Project.toml | 2 - test/Project.toml | 2 - 3 files changed, 133 deletions(-) delete mode 100644 gpu/Manifest.toml delete mode 100644 gpu/Project.toml delete mode 100644 test/Project.toml diff --git a/gpu/Manifest.toml b/gpu/Manifest.toml deleted file mode 100644 index 3fca1e6f..00000000 --- a/gpu/Manifest.toml +++ /dev/null @@ -1,129 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.7.0" -manifest_format = "2.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" - -[[deps.Conda]] -deps = ["Downloads", "JSON", "VersionParsing"] -git-tree-sha1 = "6e47d11ea2776bc5627421d59cdcc1296c058071" -uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d" -version = "1.7.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" - -[[deps.JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.3" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.9" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" - -[[deps.Parsers]] -deps = ["Dates"] -git-tree-sha1 = "85b5da0fa43588c75bb1ff986493443f821c70b7" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.2.3" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.PyCall]] -deps = ["Conda", "Dates", "Libdl", "LinearAlgebra", "MacroTools", "Serialization", "VersionParsing"] -git-tree-sha1 = "1fc929f47d7c151c839c5fc1375929766fb8edcc" -uuid = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" -version = "1.93.1" - -[[deps.Random]] -deps = ["SHA", "Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.VersionParsing]] -git-tree-sha1 = "58d6e80b4ee071f5efd07fda82cb9fbe17200868" -uuid = "81def892-9a0e-5fdd-b105-ffc91e053289" -version = "1.3.0" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" diff --git a/gpu/Project.toml b/gpu/Project.toml deleted file mode 100644 index 7d22f28a..00000000 --- a/gpu/Project.toml +++ /dev/null @@ -1,2 +0,0 @@ -[deps] -PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index 794185b8..00000000 --- a/test/Project.toml +++ /dev/null @@ -1,2 +0,0 @@ -[deps] -OptimalTransport = "7e02d93a-ae51-4f58-b602-d97af76e3b33" From 06353057e58be5c83a6d27f66a27b973938870d5 Mon Sep 17 00:00:00 2001 From: stephen zhang Date: Sat, 28 Jan 2023 00:19:27 +1100 Subject: [PATCH 21/23] revert runtests.jl --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 30969e3b..ffce6ef0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,7 +43,7 @@ const GROUP = get(ENV, "GROUP", "All") end # CUDA requires Julia >= 1.6 - if (GROUP == "All" || GROUP == "GPU") && VERSION >= v"1.8" + if (GROUP == "All" || GROUP == "GPU") && VERSION >= v"1.6" # activate separate environment: CUDA can't be added to test/Project.toml since it # is not available on older Julia versions Pkg.activate("gpu") From c3efe5a981b98e91532d4f318809ab0972ea272a Mon Sep 17 00:00:00 2001 From: stephen zhang Date: Sat, 28 Jan 2023 01:15:02 +1100 Subject: [PATCH 22/23] avoid unnecessary allocations --- src/gromov.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gromov.jl b/src/gromov.jl index a6989be3..7821ae38 100644 --- a/src/gromov.jl +++ b/src/gromov.jl @@ -56,7 +56,8 @@ function entropic_gromov_wasserstein( function get_new_cost!(C, plan, tmp, Cμ, Cν) A_batched_mul_B!(tmp, Cμ, plan) - return A_batched_mul_B!(C, tmp, -4Cν) + lmul!(-4, tmp) + return A_batched_mul_B!(C, tmp, Cν) # seems to be a missing factor of 4 (or something like that...) compared to the POT implementation? # added the factor of 4 here to ensure reproducibility for the same value of ε. # https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247 @@ -77,7 +78,8 @@ function entropic_gromov_wasserstein( if to_check_step == 0 || iter == maxiter # reset counter to_check_step = check_convergence - isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan) + plan_prev .-= plan + isconverged = sum(abs, plan_prev) < max(_atol, _rtol * norm_plan) if isconverged @debug "Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged" break From 39f0b362f5e64edbc6bf17fbe9a02f947c6cf990 Mon Sep 17 00:00:00 2001 From: stephen zhang Date: Sat, 28 Jan 2023 01:18:53 +1100 Subject: [PATCH 23/23] format --- src/gromov.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gromov.jl b/src/gromov.jl index 7821ae38..9c69a116 100644 --- a/src/gromov.jl +++ b/src/gromov.jl @@ -78,7 +78,7 @@ function entropic_gromov_wasserstein( if to_check_step == 0 || iter == maxiter # reset counter to_check_step = check_convergence - plan_prev .-= plan + plan_prev .-= plan isconverged = sum(abs, plan_prev) < max(_atol, _rtol * norm_plan) if isconverged @debug "Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged"