diff --git a/docs/src/index.md b/docs/src/index.md index ca941f9d..f3bd10fd 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -40,6 +40,7 @@ Currently the following variants of the Sinkhorn algorithm are supported: SinkhornGibbs SinkhornStabilized SinkhornEpsilonScaling +Greenkhorn ``` The following methods are deprecated and will be removed: diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 1653431e..83f53c74 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -17,6 +17,7 @@ using NNlib: NNlib export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling export SinkhornBarycenterGibbs export QuadraticOTNewton +export Greenkhorn export sinkhorn, sinkhorn2 export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter @@ -36,6 +37,7 @@ include("entropic/sinkhorn_unbalanced.jl") include("entropic/sinkhorn_barycenter.jl") include("entropic/sinkhorn_barycenter_gibbs.jl") include("entropic/sinkhorn_solve.jl") +include("entropic/greenkhorn.jl") include("quadratic.jl") include("quadratic_newton.jl") diff --git a/src/entropic/greenkhorn.jl b/src/entropic/greenkhorn.jl new file mode 100644 index 00000000..f9b04a5b --- /dev/null +++ b/src/entropic/greenkhorn.jl @@ -0,0 +1,114 @@ +# Greenkhorn is a greedy version of the Sinkhorn algorithm +# This method is from https://arxiv.org/pdf/1705.09634.pdf +# Code is based on implementation from package POT + + +""" + Greenkhorn() + +Greenkhorn is a greedy version of the Sinkhorn algorithm. +""" +struct Greenkhorn <: Sinkhorn end + +struct GreenkhornCache{U,V,KT} + u::U + v::V + K::KT + Kv::U #placeholder + G::KT + du::U + dv::V +end + +Base.show(io::IO, ::Greenkhorn) = print(io, "Greenkhorn algorithm") + +function build_cache( + ::Type{T}, + ::Greenkhorn, + size2::Tuple, + μ::AbstractVecOrMat, + ν::AbstractVecOrMat, + C::AbstractMatrix, + ε::Real, +) where {T} + # compute Gibbs kernel (has to be mutable for ε-scaling algorithm) + K = similar(C, T) + @. K = exp(-C / ε) + + # create and initialize dual potentials + u = similar(μ, T, size(μ, 1), size2...) + v = similar(ν, T, size(ν, 1), size2...) + fill!(u, one(T)/size(μ, 1)) + fill!(v, one(T)/size(ν, 1)) + + G = sinkhorn_plan(u, v, K) + # G = diagm(u) * K * diagm(v) + + Kv = similar(u) + + # This is me triying to get the `batch tests to work` + # improve this! + # if (length(size(μ)) == 2 && length(size(ν)) == 1) + # du = reshape(sum(G, dims=2), size(μ)) - μ + # dv = reshape(sum(G, dims=1),size(v)) - repeat(ν,1,size(v)[2]) + # elseif (length(size(μ)) == 1 && length(size(ν)) == 2) + # du = reshape(sum(G, dims=2),size(u)) - repeat(μ,1,size(u)[2]) + # dv = reshape(sum(G, dims=1), size(ν)) - ν + # else + du = reshape(sum(G, dims=2), size(μ)) - μ + dv = reshape(sum(G, dims=1), size(ν)) - ν + # end + + + return GreenkhornCache(u, v, K, Kv, G, du, dv) +end + +prestep!(::SinkhornSolver{Greenkhorn}, ::Int) = nothing + +init_step!(solver::SinkhornSolver{<:Greenkhorn}) = nothing + +function step!(solver::SinkhornSolver{<:Greenkhorn}, iter::Int) + μ = solver.source + ν = solver.target + cache = solver.cache + u = cache.u + v = cache.v + K = cache.K + G = cache.G + Δμ= cache.du + Δν= cache.dv + + # The implementation in POT does not compute `μ .* log.(μ ./ sum(G', dims=1)[:])` + # or `ν .* log.(ν ./ sum(G', dims=2)[:])`. Yet, this term is present in the original + # paper, where it uses ρ(a,b) = b - a + a log(a/b). + # ρμ = abs.(Δμ + μ .* log.(μ ./ sum(G', dims=1)[:])) + # ρν = abs.(Δν + ν .* log.(ν ./ sum(G', dims=2)[:])) + + i₁ = argmax(abs.(Δμ)) + i₂ = argmax(abs.(Δν)) + + # if ρμ[i₁]> ρν[i₂] + if abs(Δμ[i₁]) > abs(Δν[i₂]) + old_u = u[i₁] + u[i₁] = μ[i₁]/ (K[i₁,:] ⋅ v) + Δ = u[i₁] - old_u + G[i₁, :] = u[i₁] * K[i₁,:] .* v + Δμ[i₁] = u[i₁] * (K[i₁,:] ⋅ v) - μ[i₁] + @. Δν = Δν + Δ * K[i₁,:] * v + else + old_v = v[i₂] + v[i₂] = ν[i₂]/ (K[:,i₂] ⋅ u) + Δ = v[i₂] - old_v + G[:, i₂] = v[i₂] * K[:,i₂] .* u + Δν[i₂] = v[i₂] * (K[:,i₂] ⋅ u) - ν[i₂] + @. Δμ = Δμ + Δ * K[:,i₂] * u + end + + A_batched_mul_B!(solver.cache.Kv, K, v) # Compute to evaluate convergence +end + +function sinkhorn_plan(solver::SinkhornSolver{Greenkhorn}) + cache = solver.cache + return cache.G +end + diff --git a/test/entropic/greenkhorn.jl b/test/entropic/greenkhorn.jl new file mode 100644 index 00000000..2b9ab2b4 --- /dev/null +++ b/test/entropic/greenkhorn.jl @@ -0,0 +1,238 @@ +using OptimalTransport + +using Distances +using ForwardDiff +using ReverseDiff +using LogExpFunctions +using PythonOT: PythonOT + +using LinearAlgebra +using Random +using Test + +const POT = PythonOT + +Random.seed!(100) + +@testset "greenkhorn.jl" begin + # size of source and target + M = 250 + N = 200 + + # create two random histograms + μ = normalize!(rand(M), 1) + ν = normalize!(rand(N), 1) + + # create random cost matrix + C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) + + # regularization parameter + ε = 0.01 + + @testset "example" begin + # compute optimal transport plan and optimal transport cost + γ = sinkhorn(μ, ν, C, ε, Greenkhorn(); maxiter=200_000, rtol=1e-9) + c = sinkhorn2(μ, ν, C, ε, Greenkhorn(); maxiter=200_000, rtol=1e-9) + + # check that plan and cost are consistent + @test c ≈ dot(γ, C) + + # compare with POT + γ_pot = POT.sinkhorn(μ, ν, C, ε; numItermax=5_000, stopThr=1e-9) + c_pot = POT.sinkhorn2(μ, ν, C, ε; numItermax=5_000, stopThr=1e-9)[1] + @test γ_pot ≈ γ rtol = 1e-6 + @test c_pot ≈ c rtol = 1e-7 + + # compute optimal transport cost with regularization term + c_w_regularization = sinkhorn2( + μ, ν, C, ε, Greenkhorn(); maxiter=200_000, regularization=true + ) + @test c_w_regularization ≈ c + ε * sum(x -> iszero(x) ? x : x * log(x), γ) + @test c_w_regularization ≈ + sinkhorn2(μ, ν, C, ε; maxiter=5_000, regularization=true) + + # # ensure that provided plan is used and correct + c2 = sinkhorn2(similar(μ), similar(ν), C, rand(), Greenkhorn(); plan=γ) + @test c2 ≈ c + @test c2 == sinkhorn2(similar(μ), similar(ν), C, rand(); plan=γ) + c2_w_regularization = sinkhorn2( + similar(μ), similar(ν), C, ε, Greenkhorn(); plan=γ, regularization=true + ) + @test c2_w_regularization ≈ c_w_regularization + @test c2_w_regularization == + sinkhorn2(similar(μ), similar(ν), C, ε; plan=γ, regularization=true) + + + ################################################################ + # FIX BATCHES CASE!!! Not working for Greenkhorn implementation# + ################################################################ + + # # batches of histograms + # d = 10 + # for (size2_μ, size2_ν) in + # (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) + # # generate batches of histograms + # μ_batch = repeat(μ, 1, size2_μ...) + # ν_batch = repeat(ν, 1, size2_ν...) + + # # compute optimal transport plan and check that it is consistent with the + # # plan for individual histograms + # γ_all = sinkhorn( + # μ_batch, ν_batch, C, ε, Greenkhorn(); maxiter=5_000, rtol=1e-9 + # ) + # @test size(γ_all) == (M, N, d) + # @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) + # @test γ_all == sinkhorn(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9) + + # # compute optimal transport cost and check that it is consistent with the + # # cost for individual histograms + # c_all = sinkhorn2( + # μ_batch, ν_batch, C, ε, Greenkhorn(); maxiter=5_000, rtol=1e-9 + # ) + # @test size(c_all) == (d,) + # @test all(x ≈ c for x in c_all) + # @test c_all == sinkhorn2(μ_batch, ν_batch, C, ε; maxiter=5_000, rtol=1e-9) + # end + end + + # different element type + @testset "Float32" begin + # create histograms and cost matrix with element type `Float32` + μ32 = map(Float32, μ) + ν32 = map(Float32, ν) + C32 = map(Float32, C) + ε32 = Float32(ε) + + # compute optimal transport plan and optimal transport cost + γ = sinkhorn(μ32, ν32, C32, ε32, Greenkhorn(); maxiter=200_000, rtol=1e-6) + c = sinkhorn2(μ32, ν32, C32, ε32, Greenkhorn(); maxiter=200_000, rtol=1e-6) + @test eltype(γ) === Float32 + @test typeof(c) === Float32 + + # check that plan and cost are consistent + @test c ≈ dot(γ, C32) + + # compare with default algorithm + γ_default = sinkhorn(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) + c_default = sinkhorn2(μ32, ν32, C32, ε32; maxiter=5_000, rtol=1e-6) + @test γ_default ≈ γ rtol=1e-4 + @test c_default ≈ c rtol=1e-4 + + # compare with POT + γ_pot = POT.sinkhorn(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6) + c_pot = POT.sinkhorn2(μ32, ν32, C32, ε32; numItermax=5_000, stopThr=1e-6)[1] + @test map(Float32, γ_pot) ≈ γ rtol = 1e-3 + @test Float32(c_pot) ≈ c rtol = 1e-3 + + ################################################################ + # FIX BATCHES CASE!!! Not working for Greenkhorn implementation# + ################################################################ + + # batches of histograms + # d = 10 + # for (size2_μ, size2_ν) in + # (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) + # # generate batches of histograms + # μ32_batch = repeat(μ32, 1, size2_μ...) + # ν32_batch = repeat(ν32, 1, size2_ν...) + + # # compute optimal transport plan and check that it is consistent with the + # # plan for individual histograms + # γ_all = sinkhorn( + # μ32_batch, ν32_batch, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6 + # ) + # @test size(γ_all) == (M, N, d) + # @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) + # @test γ_all == + # sinkhorn(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) + + # # compute optimal transport cost and check that it is consistent with the + # # cost for individual histograms + # c_all = sinkhorn2( + # μ32_batch, ν32_batch, C32, ε32, Greenkhorn(); maxiter=5_000, rtol=1e-6 + # ) + # @test size(c_all) == (d,) + # @test all(x ≈ c for x in c_all) + # @test c_all == + # sinkhorn2(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) + # end + end + + + ################################################################ + # FIX AD !!! Not working for Greenkhorn implementation # + ################################################################ + + # https://github.com/JuliaOptimalTransport/OptimalTransport.jl/issues/86 + # @testset "AD" begin + # # compute gradients with respect to source and target marginals separately and + # # together. test against gradient computed using analytic formula of Proposition 2.3 of + # # Cuturi, Marco, and Gabriel Peyré. "A smoothed dual approach for variational Wasserstein problems." SIAM Journal on Imaging Sciences 9.1 (2016): 320-343. + # # + # ε = 0.05 # use a larger ε to avoid having to do many iterations + # # target marginal + # for Diff in [ReverseDiff, ForwardDiff] + # ∇ = Diff.gradient(log.(ν)) do xs + # sinkhorn2(μ, softmax(xs), C, ε, Greenkhorn(); regularization=true) + # end + # ∇default = Diff.gradient(log.(ν)) do xs + # sinkhorn2(μ, softmax(xs), C, ε; regularization=true) + # end + # @test ∇ == ∇default + + # solver = OptimalTransport.build_solver(μ, ν, C, ε, Greenkhorn()) + # OptimalTransport.solve!(solver) + # # helper function + # function dualvar_to_grad(x, ε) + # x = -ε * log.(x) + # x .-= sum(x) / size(x, 1) + # return -x + # end + # ∇_ot = dualvar_to_grad(solver.cache.v, ε) + # # chain rule because target measure parameterised by softmax + # J_softmax = ForwardDiff.jacobian(log.(ν)) do xs + # softmax(xs) + # end + # ∇analytic_target = J_softmax * ∇_ot + # # check that gradient obtained by AD matches the analytic formula + # @test ∇ ≈ ∇analytic_target rtol = 1e-6 + + # # source marginal + # ∇ = Diff.gradient(log.(μ)) do xs + # sinkhorn2(softmax(xs), ν, C, ε, Greenkhorn(); regularization=true) + # end + # ∇default = Diff.gradient(log.(μ)) do xs + # sinkhorn2(softmax(xs), ν, C, ε; regularization=true) + # end + # @test ∇ == ∇default + + # # check that gradient obtained by AD matches the analytic formula + # solver = OptimalTransport.build_solver(μ, ν, C, ε, Greenkhorn()) + # OptimalTransport.solve!(solver) + # J_softmax = ForwardDiff.jacobian(log.(μ)) do xs + # softmax(xs) + # end + # ∇_ot = dualvar_to_grad(solver.cache.u, ε) + # ∇analytic_source = J_softmax * ∇_ot + # @test ∇ ≈ ∇analytic_source rtol = 1e-6 + + # # both marginals + # ∇ = Diff.gradient(log.(vcat(μ, ν))) do xs + # sinkhorn2( + # softmax(xs[1:M]), + # softmax(xs[(M + 1):end]), + # C, + # ε, + # Greenkhorn(); + # regularization=true, + # ) + # end + # ∇default = Diff.gradient(log.(vcat(μ, ν))) do xs + # sinkhorn2(softmax(xs[1:M]), softmax(xs[(M + 1):end]), C, ε; regularization=true) + # end + # @test ∇ == ∇default + # ∇analytic = vcat(∇analytic_source, ∇analytic_target) + # @test ∇ ≈ ∇analytic rtol = 1e-6 + # end + # end +end diff --git a/test/runtests.jl b/test/runtests.jl index 66314dfa..77c13fdc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,9 @@ const GROUP = get(ENV, "GROUP", "All") @safetestset "Sinkhorn divergence" begin include(joinpath("entropic", "sinkhorn_divergence.jl")) end + @safetestset "Greenkhorn" begin + include(joinpath("entropic", "greenkhorn.jl")) + end end @safetestset "Quadratically regularized OT" begin