From 9da044cb33c4cb9c9c3eb5b67b07c6660df77f35 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Sun, 21 Aug 2022 17:34:49 +1000 Subject: [PATCH] Unbalanced divergence (#170) * first stab at unbalanced sinkhorn divergence * fixed point in log domain * increment version number * add docstrings for sinkhorn_divergence_unbalanced * format * clean emacs tmp files * clean emacs tmp files * missing math tag * missing backslash * add sinkhorn_divergence_unbalanced to docs * increase epsilon for gpu sinkhorn tests * format * udpate examples/OneDimension * add StatsPlots deps to examples/OneDimension * change Julia LTS from 1.0 to 1.6 Co-authored-by: Stephen Zhang --- .github/workflows/CI.yml | 2 +- .gitignore | 5 + Project.toml | 2 +- docs/src/index.md | 1 + examples/OneDimension/Manifest.toml | 303 +++++++++++++++++---------- examples/OneDimension/Project.toml | 1 - src/OptimalTransport.jl | 2 +- src/entropic/sinkhorn_solve.jl | 12 +- src/entropic/sinkhorn_stabilized.jl | 10 +- src/entropic/sinkhorn_unbalanced.jl | 148 +++++++++++-- src/quadratic_newton.jl | 12 +- test/entropic/sinkhorn_gibbs.jl | 8 +- test/entropic/sinkhorn_unbalanced.jl | 23 ++ test/gpu/simple_gpu.jl | 10 +- test/utils.jl | 2 +- 15 files changed, 381 insertions(+), 160 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 556390a4..953bfade 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ jobs: strategy: matrix: version: - - '1.0' + - '1.6' - '1' - 'nightly' os: diff --git a/.gitignore b/.gitignore index 571bf2a4..af0ccd7c 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,8 @@ docs/src/examples/ # Files generated by Jupyter Notebooks *.ipynb_checkpoints *.ipynb + +# emacs temp files +*~undo-tree~* +\#*\# +.\#* diff --git a/Project.toml b/Project.toml index dbd16a25..3c57e214 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OptimalTransport" uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" authors = ["zsteve "] -version = "0.3.19" +version = "0.3.20" [deps] ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31" diff --git a/docs/src/index.md b/docs/src/index.md index ca941f9d..2e80ae22 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -54,6 +54,7 @@ sinkhorn_stabilized_epsscaling ```@docs sinkhorn_unbalanced sinkhorn_unbalanced2 +sinkhorn_divergence_unbalanced ``` ## Quadratically regularised optimal transport diff --git a/examples/OneDimension/Manifest.toml b/examples/OneDimension/Manifest.toml index 19762eeb..f456ee23 100644 --- a/examples/OneDimension/Manifest.toml +++ b/examples/OneDimension/Manifest.toml @@ -1,13 +1,14 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.7.1" +julia_version = "1.8.0" manifest_format = "2.0" +project_hash = "824119576de9361cc1136b4911ed53b8e992d091" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] -git-tree-sha1 = "6f1d9bc1c08f9f4a8fa92e3ea3cb50153a1b40d4" +git-tree-sha1 = "69f7020bd72f069c219b5e8c236c1fa90d2cb409" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.1.0" +version = "1.2.1" [[deps.Adapt]] deps = ["LinearAlgebra"] @@ -17,18 +18,19 @@ version = "3.3.3" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" [[deps.Arpack]] -deps = ["Arpack_jll", "Libdl", "LinearAlgebra"] -git-tree-sha1 = "2ff92b71ba1747c5fdd541f8fc87736d82f40ec9" +deps = ["Arpack_jll", "Libdl", "LinearAlgebra", "Logging"] +git-tree-sha1 = "91ca22c4b8437da89b030f08d71db55a379ce958" uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" -version = "0.4.0" +version = "0.5.3" [[deps.Arpack_jll]] -deps = ["Libdl", "OpenBLAS_jll", "Pkg"] -git-tree-sha1 = "e214a9b9bd1b4e1b4f15b22c0994862b66af7ff7" +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "OpenBLAS_jll", "Pkg"] +git-tree-sha1 = "5ba6c757e8feccf03a1554dfaf3e26b3cfc7fd5e" uuid = "68821587-b530-5797-8361-c406ea357684" -version = "3.5.0+3" +version = "3.5.1+1" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -91,16 +93,22 @@ uuid = "944b1d66-785c-5afd-91f1-9de20f533193" version = "0.7.0" [[deps.ColorSchemes]] -deps = ["ColorTypes", "Colors", "FixedPointNumbers", "Random"] -git-tree-sha1 = "6b6f04f93710c71550ec7e16b650c1b9a612d0b6" +deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "Random"] +git-tree-sha1 = "1fd869cc3875b57347f7027521f561cf46d1fcd8" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.16.0" +version = "3.19.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" +git-tree-sha1 = "eb7f0f8307f71fac7c606984ea5fb2817275d6e4" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.0" +version = "0.11.4" + +[[deps.ColorVectorSpace]] +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "SpecialFunctions", "Statistics", "TensorCore"] +git-tree-sha1 = "d08c20eef1f2cbc6e60fd3612ac4340b89fea322" +uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" +version = "0.9.9" [[deps.Colors]] deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] @@ -117,12 +125,12 @@ version = "3.41.0" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "0.5.2+0" [[deps.Contour]] -deps = ["StaticArrays"] -git-tree-sha1 = "9f02045d934dc030edad45944ea80dbd1f0ebea7" +git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" -version = "0.5.7" +version = "0.6.2" [[deps.DataAPI]] git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8" @@ -183,8 +191,9 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" version = "0.8.6" [[deps.Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" [[deps.EarCut_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -200,9 +209,14 @@ version = "0.2.1" [[deps.Expat_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "b3bfd02e98aedfa5cf885665493c5598c350cd2f" +git-tree-sha1 = "bad72f730e9e91c08d9427d5e8db95478a3c323d" uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.2.10+0" +version = "2.4.8+0" + +[[deps.Extents]] +git-tree-sha1 = "5e1e4c53fa39afe63a7d356e30452249365fba99" +uuid = "411431e0-e8b7-467b-b5e0-f676ba4f2910" +version = "0.1.1" [[deps.FFMPEG]] deps = ["FFMPEG_jll"] @@ -211,16 +225,16 @@ uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" version = "0.4.1" [[deps.FFMPEG_jll]] -deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "Pkg", "Zlib_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] -git-tree-sha1 = "d8a578692e3077ac998b50c0217dfd67f21d1e5f" +deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "Pkg", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] +git-tree-sha1 = "ccd479984c7838684b3ac204b716c89955c76623" uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" -version = "4.4.0+0" +version = "4.4.2+0" [[deps.FFTW]] deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] -git-tree-sha1 = "463cb335fa22c4ebacfd1faba5fde14edb80d96c" +git-tree-sha1 = "90630efff0894f8142308e334473eba54c433549" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -version = "1.4.5" +version = "1.5.0" [[deps.FFTW_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -228,6 +242,9 @@ git-tree-sha1 = "c6033cc3892d0ef5bb9cd29b7f2f0331ea5184ea" uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" version = "3.3.10+0" +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + [[deps.FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] git-tree-sha1 = "8756f9935b7ccc9064c6eef0bff0ad643df733a3" @@ -266,27 +283,33 @@ version = "1.0.10+0" [[deps.GLFW_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Pkg", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll"] -git-tree-sha1 = "0c603255764a1fa0b61752d2bec14cfbd18f7fe8" +git-tree-sha1 = "d972031d28c8c8d9d7b41a536ad7bb0c2579caca" uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89" -version = "3.3.5+1" +version = "3.3.8+0" [[deps.GR]] deps = ["Base64", "DelimitedFiles", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Printf", "Random", "RelocatableFolders", "Serialization", "Sockets", "Test", "UUIDs"] -git-tree-sha1 = "4a740db447aae0fbeb3ee730de1afbb14ac798a1" +git-tree-sha1 = "cf0a9940f250dc3cb6cc6c6821b4bf8a4286cf9c" uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.63.1" +version = "0.66.2" [[deps.GR_jll]] deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Pkg", "Qt5Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "aa22e1ee9e722f1da183eb33370df4c1aeb6c2cd" +git-tree-sha1 = "2d908286d120c584abbe7621756c341707096ba4" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.63.1+0" +version = "0.66.2+0" + +[[deps.GeoInterface]] +deps = ["Extents"] +git-tree-sha1 = "fb28b5dc239d0174d7297310ef7b84a11804dfab" +uuid = "cf35fbd7-0cd7-5166-be24-54bfbe79505f" +version = "1.0.1" [[deps.GeometryBasics]] -deps = ["EarCut_jll", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] -git-tree-sha1 = "58bcdf5ebc057b085e58d95c138725628dd7453c" +deps = ["EarCut_jll", "GeoInterface", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] +git-tree-sha1 = "a7a97895780dab1085a97769316aa348830dc991" uuid = "5c1252a2-5f33-56bf-86c9-59e7332b4326" -version = "0.4.1" +version = "0.4.3" [[deps.Gettext_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] @@ -312,10 +335,10 @@ uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" version = "1.0.2" [[deps.HTTP]] -deps = ["Base64", "Dates", "IniFile", "Logging", "MbedTLS", "NetworkOptions", "Sockets", "URIs"] -git-tree-sha1 = "0fa77022fe4b511826b39c894c90daf5fce3334a" +deps = ["Base64", "CodecZlib", "Dates", "IniFile", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "f0956f8d42a92816d2bf062f8a6a6a0ad7f9b937" uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "0.9.17" +version = "1.2.1" [[deps.HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] @@ -330,10 +353,9 @@ uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" version = "0.2.2" [[deps.IniFile]] -deps = ["Test"] -git-tree-sha1 = "098e4d2c533924c921f9f9847274f2ad89e018b8" +git-tree-sha1 = "f550e6e32074c939295eb5ea6de31849ac2c9625" uuid = "83e8ac13-25f8-5344-8a64-a9f2b223428f" -version = "0.5.0" +version = "0.5.1" [[deps.IntelOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -346,10 +368,10 @@ deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[deps.Interpolations]] -deps = ["AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"] -git-tree-sha1 = "b15fc0a95c564ca2e0a7ae12c1f095ca848ceb31" +deps = ["Adapt", "AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"] +git-tree-sha1 = "64f138f9453a018c8f3562e7bae54edc059af249" uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -version = "0.13.5" +version = "0.14.4" [[deps.InverseFunctions]] deps = ["Test"] @@ -392,15 +414,15 @@ version = "0.21.2" [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "d735490ac75c5cb9f1b00d8b5509c11984dc6943" +git-tree-sha1 = "b53380851c6e6664204efb2e62cd24fa5c47e4ba" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "2.1.0+0" +version = "2.1.2+0" [[deps.KernelDensity]] deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"] -git-tree-sha1 = "591e8dc09ad18386189610acafb970032c519707" +git-tree-sha1 = "9816b296736292a80b9a3200eb7fbb57aaa3917a" uuid = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" -version = "0.6.3" +version = "0.6.5" [[deps.LAME_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -408,6 +430,12 @@ git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c" uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" version = "3.100.1+0" +[[deps.LERC_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434" +uuid = "88015f11-f218-50d7-93a8-a6af411a945d" +version = "3.0.0+1" + [[deps.LZO_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "e5b909bcf985c5e2605737d2ce278ed791b89be6" @@ -421,9 +449,9 @@ version = "1.3.0" [[deps.Latexify]] deps = ["Formatting", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "Printf", "Requires"] -git-tree-sha1 = "a8f4f279b6fa3c3c4f1adadd78a621b13a506bce" +git-tree-sha1 = "1a43be956d433b5d0321197150c2f94e16c0aaa0" uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" -version = "0.15.9" +version = "0.15.16" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -432,10 +460,12 @@ uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.84.0+0" [[deps.LibGit2]] deps = ["Base64", "NetworkOptions", "Printf", "SHA"] @@ -444,6 +474,7 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -485,10 +516,10 @@ uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" version = "2.35.0+0" [[deps.Libtiff_jll]] -deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Pkg", "Zlib_jll", "Zstd_jll"] -git-tree-sha1 = "340e257aada13f95f98ee352d316c3bed37c8ab9" +deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "Pkg", "Zlib_jll", "Zstd_jll"] +git-tree-sha1 = "3eb79b0ca5764d4799c06699573fd8f533259713" uuid = "89763e89-9b03-5906-acba-b20f662cd828" -version = "4.3.0+0" +version = "4.4.0+0" [[deps.Libuuid_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -515,11 +546,17 @@ version = "0.3.6" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "5d4d2d9904227b8bd66386c1138cf4d5ffa826bf" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "0.4.9" + [[deps.MKL_jll]] deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "5455aef09b40e5020e1520f551fa3135040d4ed0" +git-tree-sha1 = "e595b205efd49508358f7dc670a940c790204629" uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" -version = "2021.1.1+2" +version = "2022.0.0+0" [[deps.MacroTools]] deps = ["Markdown", "Random"] @@ -538,14 +575,15 @@ uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" version = "0.10.7" [[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "Random", "Sockets"] -git-tree-sha1 = "1c38e51c3d08ef2278062ebceade0e46cefc96fe" +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"] +git-tree-sha1 = "d9ab10da9de748859a7780338e1d6566993d1f25" uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.0.3" +version = "1.1.3" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.0+0" [[deps.Measures]] git-tree-sha1 = "e498ddeee6f9fdb4551ce855a46f54dbd900245f" @@ -563,12 +601,13 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.2.1" [[deps.MultivariateStats]] -deps = ["Arpack", "LinearAlgebra", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "8d958ff1854b166003238fe191ec34b9d592860a" +deps = ["Arpack", "LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI", "StatsBase"] +git-tree-sha1 = "7008a3412d823e29d370ddc77411d593bd8a3d03" uuid = "6f286f6a-111f-5878-ab1e-185364afe411" -version = "0.8.0" +version = "0.9.1" [[deps.MutableArithmetics]] deps = ["LinearAlgebra", "SparseArrays", "Test"] @@ -583,29 +622,31 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.8.0" [[deps.NaNMath]] -git-tree-sha1 = "b086b7ea07f8e38cf122f5016af580881ac914fe" +deps = ["OpenLibm_jll"] +git-tree-sha1 = "a7c3d1da1189a1c2fe843a3bfa04d18d20eb3211" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "0.3.7" +version = "1.0.1" [[deps.NearestNeighbors]] deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "16baacfdc8758bc374882566c9187e785e85c2f0" +git-tree-sha1 = "0e353ed734b1747fc20cd4cba0edd9ac027eff6a" uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.9" +version = "0.4.11" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" [[deps.Observables]] -git-tree-sha1 = "fe29afdef3d0c4a8286128d4e45cc50621b1e43d" +git-tree-sha1 = "dfd8d34871bc3ad08cd16026c1828e271d554db9" uuid = "510215fc-4207-5dde-b226-833fc4488ee2" -version = "0.4.0" +version = "0.5.1" [[deps.OffsetArrays]] deps = ["Adapt"] -git-tree-sha1 = "043017e0bdeff61cfbb7afeb558ab29536bbb5ed" +git-tree-sha1 = "1ea784113a6aa054c5ebd95945fa5e52c2f378e7" uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.10.8" +version = "1.12.7" [[deps.Ogg_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -616,16 +657,18 @@ version = "1.3.5+1" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.20+0" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+0" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "648107615c15d4e09f7eca16307bc821c1f718d8" +git-tree-sha1 = "e60321e3f2616584ff98f0a4f18d98ae6f89bbb3" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "1.1.13+0" +version = "1.1.17+0" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -635,7 +678,7 @@ version = "0.5.5+0" [[deps.OptimalTransport]] deps = ["ExactOptimalTransport", "IterativeSolvers", "LinearAlgebra", "LogExpFunctions", "NNlib", "Reexport"] -path = "../.." +git-tree-sha1 = "79ba1dab46dfc7b677278ebe892a431788da86a9" uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" version = "0.3.19" @@ -677,24 +720,25 @@ version = "0.40.1+0" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.8.0" [[deps.PlotThemes]] -deps = ["PlotUtils", "Requires", "Statistics"] -git-tree-sha1 = "a3a964ce9dc7898193536002a6dd892b1b5a6f1d" +deps = ["PlotUtils", "Statistics"] +git-tree-sha1 = "8162b2f8547bc23876edd0c5181b27702ae58dce" uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" -version = "2.0.1" +version = "3.0.0" [[deps.PlotUtils]] deps = ["ColorSchemes", "Colors", "Dates", "Printf", "Random", "Reexport", "Statistics"] -git-tree-sha1 = "6f1b25e8ea06279b5689263cc538f51331d7ca17" +git-tree-sha1 = "9888e59493658e476d3073f1ce24348bdc086660" uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" -version = "1.1.3" +version = "1.3.0" [[deps.Plots]] -deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "GeometryBasics", "JSON", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "PlotThemes", "PlotUtils", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "UUIDs", "UnicodeFun", "Unzip"] -git-tree-sha1 = "68e602f447344154f3b80f7d14bfb459a0f4dadf" +deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "GeometryBasics", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "UUIDs", "UnicodeFun", "Unzip"] +git-tree-sha1 = "a19652399f43938413340b2068e11e55caa46b65" uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -version = "1.25.5" +version = "1.31.7" [[deps.Preferences]] deps = ["TOML"] @@ -712,9 +756,9 @@ uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" [[deps.Qt5Base_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "xkbcommon_jll"] -git-tree-sha1 = "ad368663a5e20dbb8d6dc2fddeefe4dae0781ae8" +git-tree-sha1 = "c6c0f690d0cc7caddb74cef7aa847b824a16b256" uuid = "ea2cea3b-5b76-57ae-a6ef-0a8af62496e1" -version = "5.15.3+0" +version = "5.15.3+1" [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] @@ -732,9 +776,9 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[deps.Ratios]] deps = ["Requires"] -git-tree-sha1 = "01d341f502250e81f6fec0afe662aa861392a3aa" +git-tree-sha1 = "dc84268fe0e3335a62e315a3a7cf2afa7178a734" uuid = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439" -version = "0.4.2" +version = "0.4.3" [[deps.RecipesBase]] git-tree-sha1 = "6bf3f380ff52ce0832ddd3a2a7b9538ed1bcca7d" @@ -743,9 +787,9 @@ version = "1.2.1" [[deps.RecipesPipeline]] deps = ["Dates", "NaNMath", "PlotUtils", "RecipesBase"] -git-tree-sha1 = "7ad0dfa8d03b7bcf8c597f59f5292801730c55b8" +git-tree-sha1 = "e7eac76a958f8664f2718508435d058168c7953d" uuid = "01d81517-befc-4cb6-b9ec-a95719d0359c" -version = "0.4.1" +version = "0.6.3" [[deps.Reexport]] git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" @@ -754,9 +798,9 @@ version = "1.2.2" [[deps.RelocatableFolders]] deps = ["SHA", "Scratch"] -git-tree-sha1 = "cdbd3b1338c72ce29d9584fdbe9e9b70eeb5adca" +git-tree-sha1 = "22c5201127d7b243b9ee1de3b43c408879dff60f" uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" -version = "0.1.3" +version = "0.3.0" [[deps.Requires]] deps = ["UUIDs"] @@ -778,18 +822,19 @@ version = "0.3.0+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" [[deps.Scratch]] deps = ["Dates"] -git-tree-sha1 = "0b4b7f1393cff97c33891da2a0bf69c6ed241fda" +git-tree-sha1 = "f94f779c94e58bf9ea243e77a37e16d9de9126bd" uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.1.0" +version = "1.1.1" [[deps.SentinelArrays]] deps = ["Dates", "Random"] -git-tree-sha1 = "15dfe6b103c2a993be24404124b8791a09460983" +git-tree-sha1 = "db8481cf5d6278a121184809e9eb1628943c7704" uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.3.11" +version = "1.3.13" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -804,6 +849,11 @@ git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de" uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" version = "1.0.3" +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -824,10 +874,15 @@ uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "2.1.0" [[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "2884859916598f974858ff01df7dfc6c708dd895" +deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] +git-tree-sha1 = "85bc4b051546db130aeb1e8a696f1da6d4497200" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.3.3" +version = "1.5.5" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "5b413a57dd3cea38497d745ce088ac8592fbb5be" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.1.0" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -851,16 +906,16 @@ uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" version = "0.9.15" [[deps.StatsPlots]] -deps = ["Clustering", "DataStructures", "DataValues", "Distributions", "Interpolations", "KernelDensity", "LinearAlgebra", "MultivariateStats", "Observables", "Plots", "RecipesBase", "RecipesPipeline", "Reexport", "StatsBase", "TableOperations", "Tables", "Widgets"] -git-tree-sha1 = "e1e5ed9669d5521d4bbdd4fab9f0945a0ffceba2" +deps = ["AbstractFFTs", "Clustering", "DataStructures", "DataValues", "Distributions", "Interpolations", "KernelDensity", "LinearAlgebra", "MultivariateStats", "Observables", "Plots", "RecipesBase", "RecipesPipeline", "Reexport", "StatsBase", "TableOperations", "Tables", "Widgets"] +git-tree-sha1 = "2b35ba790f1f823872dcf378a6d3c3b520092eac" uuid = "f3b207a7-027a-5e70-b257-86293d7955fd" -version = "0.14.30" +version = "0.15.1" [[deps.StructArrays]] -deps = ["Adapt", "DataAPI", "StaticArrays", "Tables"] -git-tree-sha1 = "d21f2c564b21a202f4677c0fba5b5ee431058544" +deps = ["Adapt", "DataAPI", "StaticArraysCore", "Tables"] +git-tree-sha1 = "8c6ac65ec9ab781af05b08ff305ddc727c25f680" uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.4" +version = "0.6.12" [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] @@ -869,6 +924,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.0" [[deps.TableOperations]] deps = ["SentinelArrays", "Tables", "Test"] @@ -883,14 +939,21 @@ uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" version = "1.0.1" [[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"] -git-tree-sha1 = "bb1064c9a84c52e277f1096cf41434b675cd368b" +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] +git-tree-sha1 = "5ce79ce186cc678bbb5c5681ca3379d1ddae11a1" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.6.1" +version = "1.7.0" [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.TensorCore]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" +uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" +version = "0.1.1" [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] @@ -903,9 +966,9 @@ uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" version = "0.9.6" [[deps.URIs]] -git-tree-sha1 = "97bbe755a53fe859669cd907f2d96aee8d2c1355" +git-tree-sha1 = "e59ecc5a41b000fa94423a578d29290c7266fc10" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.3.0" +version = "1.4.0" [[deps.UUIDs]] deps = ["Random", "SHA"] @@ -933,15 +996,15 @@ version = "1.19.0+0" [[deps.Wayland_protocols_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "66d72dc6fcc86352f01676e8f0f698562e60510f" +git-tree-sha1 = "4528479aa01ee1b3b4cd0e6faef0e04cf16466da" uuid = "2381bf8a-dfd0-557d-9999-79630e7b1b91" -version = "1.23.0+0" +version = "1.25.0+0" [[deps.Widgets]] deps = ["Colors", "Dates", "Observables", "OrderedCollections"] -git-tree-sha1 = "80661f59d28714632132c73779f8becc19a113f2" +git-tree-sha1 = "fcdae142c1cfc7d89de2d11e08721d0f2f86c98a" uuid = "cc8bc4a8-27d6-5769-a93b-9d913e69aa62" -version = "0.6.4" +version = "0.6.6" [[deps.WoodburyMatrices]] deps = ["LinearAlgebra", "SparseArrays"] @@ -951,9 +1014,9 @@ version = "0.5.5" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "Zlib_jll"] -git-tree-sha1 = "1acf5bdf07aa0907e0a37d3718bb88d4b687b74a" +git-tree-sha1 = "58443b63fb7e465a8a7210828c91c08b92132dff" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.9.12+0" +version = "2.9.14+0" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"] @@ -1090,12 +1153,19 @@ version = "1.4.0+3" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.12+3" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "cc4bf3fdde8b7e3e9fa0351bdeedba1cf3b7f6e6" +git-tree-sha1 = "e45044cd873ded54b6a5bac0eb5c971392cf1927" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.0+0" +version = "1.5.2+0" + +[[deps.libaom_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "3a2ea60308f0996d26f1e5354e10c24e9ef905d4" +uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b" +version = "3.4.0+0" [[deps.libass_jll]] deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] @@ -1106,6 +1176,7 @@ version = "0.15.1+0" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.1.1+0" [[deps.libfdk_aac_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1128,10 +1199,12 @@ version = "1.3.7+1" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.48.0+0" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+0" [[deps.x264_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1147,6 +1220,6 @@ version = "3.5.0+0" [[deps.xkbcommon_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Wayland_jll", "Wayland_protocols_jll", "Xorg_libxcb_jll", "Xorg_xkeyboard_config_jll"] -git-tree-sha1 = "ece2350174195bb31de1a63bea3a41ae1aa593b6" +git-tree-sha1 = "9ebfc140cc56e8c2156a15ceac2f0302e327ac0a" uuid = "d8fb68d0-12a3-5cfd-a85a-d49703b185fd" -version = "0.9.1+5" +version = "1.4.1+0" diff --git a/examples/OneDimension/Project.toml b/examples/OneDimension/Project.toml index 4c7efe83..63fd8dbf 100644 --- a/examples/OneDimension/Project.toml +++ b/examples/OneDimension/Project.toml @@ -12,5 +12,4 @@ Distances = "0.10" Distributions = "0.25" Literate = "2.9" OptimalTransport = "0.3" -StatsPlots = "0.14" julia = "1" diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 1653431e..bbf0a29a 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -21,7 +21,7 @@ export QuadraticOTNewton export sinkhorn, sinkhorn2 export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter export sinkhorn_unbalanced, sinkhorn_unbalanced2 -export sinkhorn_divergence +export sinkhorn_divergence, sinkhorn_divergence_unbalanced export quadreg include("utils.jl") diff --git a/src/entropic/sinkhorn_solve.jl b/src/entropic/sinkhorn_solve.jl index e41c24b1..919bcbb9 100644 --- a/src/entropic/sinkhorn_solve.jl +++ b/src/entropic/sinkhorn_solve.jl @@ -85,12 +85,12 @@ function solve!( isconverged, abserror = OptimalTransport.check_convergence(solver) @debug string(solver.alg) * - " (" * - string(iter) * - "/" * - string(maxiter) * - ": absolute error of source marginal = " * - string(maximum(abserror)) + " (" * + string(iter) * + "/" * + string(maxiter) * + ": absolute error of source marginal = " * + string(maximum(abserror)) if isconverged @debug "$(solver.alg) ($iter/$maxiter): converged" diff --git a/src/entropic/sinkhorn_stabilized.jl b/src/entropic/sinkhorn_stabilized.jl index 861cb72a..7519aa67 100644 --- a/src/entropic/sinkhorn_stabilized.jl +++ b/src/entropic/sinkhorn_stabilized.jl @@ -83,11 +83,11 @@ function prestep!(solver::SinkhornSolver{<:SinkhornStabilized}, iter::Int) # absorption step if maximum(abs, u) > absorb_tol || maximum(abs, v) > absorb_tol @debug string(solver.alg) * - " (" * - string(iter) * - "/" * - string(maxiter) * - ": absorbing `u` and `v` into `α` and `β`" + " (" * + string(iter) * + "/" * + string(maxiter) * + ": absorbing `u` and `v` into `α` and `β`" # absorb `u` and `v` into `α` and `β` absorb!(solver) diff --git a/src/entropic/sinkhorn_unbalanced.jl b/src/entropic/sinkhorn_unbalanced.jl index 43d7775e..8cde35f9 100644 --- a/src/entropic/sinkhorn_unbalanced.jl +++ b/src/entropic/sinkhorn_unbalanced.jl @@ -1,3 +1,13 @@ +""" + proxdivKL!(s, p, ε, λ) + +Operator ``\\operatorname{proxdiv}_F(s, p, ε)`` associated with the marginal penalty +``q \\mapsto \\lambda \\operatorname{KL}(q | p)``. For further details see [^CPSV18]. + +[^CPSV18]: Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F.-X. (2018). [Scaling algorithms for unbalanced optimal transport problems](https://doi.org/10.1090/mcom/3303). Mathematics of Computation, 87(314), 2563–2609. +""" +proxdivKL!(s, p, ε, λ) = (s .= (p ./ s) .^ (λ / (ε + λ))) + """ sinkhorn_unbalanced(μ, ν, C, λ1::Real, λ2::Real, ε; kwargs...) @@ -39,13 +49,17 @@ function sinkhorn_unbalanced( end # define "proxdiv" functions for the unbalanced OT problem - proxdivF!(s, p, ε, λ) = (s .= (p ./ s) .^ (λ / (ε + λ))) - proxdivF1!(s, p, ε) = proxdivF!(s, p, ε, λ1) - proxdivF2!(s, p, ε) = proxdivF!(s, p, ε, λ2) + proxdivF1!(s, p, ε) = proxdivKL!(s, p, ε, λ1) + proxdivF2!(s, p, ε) = proxdivKL!(s, p, ε, λ2) return sinkhorn_unbalanced(μ, ν, C, proxdivF1!, proxdivF2!, ε; kwargs...) end +function sinkhorn_unbalanced(μ, C, λ::Real, ε; kwargs...) + proxdivF!(s, p, ε) = proxdivKL!(s, p, ε, λ) + return sinkhorn_unbalanced(μ, C, proxdivF!, ε; kwargs...) +end + """ sinkhorn_unbalanced( μ, ν, C, proxdivF1!, proxdivF2!, ε; @@ -177,11 +191,11 @@ function sinkhorn_unbalanced( b_old .-= b sqeuclidean_a_b = sum(abs2, a_old) + sum(abs2, b_old) @debug "Sinkhorn algorithm (" * - string(iter) * - "/" * - string(_maxiter) * - ": squared Euclidean distance of iterates = " * - string(sqeuclidean_a_b) + string(iter) * + "/" * + string(_maxiter) * + ": squared Euclidean distance of iterates = " * + string(sqeuclidean_a_b) # check convergence of `a` if sqeuclidean_a_b < max(sqatol, sqrtol * max(sqnorm_a_b, sqnorm_a_b_old)) @@ -199,6 +213,71 @@ function sinkhorn_unbalanced( return K .* a .* b' end +""" + function sinkhorn_unbalanced( + μ, C, proxdivF!, ε; atol = nothing, rtol = nothing, maxiter::Int = 1_000, check_convergence::Int=10 + ) + + Specialised case of [`sinkhorn_unbalanced`](@ref) to the special symmetric case where both inputs `μ, ν` are identical and the cost `C` is symmetric. + This implementation takes advantage of additional structure in the symmetric case which allows for a fixed point iteration with much faster convergence, + similar to that described by [^FeydyP19] and also employed in [`sinkhorn_divergence`](@ref) for the balanced case. + + [^FeydyP19]: Jean Feydy, Thibault Séjourné, François-Xavier Vialard, Shun-ichi Amari, Alain Trouvé, and Gabriel Peyré. Interpolating between optimal transport and mmd using sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics, pages 2681–2690. PMLR, 2019. +""" +function sinkhorn_unbalanced( + μ, + C, + proxdivF!, + ε; + atol=nothing, + rtol=nothing, + maxiter::Int=1_000, + check_convergence::Int=10, +) + # compute Gibbs kernel + K = @. exp(-C / ε) + + # set default values of squared tolerances + T = float(Base.promote_eltype(μ, K)) + sqatol = atol === nothing ? 0 : atol^2 + sqrtol = rtol === nothing ? (sqatol > zero(sqatol) ? zero(T) : eps(T)) : rtol^2 + + # initialize iterate and cache + a = similar(μ, T) + sum!(a, K) + tmp = similar(a) + + isconverged = false + for iter in 1:maxiter + ischeck = iter % check_convergence == 0 + mul!(tmp, K, a) + proxdivF!(tmp, μ, ε) + if ischeck + sqnorm_a = sum(abs2, tmp) + sqnorm_a_old = sum(abs2, a) + sqeuclidean_a = sum(abs2, a - tmp) + @debug "Sinkhorn algorithm (" * + string(iter) * + "/" * + string(maxiter) * + ": squared Euclidean distance of iterates = " * + string(sqeuclidean_a) + + # check convergence of `a` + if sqeuclidean_a < max(sqatol, sqrtol * max(sqnorm_a, sqnorm_a_old)) + @debug "Sinkhorn algorithm ($iter/$maxiter): converged" + isconverged = true + break + end + end + @. a = exp.(0.5 * (log.(a) + log.(tmp))) + end + if !isconverged + @warn "Sinkhorn algorithm ($maxiter/$maxiter): not converged" + end + return K .* a .* a' +end + """ sinkhorn_unbalanced2(μ, ν, C, λ1, λ2, ε; plan=nothing, kwargs...) sinkhorn_unbalanced2(μ, ν, C, proxdivF1!, proxdivF2!, ε; plan=nothing, kwargs...) @@ -216,18 +295,57 @@ optimal transport problems with general soft marginal constraints. See also: [`sinkhorn_unbalanced`](@ref) """ function sinkhorn_unbalanced2( - μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; plan=nothing, kwargs... + μ, ν, c, λ1_or_proxdivf1, λ2_or_proxdivf2, ε; plan=nothing, kwargs... ) γ = if plan === nothing - sinkhorn_unbalanced(μ, ν, C, λ1_or_proxdivF1, λ2_or_proxdivF2, ε; kwargs...) + sinkhorn_unbalanced(μ, ν, c, λ1_or_proxdivf1, λ2_or_proxdivf2, ε; kwargs...) else # check dimensions - size(C) == (length(μ), length(ν)) || - error("cost matrix `C` must be of size `(length(μ), length(ν))`") - size(plan) == size(C) || error( - "optimal transport plan `plan` and cost matrix `C` must be of the same size", + size(c) == (length(μ), length(ν)) || + error("cost matrix `c` must be of size `(length(μ), length(ν))`") + size(plan) == size(c) || error( + "optimal transport plan `plan` and cost matrix `c` must be of the same size", ) plan end - return dot(γ, C) + return dot(γ, c) +end + +function sinkhorn_unbalanced2(μ, c, λ_or_proxdivf, ε; plan=nothing, kwargs...) + γ = if plan === nothing + sinkhorn_unbalanced(μ, c, λ_or_proxdivf, ε; kwargs...) + else + # check dimensions + size(c) == (length(μ), length(μ)) || + error("cost matrix `c` must be of size `(length(μ), length(μ))`") + size(plan) == size(c) || error( + "optimal transport plan `plan` and cost matrix `c` must be of the same size", + ) + plan + end + return dot(γ, c) +end + +""" + sinkhorn_divergence_unbalanced(μ, ν, cμν, cμ, cν, λ, ε; kwargs...) + +Compute the unbalanced Sinkhorn divergence between unnormalized inputs `μ` and `ν` with cost matrix `cμν`, `cμ` and `cν` between `(μ,ν)`, `(μ, μ)` and `(ν, ν)` respectively, +regularization level `ε` and marginal constraint parameter `λ`. Following [^SFVTP19], the unbalanced Sinkhorn divergence is defined as +```math + \\operatorname{S}_{\\varepsilon, \\lambda} (\\mu, \\nu) := \\operatorname{OT}_{ε, λ}(μ,ν) + - \\frac{1}{2}(\\operatorname{OT}_{ε, λ}(μ,μ) + \\operatorname{OT}_{ε, λ}(ν,ν)) + \\frac{ε}{2}(m(μ) + m(ν))^2, +``` +where ``\\operatorname{OT}_{ε, λ}(\\alpha, \\beta)`` is defined to be +```math + \\operatorname{OT}_{ε, λ}(\\alpha, \\beta) = \\inf_{\\gamma} \\langle C, \\gamma \\rangle + \\varepsilon \\operatorname{KL}(\\gamma | \\alpha \\otimes \\beta) + \\lambda ( \\operatorname{KL}(\\gamma_1 | \\alpha) + \\operatorname{KL}(\\gamma_2 | \\beta) ), +``` +i.e. the output of calling `sinkhorn_unbalanced2` with the default Kullback-Leibler marginal penalties. + +[^SFVTP19]: Séjourné, T., Feydy, J., Vialard, F.X., Trouvé, A. and Peyré, G., 2019. Sinkhorn divergences for unbalanced optimal transport. arXiv preprint arXiv:1910.12958. +""" +function sinkhorn_divergence_unbalanced(μ, ν, cμν, cμ, cν, λ, ε; kwargs...) + Sμν = sinkhorn_unbalanced2(μ, ν, cμν, λ, λ, ε; kwargs...) + Sμ = sinkhorn_unbalanced2(μ, cμ, λ, ε; kwargs...) + Sν = sinkhorn_unbalanced2(ν, cν, λ, ε; kwargs...) + return max(0, Sμν - (Sμ + Sν) / 2 + ε * (sum(μ) - sum(ν))^2 / 2) end diff --git a/src/quadratic_newton.jl b/src/quadratic_newton.jl index a033a133..4ca2a749 100644 --- a/src/quadratic_newton.jl +++ b/src/quadratic_newton.jl @@ -198,12 +198,12 @@ function solve!(solver::QuadraticOTSolver{<:QuadraticOTNewton}) μ, ν, cache, convergence_cache, atol, rtol ) @debug string(solver.alg) * - " (" * - string(iter) * - "/" * - string(maxiter) * - ": absolute error of source marginal = " * - string(maximum(abserror)) + " (" * + string(iter) * + "/" * + string(maxiter) * + ": absolute error of source marginal = " * + string(maximum(abserror)) if isconverged @debug "$(solver.alg) ($iter/$maxiter): converged" diff --git a/test/entropic/sinkhorn_gibbs.jl b/test/entropic/sinkhorn_gibbs.jl index e9577e85..582e76c8 100644 --- a/test/entropic/sinkhorn_gibbs.jl +++ b/test/entropic/sinkhorn_gibbs.jl @@ -55,7 +55,7 @@ Random.seed!(100) ) @test c_w_regularization ≈ c + ε * sum(x -> iszero(x) ? x : x * log(x), γ) @test c_w_regularization == - sinkhorn2(μ, ν, C, ε; maxiter=5_000, regularization=true) + sinkhorn2(μ, ν, C, ε; maxiter=5_000, regularization=true) # ensure that provided plan is used and correct c2 = sinkhorn2(similar(μ), similar(ν), C, rand(), SinkhornGibbs(); plan=γ) @@ -66,7 +66,7 @@ Random.seed!(100) ) @test c2_w_regularization ≈ c_w_regularization @test c2_w_regularization == - sinkhorn2(similar(μ), similar(ν), C, ε; plan=γ, regularization=true) + sinkhorn2(similar(μ), similar(ν), C, ε; plan=γ, regularization=true) # batches of histograms d = 10 @@ -141,7 +141,7 @@ Random.seed!(100) @test size(γ_all) == (M, N, d) @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) @test γ_all == - sinkhorn(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) + sinkhorn(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) # compute optimal transport cost and check that it is consistent with the # cost for individual histograms @@ -151,7 +151,7 @@ Random.seed!(100) @test size(c_all) == (d,) @test all(x ≈ c for x in c_all) @test c_all == - sinkhorn2(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) + sinkhorn2(μ32_batch, ν32_batch, C32, ε32; maxiter=5_000, rtol=1e-6) end end diff --git a/test/entropic/sinkhorn_unbalanced.jl b/test/entropic/sinkhorn_unbalanced.jl index 25ea6015..354ac74e 100644 --- a/test/entropic/sinkhorn_unbalanced.jl +++ b/test/entropic/sinkhorn_unbalanced.jl @@ -89,6 +89,29 @@ Random.seed!(100) @test c_balanced ≈ c rtol = 1e-4 end + @testset "unbalanced Sinkhorn divergences" begin + μ = fill(1 / M, M) + μ_spt = rand(1, M) + ν = fill(1 / N, N) + ν_spt = rand(1, N) + ε = 0.01 + λ = 1.0 + Cμν = pairwise(SqEuclidean(), μ_spt, ν_spt; dims=2) + Cμμ = pairwise(SqEuclidean(), μ_spt, μ_spt; dims=2) + Cνν = pairwise(SqEuclidean(), ν_spt, ν_spt; dims=2) + + # check the symmetric terms + @test sinkhorn_unbalanced(μ, Cμμ, λ, ε) ≈ sinkhorn_unbalanced(μ, μ, Cμμ, λ, λ, ε) rtol = + 1e-4 + @test sinkhorn_unbalanced(ν, Cνν, λ, ε) ≈ sinkhorn_unbalanced(ν, ν, Cνν, λ, λ, ε) rtol = + 1e-4 + + # check against balanced case + proxdivF!(s, p, ε) = (s .= p ./ s) + @test sinkhorn_divergence_unbalanced(μ, ν, Cμν, Cμμ, Cνν, proxdivF!, ε) ≈ + sinkhorn_divergence(μ, ν, Cμν, Cμμ, Cνν, ε) rtol = 1e-4 + end + @testset "deprecations" begin μ = fill(1 / N, M) ν = fill(1 / N, N) diff --git a/test/gpu/simple_gpu.jl b/test/gpu/simple_gpu.jl index aaca5b86..edd664f7 100644 --- a/test/gpu/simple_gpu.jl +++ b/test/gpu/simple_gpu.jl @@ -32,7 +32,7 @@ Random.seed!(100) cu_C = cu(C) # regularization parameter - ε = 0.01f0 + ε = 0.05f0 @testset "sinkhorn" begin for alg in ( @@ -97,11 +97,13 @@ Random.seed!(100) @testset "quadreg" begin # use a different reg parameter ε_quad = 1.0f0 - γ = quadreg(cu_μ, cu_ν, cu_C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1f-5, 50)) + γ = quadreg( + cu_μ, cu_ν, cu_C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1.0f-5, 50) + ) # compare with results on the CPU @test convert(Array, γ) ≈ - quadreg(μ, ν, C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1f-5, 50)) atol = - 1f-4 rtol = 1f-4 + quadreg(μ, ν, C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1.0f-5, 50)) atol = + 1.0f-4 rtol = 1.0f-4 end end end diff --git a/test/utils.jl b/test/utils.jl index 7b890681..e832c852 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -39,7 +39,7 @@ Random.seed!(100) y = rand(l, m, n) @test OptimalTransport.dot_matwise(x, y) ≈ - mapreduce(vcat, (view(y, :, :, i) for i in axes(y, 3))) do yi + mapreduce(vcat, (view(y, :, :, i) for i in axes(y, 3))) do yi dot(x, yi) end @test OptimalTransport.dot_matwise(y, x) == OptimalTransport.dot_matwise(x, y)