From a3b8f6d5983567f861740a5f6b970fd670390dea Mon Sep 17 00:00:00 2001 From: Sean <10673535+slwu89@users.noreply.github.com> Date: Sun, 26 Feb 2023 10:01:58 -0800 Subject: [PATCH] add method for Distributions.rand(::ResettableRNG, ::Binomial) (#249) * add method for Distributions.rand(::ResettableRNG, ::Binomial * Fix ResettableRNG dispatch * Test with MersenneTwister (Xoshiro not available in Julia 1.6) * bump version --------- Co-authored-by: Chad Scherrer --- Project.toml | 2 +- src/parameterized/binomial.jl | 4 ++-- src/resettable-rng.jl | 8 ++++++++ test/runtests.jl | 4 ++++ 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 26a3a666..a88a1b63 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MeasureTheory" uuid = "eadaa1a4-d27c-401d-8699-e962e1bbc33b" authors = ["Chad Scherrer and contributors"] -version = "0.18.1" +version = "0.18.2" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/src/parameterized/binomial.jl b/src/parameterized/binomial.jl index 72981b96..1825cb69 100644 --- a/src/parameterized/binomial.jl +++ b/src/parameterized/binomial.jl @@ -23,8 +23,8 @@ end x ∈ (0, 1) end -function Base.rand(rng::AbstractRNG, ::Type, d::Binomial{(:n, :p)}) - rand(rng, Dists.Binomial(d.n, d.p)) +function Base.rand(rng::AbstractRNG, ::Type{T}, d::Binomial{(:n, :p)}) where {T} + rand(rng, T, Dists.Binomial(d.n, d.p)) end Binomial(n) = Binomial(n, 0.5) diff --git a/src/resettable-rng.jl b/src/resettable-rng.jl index a6297a21..02223e97 100644 --- a/src/resettable-rng.jl +++ b/src/resettable-rng.jl @@ -57,6 +57,14 @@ for T in vcat(subtypes(Signed), subtypes(Unsigned), subtypes(AbstractFloat)) end end +function Base.rand(r::ResettableRNG, d::AbstractMeasure) + rand(r.rng, d) +end + +function Base.rand(r::ResettableRNG, ::Type{T}, d::AbstractMeasure) where {T} + rand(r.rng, T, d) +end + Base.iterate(r::ResettableRNG) = iterate(r, nothing) function Base.iterate(r::ResettableRNG, ::Nothing) diff --git a/test/runtests.jl b/test/runtests.jl index 3ccf0f45..97f6833e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -116,6 +116,10 @@ end @test ℓ ≈ logdensity_def(Binomial(; n, logitp), y) @test ℓ ≈ logdensity_def(Binomial(; n, probitp), y) + rng = ResettableRNG(Random.MersenneTwister()) + @test rand(rng, Binomial(n=0, p=1.0)) == 0 + @test rand(rng, Binomial(n=10, p=1.0)) == 10 + @test_broken logdensity_def(Binomial(n, p), CountingMeasure(ℤ[0:n]), x) ≈ binomlogpdf(n, p, x) end