Skip to content

nsiccha/NUTS.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NUTS.jl

A non-allocating NUTS implementation. Faster than and equivalent to Stan's default implementation, DynamicHMC.jl's implementation, and AdvancedHMC.jl's HMCKernel(Trajectory{MultinomialTS}(Leapfrog(stepsize), StrictGeneralisedNoUTurn())).

For a 100 dimensional standard normal target with unit stepsize and 1k samples, I measure it to be ~5x slower than direct sampling (randn!(...)), ~6x faster than DynamicHMC, ~15x faster than AdvancedHMC and ~25x faster than Stan.jl. For most other posteriors the computational cost will be dominated by the cost of evaluating the log density gradient, so any real world speed-ups should be smaller.

Usage

Exports a single function, nuts!!(state). Use e.g. as

nuts_sample!(samples, rng, posterior; stepsize, position=randn(rng, size(samples, 1)), n_samples=size(samples, 2)) = begin
    state = (;rng, posterior, stepsize, position)
    for i in 1:n_samples
        state = nuts!!(state)
        samples[:, i] .= state.position
    end
    state
end

where posterior has to implement log_density = NUTS.log_density_gradient!(posterior, position, log_density_gradient), i.e. it returns the log density and writes its gradient into log_density_gradient.

Releases

No releases published

Packages

No packages published

Languages