Skip to content

Commit

Permalink
Support permutations of names in kldivergence
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Jan 16, 2025
1 parent 8188d35 commit d33c31d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/namedtuple/productnamedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ std(d::ProductNamedTupleDistribution) = map(std, d.dists)
entropy(d::ProductNamedTupleDistribution) = sum(entropy, values(d.dists))

function kldivergence(
d1::ProductNamedTupleDistribution{K}, d2::ProductNamedTupleDistribution{K}
d1::ProductNamedTupleDistribution{K}, d2::ProductNamedTupleDistribution
) where {K}
_named_fields_match(d1.dists, d2.dists) || throw(
ArgumentError(
"Sets of named tuple fields are not the same: !issetequal($(fieldnames(d1)), $(fieldnames(d2)))",
"Sets of named tuple fields are not the same: !issetequal($(keys(d1.dists)), $(keys(d2.dists)))",
),
)
return sum(map(kldivergence, d1.dists, NamedTuple{K}(d2.dists)))
Expand Down
5 changes: 5 additions & 0 deletions test/namedtuple/productnamedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,13 @@ using Test

d1 = ProductNamedTupleDistribution((x=Normal(1.0, 2.0), y=Gamma()))
d2 = ProductNamedTupleDistribution((x=Normal(), y=Gamma(2.0, 3.0)))
d2_perm = ProductNamedTupleDistribution((y=Gamma(2.0, 3.0), x=Normal()))
d2_sub = ProductNamedTupleDistribution((x=Normal(1.0, 2.0),))
@test kldivergence(d1, d2) ==
kldivergence(d1.dists.x, d2.dists.x) + kldivergence(d1.dists.y, d2.dists.y)
@test kldivergence(d1, d2_perm) == kldivergence(d1, d2)
@test_throws ArgumentError kldivergence(d1, d2_sub)
@test_throws ArgumentError kldivergence(d2_sub, d1)

d3 = ProductNamedTupleDistribution((x=Normal(1.0, 2.0), y=Gamma(6.0, 7.0)))
@test std(d3) == (x=std(d3.dists.x), y=std(d3.dists.y))
Expand Down

0 comments on commit d33c31d

Please sign in to comment.