Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #26

Draft
wants to merge 48 commits into
base: main
Choose a base branch
from
Draft

Dev #26

Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d77eaa2
bouncy updates
cscherrer Jun 7, 2022
8fff1cc
rand
cscherrer Jun 11, 2022
059ab30
minor fix
cscherrer Jun 11, 2022
4d65f63
Fix method ambiguity
cscherrer Jun 11, 2022
2dd7924
refactoring
cscherrer Jun 17, 2022
939f341
Merge branch 'dev' into temp
cscherrer Jun 17, 2022
78c3df1
Merge pull request #25 from cscherrer/temp
cscherrer Jun 17, 2022
97c29c5
Merge remote-tracking branch 'origin/dev' into dev
cscherrer Jun 17, 2022
52179a9
Merge pull request #20 from cscherrer/rand
cscherrer Jun 21, 2022
dc3c376
Disallow `rand` on ModelPosteriors
cscherrer Jun 21, 2022
3c03bd4
Merge branch 'dev' of https://github.com/cscherrer/Tilde.jl into dev
cscherrer Jun 21, 2022
a352754
Move astmodel.jl contents into model.jl
cscherrer Jun 21, 2022
3921a40
Drop old `DAGModel` stuff
cscherrer Jun 21, 2022
3aaaa74
drop dead code
cscherrer Jun 21, 2022
2f832a5
change `S` type param to `F`
cscherrer Jun 21, 2022
69ce0d8
drop on include
cscherrer Jun 21, 2022
36a2e33
edit for readability
cscherrer Jun 22, 2022
781c29d
update dependencies
cscherrer Jun 22, 2022
fbef2d1
Update src/core/models/abstractmodel.jl
cscherrer Jun 22, 2022
d41b73a
Update src/core/models/abstractmodel.jl
cscherrer Jun 22, 2022
cf209dc
Update src/core/models/model.jl
cscherrer Jun 22, 2022
bedd3fa
drop the F
cscherrer Jun 22, 2022
564ec84
Merge branch 'dev' of https://github.com/cscherrer/Tilde.jl into dev
cscherrer Jun 22, 2022
40933c3
predict (#21)
cscherrer Jun 22, 2022
9eefb93
test unbound args
cscherrer Jun 23, 2022
5e12f9a
cleanup
cscherrer Jun 23, 2022
2d29d0d
bugfix
cscherrer Jun 23, 2022
07dcf28
start on refactoring PDMP example
cscherrer Jun 24, 2022
c4ce48e
comments, mostly
cscherrer Jun 24, 2022
4011b4a
bugfix
cscherrer Jun 25, 2022
dd740bd
working on `predict`
cscherrer Jun 27, 2022
afb72fc
update Lens!!
cscherrer Jun 27, 2022
3d35531
Update benchmarks/bouncy.jl
cscherrer Jun 27, 2022
39fabf0
Update benchmarks/bouncy.jl
cscherrer Jun 27, 2022
761c60e
Merge branch 'main' into dev
cscherrer Jun 27, 2022
1d81bf7
lens stuff
cscherrer Jun 27, 2022
5c2ed88
drop optics change
cscherrer Jun 27, 2022
2ce7f40
moving things around a bit
cscherrer Jun 27, 2022
a0aabb8
Contexts (#32)
cscherrer Jul 4, 2022
e6cb833
Merge branch 'main' of https://github.com/cscherrer/Tilde.jl into dev
cscherrer Jul 4, 2022
6513e51
formatting
cscherrer Jul 4, 2022
75bdc37
bugfixes
cscherrer Jul 5, 2022
d8fb3cc
bugfix
cscherrer Jul 5, 2022
fd6ce5d
fix `predict`
cscherrer Jul 5, 2022
f315d9b
MaybeObserved shoudl be Unobserved
cscherrer Jul 5, 2022
7c93602
Move `anyfy`
cscherrer Jul 6, 2022
2181564
update `callify`
cscherrer Jul 6, 2022
ee2ab1a
deps
cscherrer Jul 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ TupleVectors = "615932cf-77b6-4358-adcd-5b7eba981d7e"

[compat]
Accessors = "0.1"
ArrayInterface = "4, 5, 6"
ArrayInterface = "5, 6"
DataStructures = "0.18"
DensityInterface = "0.4"
DiffResults = "1"
Expand All @@ -47,7 +47,7 @@ JuliaVariables = "0.2"
MLStyle = "0.3,0.4"
MacroTools = "0.5"
MappedArrays = "0.3, 0.4"
MeasureBase = "0.9"
MeasureBase = "0.12"
MeasureTheory = "0.16"
NamedTupleTools = "0.12, 0.13, 0.14"
NestedTuples = "0.3"
Expand All @@ -63,7 +63,7 @@ StatsFuns = "0.9, 1"
TransformVariables = "0.5, 0.6"
Tricks = "0.1"
TupleVectors = "0.1"
julia = "1.5"
julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
1 change: 1 addition & 0 deletions benchmarks/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Expand Down
9 changes: 4 additions & 5 deletions benchmarks/bouncy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,9 @@ M = pf_result.fit_distribution.Σ;
x0 = pf_result.fit_distribution.μ;
v0 = PDMats.unwhiten(M, randn(length(x0)));





MAP = pf_result.optim_solution; # MAP, could be useful for control variates


# define BouncyParticle sampler (has two relevant parameters)
Z = BouncyParticle(missing, # graphical structure
MAP, # MAP estimate, unused
Expand All @@ -103,6 +100,7 @@ sampler = ZZB.NotFactSampler(Z, (dneglogp, ∇neglogp!), ZZB.LocalBound(c), t0 =
using TupleVectors: chainvec
using Tilde.MeasureTheory: transform

# @time first(Iterators.drop(tvs,1000))

function collect_sampler(t, sampler, n; progress=true, progress_stops=20)
if progress
Expand Down Expand Up @@ -173,4 +171,5 @@ ylabel!(plt, "DynamicHMC");
plt_bounds = collect(extrema(ess_hmc));
lineplot!(plt, plt_bounds, plt_bounds);
plt
@info "For each coordinate, a point (x,y) shows the effective sample size per second for BPS (x) and HMC (y) . In blue is the diagonal x=y"
@info "For each coordinate, a point (x,y) shows the effective sample size per second for BPS (x) and HMC (y) . In blue is the diagonal x=y"

1 change: 0 additions & 1 deletion src/Tilde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ end
include("optics.jl")
include("maybe.jl")
include("core/models/abstractmodel.jl")
include("core/models/astmodel/astmodel.jl")
include("core/models/model.jl")
include("core/dependencies.jl")
include("core/utils.jl")
Expand Down
8 changes: 7 additions & 1 deletion src/core/models/abstractmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
# struct MixedSupport <: ValueSupport end
# struct MixedVariate <: VariateForm end

abstract type ModelSupport end

struct LatentSupport <: ModelSupport end
struct PushforwardSupport <: ModelSupport end
struct JointSupport <: ModelSupport end

cscherrer marked this conversation as resolved.
Show resolved Hide resolved
"""
AbstractModel{A,B}

Expand All @@ -15,7 +21,7 @@ N gives the Names of arguments (each a Symbol)
B gives the Body, as an Expr
M gives the Module where the model is defined
"""
abstract type AbstractModel{A,B,M} <: AbstractTransitionKernel end
abstract type AbstractModel{A,B,M,F} <: AbstractTransitionKernel end
cscherrer marked this conversation as resolved.
Show resolved Hide resolved

abstract type AbstractConditionalModel{M, Args, Obs} <: AbstractMeasure end

Expand Down
Empty file added src/core/models/astmodel.jl
Empty file.
82 changes: 0 additions & 82 deletions src/core/models/astmodel/astmodel.jl

This file was deleted.

46 changes: 46 additions & 0 deletions src/core/models/model.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,49 @@
struct Model{A,B,M<:GG.TypeLevel,F} <: AbstractModel{A,B,M,F}
args :: Vector{Symbol}
body :: Expr
f :: F
end
cscherrer marked this conversation as resolved.
Show resolved Hide resolved

function Model(theModule::Module, args::Vector{Symbol}, body::Expr)
A = NamedTuple{Tuple(args)}

B = to_type(body)
M = to_type(theModule)
return Model{A,B,M,typeof(last)}(args, body, last)
end

model(m::Model) = m

function Base.convert(::Type{Expr}, m::Model)
numArgs = length(m.args)
args = if numArgs == 1
m.args[1]
elseif numArgs > 1
Expr(:tuple, [x for x in m.args]...)
end

body = m.body

q = if numArgs == 0
@q begin
@model $body
end
else
@q begin
@model $(args) $body
end
end

striplines(q).args[1]
end

Base.show(io::IO, m :: Model) = println(io, convert(Expr, m))

function type2model(::Type{Model{A,B,M,F}}) where {A,B,M,S}
args = [fieldnames(A)...]
body = from_type(B)
Model{A,B,M,F}(from_type(M), convert(Vector{Symbol},args), body)
end

toargs(vs :: Vector{Symbol}) = Tuple(vs)
toargs(vs :: NTuple{N,Symbol} where {N}) = vs
Expand Down
9 changes: 0 additions & 9 deletions src/core/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,6 @@ end
import MacroTools: striplines, @q




# function arguments(model::DAGModel)
# model.args
# end




allequal(xs) = all(xs[1] .== xs)


Expand Down
2 changes: 1 addition & 1 deletion src/primitives/insupport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ import MeasureBase: insupport
export insupport

@inline function insupport(m::AbstractConditionalModel, x::NamedTuple)
mapreduce(insupport, (a,b) -> a && b, measures!(m,x), x)
mapreduce(insupport, (a,b) -> a && b, measures(m), x)
end
14 changes: 11 additions & 3 deletions src/primitives/interpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,21 @@ function make_body(M, f, ast::Expr, retfun, argsT, obsT, parsT)
# X = to_type(unsolved_lhs)
# M = to_type(unsolve(rhs))

inargs = inkeys(sx, argsT)
# inargs = inkeys(sx, argsT)
inobs = inkeys(sx, obsT)
inpars = inkeys(sx, parsT)
# inpars = inkeys(sx, parsT)
rhs = unsolve(rhs)

xval = inobs ? :($Observed($x)) : (x ∈ knownvars ? :($Unobserved($x)) : :($Unobserved(missing)))
xval = if inkeys(sx, obsT)
:($Observed($x))
elseif x ∈ knownvars
:($Unobserved($x))
else
:($Unobserved(missing))
end

st = :(($x, _ctx, _retn) = $tilde($f, $l, $sx, $xval, $rhs, _cfg, _ctx))

qst = QuoteNode(st)
q = quote
# println($qst)
Expand Down
49 changes: 37 additions & 12 deletions src/primitives/rand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,73 @@ using TupleVectors: chainvec
export rand
EmptyNTtype = NamedTuple{(),Tuple{}} where T<:Tuple

@inline function Base.rand(rng::AbstractRNG, d::AbstractConditionalModel, N::Int)
r = chainvec(rand(rng, d), N)
@inline function Base.rand(m::ModelClosure, args...; kwargs...)
rand(GLOBAL_RNG, Float64, m, args...; kwargs...)
end

@inline function Base.rand(rng::AbstractRNG, m::ModelClosure, args...; kwargs...)
rand(rng, Float64, m, args...; kwargs...)
end

@inline function Base.rand(::Type{T_rng}, m::ModelClosure, args...; kwargs...) where {T_rng}
rand(GLOBAL_RNG, T_rng, m, args...; kwargs...)
end

@inline function Base.rand(m::ModelClosure, d::Integer, dims::Integer...; kwargs...)
rand(GLOBAL_RNG, Float64, m, d, dims...; kwargs...)
end

@inline function Base.rand(rng::AbstractRNG, m::ModelClosure, d::Integer, dims::Integer...; kwargs...)
rand(rng, Float64, m, d, dims...; kwargs...)
end

@inline function Base.rand(::Type{T_rng}, m::ModelClosure, d::Integer, dims::Integer...; kwargs...) where {T_rng}
rand(GLOBAL_RNG, T_rng, m, d, dims...; kwargs...)
end

@inline function Base.rand(rng::AbstractRNG, ::Type{T_rng}, d::ModelClosure, N::Integer, v::Vararg{Integer}) where {T_rng}
@assert isempty(v)
r = chainvec(rand(rng, T_rng, d), N)
for j in 2:N
@inbounds r[j] = rand(rng, d)
@inbounds r[j] = rand(rng, T_rng, d)
end
return r
end

@inline Base.rand(d::AbstractConditionalModel, N::Int) = rand(GLOBAL_RNG, d, N)
@inline Base.rand(d::ModelClosure, N::Int) = rand(GLOBAL_RNG, d, N)

@inline function Base.rand(m::AbstractConditionalModel; kwargs...)
@inline function Base.rand(m::ModelClosure; kwargs...)
rand(GLOBAL_RNG, m; kwargs...)
end

@inline function Base.rand(rng::AbstractRNG, m::AbstractConditionalModel; ctx=NamedTuple(), retfun = (r, ctx) -> r)
@inline function Base.rand(rng::AbstractRNG, m::ModelClosure; ctx=NamedTuple(), retfun = (r, ctx) -> r)
cfg = (rng=rng,)

@inline function Base.rand(rng::AbstractRNG, ::Type{T_rng}, m::ModelClosure; ctx=NamedTuple(), retfun = (r, ctx) -> r) where {T_rng}
cfg = (rng=rng, T_rng=T_rng)
gg_call(rand, m, NamedTuple(), cfg, ctx, retfun)
end

###############################################################################
# ctx::NamedTuple
@inline function tilde(::typeof(Base.rand), lens, xname, x::Unobserved, d, cfg, ctx::NamedTuple)
@inline function tilde(::typeof(Base.rand), lens, xname, x, d, cfg, ctx::NamedTuple)
xnew = set(x.value, Lens!!(lens), rand(cfg.rng, d))
ctx′ = merge(ctx, NamedTuple{(dynamic(xname),)}((xnew,)))
(xnew, ctx′, nothing)
end

@inline function tilde(::typeof(Base.rand), lens, xname, x::Observed, d, cfg, ctx::NamedTuple)
(x.value, ctx, nothing)
end


###############################################################################
# ctx::Dict

@inline function tilde(::typeof(Base.rand), lens::typeof(identity), xname, x, d, cfg, ctx::Dict)
x = rand(cfg.rng, d)
x = rand(cfg.rng, cfg.T_rng, d)
ctx[dynamic(xname)] = x
(x, ctx, nothing)
end

@inline function tilde(::typeof(Base.rand), lens, xname, x, m::AbstractConditionalModel, cfg, ctx::Dict)
@inline function tilde(::typeof(Base.rand), lens, xname, x, m::ModelClosure, cfg, ctx::Dict)
args = get(cfg.args, dynamic(xname), Dict())
cfg = merge(cfg, (args = args,))
tilde(rand, lens, xname, x, m(cfg.args), cfg, ctx)
Expand Down
9 changes: 0 additions & 9 deletions test/transforms.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
# Check for Model equality up to reorderings of a few fields
function ≊(m1::DAGModel,m2::DAGModel)
function eq_tuples(nt1::NamedTuple,nt2::NamedTuple)
return length(nt1)==length(nt2) && all(nt1[k]==nt2[k] for k in keys(nt1))
end
return Set(arguments(m1))==Set(arguments(m2)) && m1.retn==m2.retn && eq_tuples(m1.dists,m2.dists) && eq_tuples(m1.vals,m2.vals)
end


m = @model (n,α,β) begin
p ~ Beta(α, β)
x ~ Binomial(n, p)
Expand Down