diff --git a/Project.toml b/Project.toml index a079ce2..8392fda 100644 --- a/Project.toml +++ b/Project.toml @@ -6,9 +6,11 @@ version = "0.1.2" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" JuliaVariables = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -17,6 +19,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14" MeasureTheory = "eadaa1a4-d27c-401d-8699-e962e1bbc33b" +MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" NestedTuples = "a734d2a7-8d68-409b-9419-626914d4061d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -35,10 +38,12 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff" Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" TupleVectors = "615932cf-77b6-4358-adcd-5b7eba981d7e" +Umlaut = "92992a2b-8ce5-4a9c-bb9d-58be9a7dc841" +Yota = "cd998857-8626-517d-b929-70ad188a48f0" [compat] Accessors = "0.1" -ArrayInterface = "4, 5, 6" +ArrayInterface = "5, 6" DataStructures = "0.18" DensityInterface = "0.4" DiffResults = "1" @@ -47,7 +52,7 @@ JuliaVariables = "0.2" MLStyle = "0.3,0.4" MacroTools = "0.5" MappedArrays = "0.3, 0.4" -MeasureBase = "0.9" +MeasureBase = "0.12" MeasureTheory = "0.16" NamedTupleTools = "0.12, 0.13, 0.14" NestedTuples = "0.3" @@ -63,7 +68,7 @@ StatsFuns = "0.9, 1" TransformVariables = "0.5, 0.6" Tricks = "0.1" TupleVectors = "0.1" -julia = "1.5" +julia = "1.6" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index a32aeef..7a322e0 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -1,8 +1,6 @@ [deps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -Optim = "429524aa-4258-5aef-a3af-852621145aeb" Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/benchmarks/bouncy.jl b/benchmarks/bouncy.jl index 53b41d2..c22cf86 100644 --- a/benchmarks/bouncy.jl +++ b/benchmarks/bouncy.jl @@ -14,56 +14,73 @@ using ForwardDiff using ForwardDiff: Dual using Pathfinder using Pathfinder.PDMats +using MCMCChains +using TupleVectors: chainvec +using Tilde.MeasureTheory: transform Random.seed!(1) +function make_grads(post) + as_post = as(post) + d = TV.dimension(as_post) + obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ)) + ℓ(θ) = -obj(θ) + @inline function dneglogp(t, x, v, args...) # two directional derivatives + f(t) = obj(x + t * v) + u = ForwardDiff.derivative(f, Dual{:hSrkahPmmC}(0.0, 1.0)) + u.value, u.partials[] + end + + gconfig = ForwardDiff.GradientConfig(obj, rand(d), ForwardDiff.Chunk{d}()) + function ∇neglogp!(y, t, x, args...) + ForwardDiff.gradient!(y, obj, x, gconfig) + y + end + ℓ, dneglogp, ∇neglogp! +end + +# ↑ general purpose +############################################################ +# ↓ problem-specific + # read data function readlrdata() fname = joinpath("lr.data") z = readdlm(fname) - A = z[:, 1:end-1] + A = z[:, 1:(end-1)] A = [ones(size(A, 1)) A] y = z[:, end] .- 1 return A, y end -A, y = readlrdata(); -At = collect(A'); model_lr = @model (At, y, σ) begin d, n = size(At) θ ~ Normal(σ = σ)^d for j in 1:n - logitp = dot(view(At, :, j), θ) + logitp = view(At, :, j)' * θ y[j] ~ Bernoulli(logitp = logitp) end end + +# Define model arguments +A, y = readlrdata(); +At = collect(A'); σ = 100.0 -function make_grads(model_lr, At, y, σ) - post = model_lr(At, y, σ) | (; y) - as_post = as(post) - obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ)) - ℓ(θ) = -obj(θ) - @inline function dneglogp(t, x, v) # two directional derivatives - f(t) = obj(x + t * v) - u = ForwardDiff.derivative(f, Dual{:hSrkahPmmC}(0.0, 1.0)) - u.value, u.partials[] - end +# Represent the posterior +post = model_lr(At, y, σ) | (; y) - gconfig = ForwardDiff.GradientConfig(obj, rand(25), ForwardDiff.Chunk{25}()) - function ∇neglogp!(y, t, x) - ForwardDiff.gradient!(y, obj, x, gconfig) - return - end - post, ℓ, dneglogp, ∇neglogp! -end +d = TV.dimension(as(post)) -post, ℓ, dneglogp, ∇neglogp! = make_grads(model_lr, At, y, σ) -# Try things out -dneglogp(2.4, randn(25), randn(25)); -∇neglogp!(randn(25), 2.1, randn(25)); +# Make sure gradients are working +let + ℓ, dneglogp, ∇neglogp! = make_grads(post) + @show dneglogp(2.4, randn(d), randn(d)) + y = Vector{Float64}(undef, d) + @show ∇neglogp!(y, 2.1, randn(d)) + nothing +end -d = 25 # number of parameters t0 = 0.0; x0 = zeros(d); # starting point sampler # estimated posterior mean (n=100000, 797s) @@ -129,8 +146,7 @@ sampler = ZZB.NotFactSampler( ), ); -using TupleVectors: chainvec -using Tilde.MeasureTheory: transform +# @time first(Iterators.drop(tvs,1000)) function collect_sampler(t, sampler, n; progress = true, progress_stops = 20) if progress @@ -166,7 +182,6 @@ elapsed_time = @elapsed @time begin bps_samples, info = collect_sampler(as(post), sampler, n; progress = false) end -using MCMCChains bps_chain = MCMCChains.Chains(bps_samples.θ); bps_chain = setinfo(bps_chain, (; start_time = 0.0, stop_time = elapsed_time)); diff --git a/src/GG/deprecated_codes/explicit_scope.jl b/src/GG/deprecated_codes/explicit_scope.jl index 3c0d6d2..283f5f6 100644 --- a/src/GG/deprecated_codes/explicit_scope.jl +++ b/src/GG/deprecated_codes/explicit_scope.jl @@ -3,7 +3,7 @@ function scoping(ast) @match ast begin :([$(frees...)]($(args...)) -> begin $(stmts...) - end) => let stmts = map(rec, stmts), arw = :(($(args...),) -> begin + end) => let stmts = map(rec, stmts), arw = :(($(args...),) -> begin $(stmts...) end) Expr(:scope, (), Tuple(frees), (), arw) diff --git a/src/GG/deprecated_codes/static_closure_conv.jl b/src/GG/deprecated_codes/static_closure_conv.jl index 542ffa7..95a4dd2 100644 --- a/src/GG/deprecated_codes/static_closure_conv.jl +++ b/src/GG/deprecated_codes/static_closure_conv.jl @@ -39,6 +39,11 @@ function mk_closure_static(expr, toplevel::Vector{Expr}) $Closure{$glob_name,typeof(frees)}(frees) end ) + ret = :( + let frees = $closure_arg + $Closure{$glob_name,typeof(frees)}(frees) + end + ) (fn_expr, ret) end diff --git a/src/Tilde.jl b/src/Tilde.jl index 7cf0c18..3c2da99 100644 --- a/src/Tilde.jl +++ b/src/Tilde.jl @@ -82,13 +82,12 @@ end include("optics.jl") include("maybe.jl") include("core/models/abstractmodel.jl") -include("core/models/astmodel/astmodel.jl") include("core/models/model.jl") include("core/dependencies.jl") include("core/utils.jl") include("core/models/closure.jl") +include("maybeobserved.jl") include("core/models/posterior.jl") -include("primitives/interpret.jl") include("distributions/iid.jl") include("primitives/rand.jl") @@ -96,12 +95,14 @@ include("primitives/logdensity.jl") include("primitives/logdensity_rel.jl") include("primitives/insupport.jl") -# include("primitives/basemeasure.jl") include("primitives/testvalue.jl") include("primitives/testparams.jl") include("primitives/weightedsampling.jl") include("primitives/measures.jl") include("primitives/basemeasure.jl") +include("primitives/predict.jl") +include("primitives/dag.jl") +include("primitives/interpret.jl") include("transforms/utils.jl") diff --git a/src/callify.jl b/src/callify.jl index 80d39b6..4f8a574 100644 --- a/src/callify.jl +++ b/src/callify.jl @@ -5,10 +5,31 @@ using MLStyle Replace every `f(args...; kwargs..)` with `mycall(f, args...; kwargs...)` """ -function callify(mycall, ast) +function callify(g, ast) leaf(x) = x function branch(f, head, args) default() = Expr(head, map(f, args)...) + + # Convert `for` to `while` + if head == :for + arg1 = args[1] + @assert arg1.head == :(=) + a,A0 = arg1.args + A0 = callify(g, A0) + @gensym temp + @gensym state + @gensym A + return quote + $A = $A0 + $temp = $call($g, iterate, $A) + while $temp !== nothing + $a, $state = $temp + $(args[2]) + $temp = $call($g, iterate, $A, $state) + end + end + end + head == :call || return default() if first(args) == :~ && length(args) == 3 @@ -16,71 +37,19 @@ function callify(mycall, ast) end # At this point we know it's a function call - length(args) == 1 && return Expr(:call, mycall, first(args)) + length(args) == 1 && return Expr(:call, call, g, first(args)) fun = args[1] arg2 = args[2] if arg2 isa Expr && arg2.head == :parameters # keyword arguments (try dump(:(f(x,y;a=1, b=2))) to see this) - return Expr(:call, mycall, arg2, fun, map(f, Base.rest(args, 3))...) + return Expr(:call, call, g, arg2, fun, map(f, Base.rest(args, 3))...) else - return Expr(:call, mycall, map(f, args)...) + return Expr(:call, call, g, map(f, args)...) end end - foldast(leaf, branch)(ast) + foldast(leaf, branch)(ast) |> MacroTools.flatten end -# struct Provenance{T,S} -# value::T -# sources::S -# end - -# getvalue(p::Provenance) = p.value -# getvalue(x) = x - -# getsources(p::Provenance) = p.sources -# getsources(x) = Set() - -# function trace_provenance(f, args...; kwargs...) -# (newargs, arg_sources) = (getvalue.(args), union(getsources.(args)...)) - -# k = keys(kwargs) -# v = values(kwargs) -# newkwargs = NamedTuple{k}(map(getvalue, v)) - -# k = keys(kwargs) -# v = values(NamedTuple(kwargs)) -# newkwargs = NamedTuple{k}(getvalue.(v)) -# kwarg_sources = union(getsources.(args)...) - -# sources = union(arg_sources, kwarg_sources) -# Provenance(f(newargs...; newkwargs), sources) -# end - -# macro call(expr) -# callify(expr) -# end - -# julia> callify(:(f(g(x,y)))) -# :(call(f, call(g, x, y))) - -# julia> callify(:(f(x; a=3))) -# :(call(f, x; a = 3)) - -# julia> callify(:(a+b)) -# :(call(+, a, b)) - -# julia> callify(:(call(f,3))) -# :(call(f, 3)) - -# f(x) = x+1 - -# @call f(2) - -# using SymbolicUtils - -# @syms x::Vector{Float64} i::Int - -# @call getindex(x,i) diff --git a/src/core/models/astmodel.jl b/src/core/models/astmodel.jl new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/core/models/astmodel.jl @@ -0,0 +1 @@ + diff --git a/src/core/models/astmodel/astmodel.jl b/src/core/models/astmodel/astmodel.jl deleted file mode 100644 index ff81810..0000000 --- a/src/core/models/astmodel/astmodel.jl +++ /dev/null @@ -1,82 +0,0 @@ -struct Model{A,B,M<:GG.TypeLevel} <: AbstractModel{A,B,M} - args::Vector{Symbol} - body::Expr -end - -function Model(theModule::Module, args::Vector{Symbol}, body::Expr) - A = NamedTuple{Tuple(args)} - - B = to_type(body) - M = to_type(theModule) - return Model{A,B,M}(args, body) -end - -model(m::Model) = m - -# ModelClosure{A,B,M,Args,Obs} <: AbstractModel{A,B,M,Argvals,Obs} -# model::Model{A,B,M} -# argvals :: Argvals -# obs :: Obs -# end - -function Base.convert(::Type{Expr}, m::Model) - numArgs = length(m.args) - args = if numArgs == 1 - m.args[1] - elseif numArgs > 1 - Expr(:tuple, [x for x in m.args]...) - end - - body = m.body - - q = if numArgs == 0 - @q begin - @model $body - end - else - @q begin - @model $(args) $body - end - end - - striplines(q).args[1] -end - -Base.show(io::IO, m::Model) = println(io, convert(Expr, m)) - -function type2model(::Type{Model{A,B,M}}) where {A,B,M} - args = [fieldnames(A)...] - body = from_type(B) - Model(from_type(M), convert(Vector{Symbol}, args), body) -end - -# julia> using Tilde, MeasureTheory - -# julia> m = @model begin -# p ~ Uniform() -# x ~ Bernoulli(p) |> iid(3) -# end; - -# julia> f = interpret(m); - -# julia> f(NamedTuple()) do x,d,ctx -# r = rand(d) -# (r, merge(ctx, NamedTuple{(x,)}((r,)))) -# end -# (p = 0.3863623559358842, x = Bool[0, 0, 0]) - -# julia> f(0) do x,d,n -# r = rand(d) -# (r, n+1) -# end -# 2 - -# julia> f -# function = (_tilde, _ctx0;) -> begin -# begin -# _ctx = _ctx0 -# (p, _ctx) = _tilde(:p, (Main).Uniform(), _ctx) -# (x, _ctx) = _tilde(:x, (Main).:|>((Main).Bernoulli(p), (Main).iid(3)), _ctx) -# return _ctx -# end -# end diff --git a/src/core/models/model.jl b/src/core/models/model.jl index ed4fe58..48944e3 100644 --- a/src/core/models/model.jl +++ b/src/core/models/model.jl @@ -1,3 +1,47 @@ +struct Model{A,B,M<:GG.TypeLevel} <: AbstractModel{A,B,M} + args::Vector{Symbol} + body::Expr +end + +function Model(theModule::Module, args::Vector{Symbol}, body::Expr) + A = NamedTuple{Tuple(args)} + B = to_type(body) + M = to_type(theModule) + return Model{A,B,M}(args, body) +end + +model(m::Model) = m + +function Base.convert(::Type{Expr}, m::Model) + numArgs = length(m.args) + args = if numArgs == 1 + m.args[1] + elseif numArgs > 1 + Expr(:tuple, [x for x in m.args]...) + end + + body = m.body + + q = if numArgs == 0 + @q begin + @model $body + end + else + @q begin + @model $(args) $body + end + end + + striplines(q).args[1] +end + +Base.show(io::IO, m::Model) = println(io, convert(Expr, m)) + +function type2model(::Type{Model{A,B,M}}) where {A,B,M} + args = Symbol[fieldnames(A)...] + body = from_type(B) + Model{A,B,M}(args, body) +end toargs(vs::Vector{Symbol}) = Tuple(vs) toargs(vs::NTuple{N,Symbol} where {N}) = vs diff --git a/src/core/utils.jl b/src/core/utils.jl index 60b1878..3028ea9 100644 --- a/src/core/utils.jl +++ b/src/core/utils.jl @@ -101,10 +101,6 @@ end import MacroTools: striplines, @q -# function arguments(model::DAGModel) -# model.args -# end - allequal(xs) = all(xs[1] .== xs) # # fold example usage: @@ -130,20 +126,19 @@ allequal(xs) = all(xs[1] .== xs) # # (s = [0.545324, 0.281332, 0.418541, 0.485946], a = 2.217762640580984) # From https://github.com/thautwarm/MLStyle.jl/issues/66 -@active LamExpr(x) begin - @match x begin - :($a -> begin - $(bs...) - end) => let exprs = filter(x -> !(x isa LineNumberNode), bs) - if length(exprs) == 1 - (a, exprs[1]) - else - (a, Expr(:block, bs...)) - end - end - _ => nothing - end -end +# @active LamExpr(x) begin +# @match x begin +# :($a -> begin $(bs...) end) => +# let exprs = filter(x -> !(x isa LineNumberNode), bs) +# if length(exprs) == 1 +# (a, exprs[1]) +# else +# (a, Expr(:block, bs...)) +# end +# end +# _ => nothing +# end +# end # using BenchmarkTools # f(;kwargs...) = kwargs[:a] + kwargs[:b] @@ -156,35 +151,12 @@ end # @__MODULE__ # names -# getprototype(::Type{NamedTuple{(),Tuple{}}}) = NamedTuple() -getprototype(::Type{NamedTuple{N,T} where {T<:Tuple}}) where {N} = NamedTuple{N} -getprototype(::NamedTuple{N,T} where {T<:Tuple}) where {N} = NamedTuple{N} - -function loadvals(argstype, obstype) - args = getntkeys(argstype) - obs = getntkeys(obstype) - loader = @q begin end - - for k in args - push!(loader.args, :($k = _args.$k)) - end - for k in obs - push!(loader.args, :($k = _obs.$k)) - end - - src -> (@q begin - $loader - $src - end) |> MacroTools.flatten -end - function loadvals(argstype, obstype, parstype) args = schema(argstype) data = schema(obstype) pars = schema(parstype) - loader = @q begin - end + loader = @q begin end for k in keys(args) ∪ keys(pars) ∪ keys(data) push!(loader.args, :(local $k)) @@ -223,11 +195,6 @@ function loadvals(argstype, obstype, parstype) end) |> MacroTools.flatten end -getntkeys(::NamedTuple{A,B}) where {A,B} = A -getntkeys(::Type{NamedTuple{A,B}}) where {A,B} = A -getntkeys(::Type{NamedTuple{A}}) where {A} = A -getntkeys(::Type{LazyMerge{X,Y}}) where {X,Y} = Tuple(getntkeys(X) ∪ getntkeys(Y)) - # This is just handy for REPLing, no direct connection to Tilde # julia> tower(Int) diff --git a/src/maybeobserved.jl b/src/maybeobserved.jl new file mode 100644 index 0000000..96abb12 --- /dev/null +++ b/src/maybeobserved.jl @@ -0,0 +1,16 @@ +abstract type MaybeObserved{N,T} end + +struct Observed{N,T} <: MaybeObserved{N,T} + value::T +end + +Observed{N}(x::T) where {N,T} = Observed{N,T}(x) + +struct Unobserved{N,T} <: MaybeObserved{N,T} + value::T +end + +Unobserved{N}(x::T) where {N,T} = Unobserved{N,T}(x) +NamedTuple(o::MaybeObserved{N,T}) where {N,T} = NamedTuple{(N,)}((o.value,)) + +value(obj::MaybeObserved) = obj.value diff --git a/src/optics.jl b/src/optics.jl index 7ddaea7..8de996a 100644 --- a/src/optics.jl +++ b/src/optics.jl @@ -43,6 +43,17 @@ end end end +# @inline function _setindex!(o::AbstractArray{T}, val::T, l::Lens!!{<:IndexLens}) where {T} +# setindex!(o, val, l.pure.indices...) +# end + +# # Attempting to set a value outside the current eltype widens the eltype +# @inline function _setindex!(o::AbstractArray{T}, val::V, l::Lens!!{<:IndexLens}) where {T,V} +# new_o = similar(o, Union{T,V}) +# new_o .= o +# setindex!(new_o, val, l.pure.indices...) +# end + @inline function Accessors.modify(f, o, l::Lens!!) set(o, l, f(l(o))) end diff --git a/src/primitives/basemeasure.jl b/src/primitives/basemeasure.jl index bf5eeae..9f16f83 100644 --- a/src/primitives/basemeasure.jl +++ b/src/primitives/basemeasure.jl @@ -6,38 +6,15 @@ end @inline function tilde( ::typeof(basemeasure), + x::MaybeObserved{X}, lens, - xname, - x, d, cfg, ctx::NamedTuple, - _, - ::True, -) - xname = dynamic(xname) - xparent = getproperty(cfg.obs, xname) +) where {X} + xparent = getproperty(cfg.obs, X) x = lens(xparent) b = basemeasure(d, x) - ctx = merge(ctx, NamedTuple{(xname,)}((b,))) - (x, ctx, productmeasure(ctx)) -end - -@inline function tilde( - ::typeof(basemeasure), - lens, - xname, - x, - d, - cfg, - ctx::NamedTuple, - _, - ::False, -) - xname = dynamic(xname) - xparent = getproperty(cfg.pars, xname) - x = getproperty(cfg.pars, xname) - b = basemeasure(d, x) - ctx = merge(ctx, NamedTuple{(xname,)}((b,))) + ctx = merge(ctx, NamedTuple{(X,)}((b,))) (x, ctx, productmeasure(ctx)) end diff --git a/src/primitives/dag.jl b/src/primitives/dag.jl new file mode 100644 index 0000000..f0afb6f --- /dev/null +++ b/src/primitives/dag.jl @@ -0,0 +1,89 @@ +abstract type AbstractContext end + +struct GenericContext{T,M} <: AbstractContext + value::T + meta::M +end + +struct EmptyMeta end + +context(value, meta) = GenericContext(value, meta) +context(value) = GenericContext(value, EmptyMeta()) + +context_value(ctx::AbstractContext) = ctx.value +context_meta(ctx::AbstractContext) = ctx.meta + +export getdag + +using Graphs +using MetaGraphsNext + +struct MarkovContext{T} <: AbstractContext + value::T + meta::Set{Tuple{Symbol,Any}} +end + +function MarkovContext(ctx::MarkovContext, m::Set{Tuple{Symbol,Any}}) + newset = union(ctx.meta, m) + MarkovContext(ctx.value, newset) +end + +function MarkovContext(ctx::MarkovContext, m::Set{Tuple{Symbol,T}}) where {T} + newset = union(ctx.meta, Set{Tuple{Symbol,Any}}([m])) + MarkovContext(ctx.value, newset) +end + +function Base.show(io::IO, mc::MarkovContext) + print(io, "MarkovContext(", mc.value, ", ", mc.meta, ")") +end + +function markovinate(nt::NamedTuple{N,T}) where {N,T} + vals = tuple( + ( + MarkovContext(v, Set{Tuple{Symbol,Any}}([(k, identity)])) for + (k, v) in pairs(nt) + )..., + ) + NamedTuple{N}(vals) +end + +MarkovContext(x::MarkovContext) = x +MarkovContext(x) = MarkovContext(x, Set{Tuple{Symbol,Any}}()) + +markov_value(x) = x +markov_parents(x) = Set{Tuple{Symbol,Any}}() + +markov_value(x::MarkovContext) = x.value +markov_parents(x::MarkovContext) = x.meta + +function getdag(m::AbstractConditionalModel, pars) + cfg = NamedTuple() + pars = markovinate(pars) + ctx = (dag = MetaGraph(DiGraph(), Label = Tuple{Symbol,Any}),) + ctx = gg_call(getdag, m, pars, cfg, ctx, (r, ctx) -> ctx) + return ctx.dag +end + +# When a Tilde primitive `f` is called, every `g(args...)` is converted to +# `call(f, g, args...)` +function call(::typeof(getdag), g, args...) + val = g(map(markov_value, args)...) + parents = if isempty(args) + Set{Tuple{Symbol,Any}}() + else + union(map(markov_parents, args)...) + end + MarkovContext(val, parents) +end + +@inline function tilde(::typeof(getdag), x::MaybeObserved{X}, lens, d, pars, ctx) where {X} + dag = ctx.dag + for p in markov_parents(d) + # Make sure vertices exist + dag[p] = nothing + dag[(X, lens)] = nothing + # Add a new edge in the DAG + dag[p, (X, lens)] = nothing + end + (MarkovContext(value(x), Set{Tuple{Symbol,Any}}([(X, lens)])), ctx, dag) +end diff --git a/src/primitives/insupport.jl b/src/primitives/insupport.jl index 71d869d..228c53e 100644 --- a/src/primitives/insupport.jl +++ b/src/primitives/insupport.jl @@ -2,5 +2,5 @@ import MeasureBase: insupport export insupport @inline function insupport(m::AbstractConditionalModel, x::NamedTuple) - mapreduce(insupport, (a, b) -> a && b, measures!(m, x), x) + mapreduce(insupport, (a, b) -> a && b, measures(m), x) end diff --git a/src/primitives/interpret.jl b/src/primitives/interpret.jl index 4d6fe0d..d2c4c22 100644 --- a/src/primitives/interpret.jl +++ b/src/primitives/interpret.jl @@ -13,13 +13,7 @@ function make_body(M, f, m::AbstractModel) make_body(M, body(m)) end -struct Observed{T} - value::T -end - -struct Unobserved{T} - value::T -end +call(f, g, args...; kwargs...) = g(args...; kwargs...) function make_body(M, f, ast::Expr, retfun, argsT, obsT, parsT) knownvars = union(keys.(schema.((argsT, obsT, parsT)))...) @@ -43,18 +37,22 @@ function make_body(M, f, ast::Expr, retfun, argsT, obsT, parsT) # X = to_type(unsolved_lhs) # M = to_type(unsolve(rhs)) - inargs = inkeys(sx, argsT) + # inargs = inkeys(sx, argsT) inobs = inkeys(sx, obsT) - inpars = inkeys(sx, parsT) + # inpars = inkeys(sx, parsT) rhs = unsolve(rhs) - xval = if inobs - :($Observed($x)) + obj = if inobs + :($Observed{$qx}($x)) else - (x ∈ knownvars ? :($Unobserved($x)) : :($Unobserved(missing))) + (if x ∈ knownvars + :($Unobserved{$qx}($x)) + else + :($Unobserved{$qx}(missing)) + end) end - st = :(($x, _ctx, _retn) = $tilde($f, $l, $sx, $xval, $rhs, _cfg, _ctx)) - qst = QuoteNode(st) + st = :(($x, _ctx, _retn) = $tilde($f, $obj, $l, $rhs, _cfg, _ctx)) + # qst = QuoteNode(st) q = quote # println($qst) $st @@ -76,9 +74,10 @@ function make_body(M, f, ast::Expr, retfun, argsT, obsT, parsT) end end - body = go(@q begin - $(solve_scope(opticize(ast))) - end) |> unsolve |> MacroTools.flatten + body = + go(@q begin + $(solve_scope(opticize(callify(f, ast)))) + end) |> unsolve |> MacroTools.flatten body end diff --git a/src/primitives/logdensity.jl b/src/primitives/logdensity.jl index ce44c37..6399410 100644 --- a/src/primitives/logdensity.jl +++ b/src/primitives/logdensity.jl @@ -18,8 +18,15 @@ using Accessors gg_call(logdensityof, cm, pars, cfg, ctx, retfun) end -@inline function tilde(::typeof(logdensityof), lens, xname, x, d, cfg, ctx::NamedTuple) - x = x.value +@inline function tilde( + ::typeof(logdensityof), + x::MaybeObserved{X}, + lens, + d, + cfg, + ctx::NamedTuple, +) where {X} + x = value(x) insupport(d, lens(x)) || return (x, ctx, ReturnNow(-Inf)) @reset ctx.ℓ += MeasureBase.unsafe_logdensityof(d, lens(x)) (x, ctx, nothing) @@ -39,14 +46,13 @@ end @inline function tilde( ::typeof(unsafe_logdensityof), + x::MaybeObserved{X}, lens, - xname, - x, d, cfg, ctx::NamedTuple, -) - x = x.value +) where {X} + x = value(x) @reset ctx.ℓ += MeasureBase.unsafe_logdensityof(d, lens(x)) - (x, ctx, ctx.ℓ) -end + (x, ctx, nothing) +end \ No newline at end of file diff --git a/src/primitives/measures.jl b/src/primitives/measures.jl index 37f3e9d..ec42b14 100644 --- a/src/primitives/measures.jl +++ b/src/primitives/measures.jl @@ -61,32 +61,21 @@ export measures rmap(f, nt) end -@inline function tilde( - ::typeof(measures), - ::typeof(identity), - xname, - ::Unobserved, - d, - cfg, - ctx, -) +@inline function tilde(::typeof(measures), x::Unobserved{X}, d, cfg, ctx) where {X} x = testvalue(d) - xname = dynamic(xname) - ctx = merge(ctx, NamedTuple{(xname,)}((d,))) + ctx = merge(ctx, NamedTuple{X}((d,))) (x, ctx, ctx) end -@inline function tilde(::typeof(measures), lens, xname, x::Unobserved, d, cfg, ctx) - xname = dynamic(xname) - ctx = set(ctx, PropertyLens{xname}() ⨟ Lens!!(lens), d) +@inline function tilde(::typeof(measures), x::Unobserved{X}, lens, d, cfg, ctx) where {X} + ctx = set(ctx, PropertyLens{X}() ⨟ Lens!!(lens), d) - xnew = getproperty(cfg.pars, xname) + xnew = getproperty(cfg.pars, X) (xnew, ctx, ctx) end -@inline function tilde(::typeof(measures), lens, xname, x::Observed, d, cfg, ctx) - x = x.value - (x, ctx, ctx) +@inline function tilde(::typeof(measures), x::Observed{X}, lens, d, cfg, ctx) where {X} + (value(x), ctx, ctx) end function as(mdl::AbstractConditionalModel) diff --git a/src/primitives/predict.jl b/src/primitives/predict.jl new file mode 100644 index 0000000..df40685 --- /dev/null +++ b/src/primitives/predict.jl @@ -0,0 +1,103 @@ +using Random: GLOBAL_RNG +using TupleVectors +export predict + + +anyfy(x) = x +anyfy(x::AbstractArray) = collect(Any, x) + +function anyfy(mc::ModelClosure) + m = model(mc) + a = rmap(anyfy, argvals(mc)) + m(a) +end + +function anyfy(mp::ModelPosterior) + m = model(mp) + a = rmap(anyfy, argvals(mp)) + o = rmap(anyfy, observations(mp)) + m(a) | o +end + +@inline function predict(m::AbstractConditionalModel, pars) + f(d, x) = rand(GLOBAL_RNG, d) + return predict(f, m, pars) +end + +@inline function predict(rng::AbstractRNG, m::AbstractConditionalModel, pars) + f(d, x) = rand(rng, d) + return predict(f, m, pars) +end + +@inline function predict(f, m::AbstractConditionalModel, pars::NamedTuple) + m = anyfy(m) + pars = rmap(anyfy, pars) + cfg = (f = f, pars = pars) + ctx = NamedTuple() + gg_call(predict, m, pars, cfg, ctx, (r, ctx) -> r) +end + +@inline function predict(f, m::AbstractConditionalModel, tv::TupleVector) + n = length(tv) + @inbounds result = chainvec(predict(f, m, tv[1]), n) + @inbounds for j in 2:n + result[j] = predict(f, m, tv[j]) + end + return result +end + +@inline function tilde(::typeof(predict), x, lens, d, cfg, ctx) + tilde_predict(cfg.f, x, lens, d, cfg.pars, ctx) +end + +@generated function tilde_predict( + f, + x::Observed{X}, + lens, + d, + pars::NamedTuple{N}, + ctx, +) where {X,N} + if X ∈ N + quote + # @info "$X ∈ N" + xnew = set(x.value, Lens!!(lens), lens(getproperty(pars, X))) + # ctx = merge(ctx, NamedTuple{(X,)}((xnew,))) + (xnew, ctx, ctx) + end + else + quote + # @info "$X ∉ N" + x = x.value + xnew = set(copy(x), Lens!!(lens), f(d, lens(x))) + ctx = merge(ctx, NamedTuple{(X,)}((xnew,))) + (xnew, ctx, ctx) + end + end +end + +@generated function tilde_predict( + f, + x::Unobserved{X}, + lens, + d, + pars::NamedTuple{N}, + ctx, +) where {X,N} + if X ∈ N + quote + # @info "$X ∈ N" + xnew = set(value(x), Lens!!(lens), lens(getproperty(pars, X))) + # ctx = merge(ctx, NamedTuple{(X,)}((xnew,))) + (xnew, ctx, ctx) + end + else + quote + # @info "$X ∉ N" + # In this case x == Unobserved(missing) + xnew = set(value(x), Lens!!(lens), f(d, missing)) + ctx = merge(ctx, NamedTuple{(X,)}((xnew,))) + (xnew, ctx, ctx) + end + end +end diff --git a/src/primitives/rand.jl b/src/primitives/rand.jl index 5f16d01..2adc6e4 100644 --- a/src/primitives/rand.jl +++ b/src/primitives/rand.jl @@ -4,27 +4,73 @@ using TupleVectors: chainvec export rand EmptyNTtype = NamedTuple{(),Tuple{}} where {T<:Tuple} -@inline function Base.rand(rng::AbstractRNG, d::AbstractConditionalModel, N::Int) - r = chainvec(rand(rng, d), N) +@inline function Base.rand(m::ModelClosure, args...; kwargs...) + rand(GLOBAL_RNG, Float64, m, args...; kwargs...) +end + +@inline function Base.rand(rng::AbstractRNG, m::ModelClosure, args...; kwargs...) + rand(rng, Float64, m, args...; kwargs...) +end + +@inline function Base.rand(::Type{T_rng}, m::ModelClosure, args...; kwargs...) where {T_rng} + rand(GLOBAL_RNG, T_rng, m, args...; kwargs...) +end + +@inline function Base.rand(m::ModelClosure, d::Integer, dims::Integer...; kwargs...) + rand(GLOBAL_RNG, Float64, m, d, dims...; kwargs...) +end + +@inline function Base.rand( + rng::AbstractRNG, + m::ModelClosure, + d::Integer, + dims::Integer...; + kwargs..., +) + rand(rng, Float64, m, d, dims...; kwargs...) +end + +@inline function Base.rand( + ::Type{T_rng}, + m::ModelClosure, + d::Integer, + dims::Integer...; + kwargs..., +) where {T_rng} + rand(GLOBAL_RNG, T_rng, m, d, dims...; kwargs...) +end + +@inline function Base.rand( + rng::AbstractRNG, + ::Type{T_rng}, + d::ModelClosure, + N::Integer, + v::Vararg{Integer}, +) where {T_rng} + @assert isempty(v) + r = chainvec(rand(rng, T_rng, d), N) for j in 2:N - @inbounds r[j] = rand(rng, d) + @inbounds r[j] = rand(rng, T_rng, d) end return r end -@inline Base.rand(d::AbstractConditionalModel, N::Int) = rand(GLOBAL_RNG, d, N) +@inline Base.rand(d::ModelClosure, N::Int) = rand(GLOBAL_RNG, d, N) -@inline function Base.rand(m::AbstractConditionalModel; kwargs...) +@inline function Base.rand(m::ModelClosure; kwargs...) rand(GLOBAL_RNG, m; kwargs...) end +@inline Base.rand(rng::AbstractRNG, m::ModelClosure) = rand(rng, Float64, m) + @inline function Base.rand( rng::AbstractRNG, - m::AbstractConditionalModel; + ::Type{T_rng}, + m::ModelClosure; ctx = NamedTuple(), retfun = (r, ctx) -> r, -) - cfg = (rng = rng,) +) where {T_rng} + cfg = (rng = rng, T_rng = T_rng) gg_call(rand, m, NamedTuple(), cfg, ctx, retfun) end @@ -32,57 +78,24 @@ end # ctx::NamedTuple @inline function tilde( ::typeof(Base.rand), + x::Unobserved{X}, lens, - xname, - x::Unobserved, d, cfg, ctx::NamedTuple, -) - xnew = set(x.value, Lens!!(lens), rand(cfg.rng, d)) - ctx′ = merge(ctx, NamedTuple{(dynamic(xname),)}((xnew,))) +) where {X} + xnew = set(value(x), Lens!!(lens), rand(cfg.rng, d)) + ctx′ = merge(ctx, NamedTuple{(X,)}((xnew,))) (xnew, ctx′, nothing) end @inline function tilde( ::typeof(Base.rand), + x::Observed{X}, lens, - xname, - x::Observed, d, cfg, ctx::NamedTuple, -) - (x.value, ctx, nothing) -end - -############################################################################### -# ctx::Dict - -@inline function tilde( - ::typeof(Base.rand), - lens::typeof(identity), - xname, - x, - d, - cfg, - ctx::Dict, -) - x = rand(cfg.rng, d) - ctx[dynamic(xname)] = x - (x, ctx, nothing) -end - -@inline function tilde( - ::typeof(Base.rand), - lens, - xname, - x, - m::AbstractConditionalModel, - cfg, - ctx::Dict, -) - args = get(cfg.args, dynamic(xname), Dict()) - cfg = merge(cfg, (args = args,)) - tilde(rand, lens, xname, x, m(cfg.args), cfg, ctx) +) where {X} + (value(x), ctx, nothing) end diff --git a/src/primitives/testparams.jl b/src/primitives/testparams.jl index ca9893c..c32297e 100644 --- a/src/primitives/testparams.jl +++ b/src/primitives/testparams.jl @@ -14,22 +14,26 @@ end @inline function tilde( ::typeof(testparams), + x::MaybeObserved{X}, lens::typeof(identity), - xname, - x, d, cfg, ctx::NamedTuple, - _, - _, -) +) where {X} xnew = testparams(d) - ctx′ = merge(ctx, NamedTuple{(dynamic(xname),)}((xnew,))) + ctx′ = merge(ctx, NamedTuple{(X,)}((xnew,))) (xnew, ctx′, ctx′) end -@inline function tilde(::typeof(testparams), lens, xname, x, d, cfg, ctx::NamedTuple, _, _) +@inline function tilde( + ::typeof(testparams), + x::MaybeObserved{X}, + lens, + d, + cfg, + ctx::NamedTuple, +) where {X} xnew = set(x, Lens!!(lens), testparams(d)) - ctx′ = merge(ctx, NamedTuple{(dynamic(xname),)}((xnew,))) + ctx′ = merge(ctx, NamedTuple{(X,)}((xnew,))) (xnew, ctx′, ctx′) end diff --git a/src/primitives/testvalue.jl b/src/primitives/testvalue.jl index acc46b1..967e682 100644 --- a/src/primitives/testvalue.jl +++ b/src/primitives/testvalue.jl @@ -14,26 +14,24 @@ end @inline function tilde( ::typeof(testvalue), + x::Unobserved{X}, lens, - xname, - x::Unobserved, d, cfg, ctx::NamedTuple, -) - xnew = set(x.value, Lens!!(lens), testvalue(d)) - ctx′ = merge(ctx, NamedTuple{(dynamic(xname),)}((xnew,))) +) where {X} + xnew = set(value(x), Lens!!(lens), testvalue(d)) + ctx′ = merge(ctx, NamedTuple{(X,)}((xnew,))) (xnew, ctx′, nothing) end @inline function tilde( ::typeof(testvalue), + x::Observed{X}, lens, - xname, - x::Observed, d, cfg, ctx::NamedTuple, -) - (x.value, ctx, nothing) +) where {X} + (lens(value(x)), ctx, nothing) end diff --git a/src/primitives/weightedsampling.jl b/src/primitives/weightedsampling.jl index 5a90c7d..fa80a1c 100644 --- a/src/primitives/weightedsampling.jl +++ b/src/primitives/weightedsampling.jl @@ -18,15 +18,13 @@ end @inline function tilde( ::typeof(weightedrand), + x::Observed{X}, lens, - xname, - x::Observed, d, cfg, ctx::NamedTuple, -) - x = x.value - xname = dynamic(xname) +) where {X} + x = value(x) Δℓ = logdensityof(d, lens(x)) @reset ctx.ℓ += Δℓ (x, ctx, ctx) @@ -34,15 +32,14 @@ end @inline function tilde( ::typeof(weightedrand), + x::Unobserved{X}, lens, - xname, - x::Unobserved, d, cfg, ctx::NamedTuple, -) - xnew = set(x.value, Lens!!(lens), rand(cfg.rng, d)) - pars = merge(ctx.pars, NamedTuple{(dynamic(xname),)}((xnew,))) +) where {X} + xnew = set(value(x), Lens!!(lens), rand(cfg.rng, d)) + pars = merge(ctx.pars, NamedTuple{(X,)}((xnew,))) ctx = merge(ctx, (pars = pars,)) (xnew, ctx, nothing) end diff --git a/test/runtests.jl b/test/runtests.jl index 8f57c3b..86827a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ import TransformVariables as TV using Aqua using Tilde -Aqua.test_all(Tilde; ambiguities=false, unbound_args=false) +Aqua.test_all(Tilde; ambiguities=false) include("examples-list.jl") diff --git a/test/transforms.jl b/test/transforms.jl index f50f45d..ea80567 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -1,12 +1,3 @@ -# Check for Model equality up to reorderings of a few fields -function ≊(m1::DAGModel,m2::DAGModel) - function eq_tuples(nt1::NamedTuple,nt2::NamedTuple) - return length(nt1)==length(nt2) && all(nt1[k]==nt2[k] for k in keys(nt1)) - end - return Set(arguments(m1))==Set(arguments(m2)) && m1.retn==m2.retn && eq_tuples(m1.dists,m2.dists) && eq_tuples(m1.vals,m2.vals) -end - - m = @model (n,α,β) begin p ~ Beta(α, β) x ~ Binomial(n, p)