diff --git a/docs/functions/other.rst b/docs/functions/other.rst index 763091ec..96097428 100644 --- a/docs/functions/other.rst +++ b/docs/functions/other.rst @@ -134,3 +134,54 @@ Other Constructs a :js:func:`KDE` distribution from a sample based marginal distribution. + +.. js:function:: AIS(model[, options]) + + Returns an estimate of the log of the normalization constant of + ``model``. This is not an unbiased estimator, rather it is a + stochastic lower bound. [grosse16]_ + + The sequence of intermediate distributions used by AIS is obtained + by scaling the contribution to the overall score made by the + ``factor`` statements in ``model``. + + When a model includes hard factors (e.g. ``factor(-Infinity)``, + ``condition(bool)``) this approach does not produce an estimate of + the expected quantity. Hence, to avoid confusion, an error is + generated by ``AIS`` if a hard factor is encountered in the model. + + The length of the sequence of distributions is given by the + ``steps`` option. At step ``k`` the score given by each ``factor`` + is scaled by ``k / steps``. + + + The MCMC transition operator used is based on the :ref:`MH kernel + `. + + + The following options are supported: + + .. describe:: steps + + The length of the sequence of intermediate distributions. + + Default: ``20`` + + .. describe:: samples + + The number of times the AIS procedure is repeated. ``AIS`` + returns the average of the log of the estimates produced by the + individual runs. + + Default: ``1`` + + Example usage:: + + AIS(model, {samples: 100, steps: 100}) + +.. rubric:: Bibliography + +.. [grosse16] Grosse, Roger B., Siddharth Ancha, and Daniel M. Roy. + "Measuring the reliability of MCMC inference with + bidirectional Monte Carlo." Advances in Neural + Information Processing Systems. 2016. diff --git a/docs/inference/methods.rst b/docs/inference/methods.rst index ea4ae2ac..ec0b0d97 100644 --- a/docs/inference/methods.rst +++ b/docs/inference/methods.rst @@ -125,6 +125,8 @@ Kernels The following kernels are available: +.. _mh: + .. describe:: MH Implements single site Metropolis-Hastings. [wingate11]_ diff --git a/src/header.js b/src/header.js index 0d1e112c..b174fcc0 100644 --- a/src/header.js +++ b/src/header.js @@ -32,6 +32,7 @@ try { var smc = require('./inference/smc'); var rejection = require('./inference/rejection'); var incrementalmh = require('./inference/incrementalmh'); + var ais = require('./inference/ais'); var optimize = require('./inference/optimize'); var forwardSample = require('./inference/forwardSample'); var checkSampleAfterFactor = require('./inference/checkSampleAfterFactor'); @@ -186,7 +187,7 @@ module.exports = function(env) { // Inference functions and header utils var headerModules = [ enumerate, asyncpf, mcmc, incrementalmh, pmcmc, - smc, rejection, optimize, forwardSample, dreamSample, checkSampleAfterFactor, + smc, rejection, optimize, ais, forwardSample, dreamSample, checkSampleAfterFactor, headerUtils, params ]; headerModules.forEach(function(mod) { diff --git a/src/header.wppl b/src/header.wppl index 24211ae8..dd118488 100644 --- a/src/header.wppl +++ b/src/header.wppl @@ -473,6 +473,15 @@ var OptimizeThenSample = function(wpplFn, options) { return SampleGuide(wpplFn, opts); }; +// Make AIS available via Infer to allow AIS to be tested using +// existing inference tests. +var AISforInfer = function(wpplFn, options) { + var dummyMarginal = Infer(constF(true)); + // This mutates dummyMarginal. + _.assign(dummyMarginal, {normalizationConstant: AIS(wpplFn, options)}); + return dummyMarginal; +}; + /* * DefaultInfer() called when no options are specified * var maxEnumTreeSize: upper bound for enumeration tree size, enumeration ends when above threshold. @@ -538,6 +547,7 @@ var Infer = function(options, maybeFn) { incrementalMH: IncrementalMH, forward: ForwardSample, optimize: OptimizeThenSample, + AIS: AISforInfer, defaultInfer: DefaultInfer }; diff --git a/src/inference/ais.js b/src/inference/ais.js new file mode 100644 index 00000000..a599e025 --- /dev/null +++ b/src/inference/ais.js @@ -0,0 +1,76 @@ +// This closely follows the AIS implementation developed for WebPPL as +// part of "Measuring the reliability of MCMC inference with +// bidirectional Monte Carlo" (Grosse et al). + +// https://arxiv.org/abs/1606.02275 +// https://github.com/siddancha/webppl/tree/b607efe714d78c44f763ffd36324c0b67de96f56 + +'use strict'; + +var _ = require('lodash'); +var assert = require('assert'); +var util = require('../util'); +var numeric = require('../math/numeric'); + +module.exports = function(env){ + + var Initialize = require('./initialize')(env); + var kernels = require('./kernels')(env); + + function AIS(s, k, a, wpplFn, options) { + options = util.mergeDefaults(options, { + steps: 20, + samples: 1 + }); + + var weights = []; + + var singleSample = function(k) { + + var initialize, run, finish; + + initialize = function() { + return Initialize(run, wpplFn, s, env.exit, a, {}); + }; + + run = function(initialTrace) { + + var curStep = 0; + var increment = 1 / options.steps; + var weight = 0; + + var MHKernel = kernels.parseOptions('MH'); + + var mhStepKernel = function(k, trace) { + weight += increment * trace.scoreAllFactors(); + curStep += 1; + return MHKernel(k, trace, { + factorCoeff: curStep * increment, + allowHardFactors: false + }); + }; + + var mhChainKernel = kernels.repeat(options.steps, mhStepKernel); + + return mhChainKernel(function(trace) { + return k(weight); + }, initialTrace); + }; + + return initialize(); + }; + + return util.cpsLoop(options.samples, function(i, next) { + return singleSample(function(weight) { + weights.push(weight); + return next(); + }); + }, function() { + var avgWeight = numeric._sum(weights) / options.samples; + return k(s, avgWeight); + }); + } + + return {AIS: AIS}; + +}; diff --git a/src/inference/mhkernel.js b/src/inference/mhkernel.js index 7a4f36d1..9ad8e06e 100644 --- a/src/inference/mhkernel.js +++ b/src/inference/mhkernel.js @@ -30,12 +30,18 @@ module.exports = function(env) { runOpts = util.mergeDefaults(runOpts, { proposalBoundary: 0, - exitFactor: 0 + exitFactor: 0, + factorCoeff: 1, + allowHardFactors: true }); this.proposalBoundary = runOpts.proposalBoundary; this.exitFactor = runOpts.exitFactor; + this.factorCoeff = runOpts.factorCoeff; + assert.ok(0 <= this.factorCoeff && this.factorCoeff <= 1); + this.allowHardFactors = runOpts.allowHardFactors; + this.cont = cont; this.oldTrace = oldTrace; this.a = oldTrace.baseAddress; // Support relative addressing. @@ -61,6 +67,9 @@ module.exports = function(env) { MHKernel.prototype.factor = function(s, k, a, score) { // Optimization: Bail early if we know acceptProb will be zero. if (ad.value(score) === -Infinity) { + if (!this.allowHardFactors) { + throw new Error('Hard factor statements are not allowed.'); + } return this.finish(this.oldTrace, false); } this.trace.numFactors += 1; @@ -203,10 +212,20 @@ module.exports = function(env) { // assert(_.isNumber(ad.value(oldTrace.score))); // assert(_.isNumber(this.regenFrom)); // assert(_.isNumber(this.proposalBoundary)); - var fw = this.transitionProb(oldTrace, trace, this.fwdProposalDist); var bw = this.transitionProb(trace, oldTrace, this.revProposalDist); - var p = Math.exp(ad.value(trace.score) - ad.value(oldTrace.score) + bw - fw); + + var newTraceScore, oldTraceScore; + if (this.factorCoeff == 1) { + // Optimise for the common case. + newTraceScore = ad.value(trace.score); + oldTraceScore = ad.value(oldTrace.score); + } else { + newTraceScore = ad.value(trace.scoreAllChoices()) + this.factorCoeff * ad.value(trace.scoreAllFactors()); + oldTraceScore = ad.value(oldTrace.scoreAllChoices()) + this.factorCoeff * ad.value(oldTrace.scoreAllFactors()); + } + + var p = Math.exp(newTraceScore - oldTraceScore + bw - fw); assert(!isNaN(p)); return Math.min(1, p); }; diff --git a/src/trace.js b/src/trace.js index 46432a25..cfe2728c 100644 --- a/src/trace.js +++ b/src/trace.js @@ -59,6 +59,8 @@ Trace.prototype.addChoice = function(dist, val, address, store, continuation, op // assert(_.isObject(store)); // assert(_.isFunction(continuation)); + var choiceScore = dist.score(val); + var choice = { k: continuation, address: address, @@ -67,6 +69,7 @@ Trace.prototype.addChoice = function(dist, val, address, store, continuation, op // Record the score without adding the choiceScore. This is the score we'll // need if we regen from this choice. score: this.score, + choiceScore: choiceScore, val: val, store: _.clone(store), numFactors: this.numFactors @@ -75,10 +78,20 @@ Trace.prototype.addChoice = function(dist, val, address, store, continuation, op this.choices.push(choice); this.addressMap[address] = choice; this.length += 1; - this.score = ad.scalar.add(this.score, dist.score(val)); + this.score = ad.scalar.add(this.score, choiceScore); // this.checkConsistency(); }; +Trace.prototype.scoreAllChoices = function() { + return this.choices.reduce(function(acc, choice) { + return ad.scalar.add(acc, choice.choiceScore); + }, 0); +}; + +Trace.prototype.scoreAllFactors = function() { + return ad.scalar.sub(this.score, this.scoreAllChoices()); +}; + Trace.prototype.complete = function(value) { // Called at coroutine exit. assert.strictEqual(this.value, undefined); diff --git a/tests/test-data/stochastic/expected/partition1.json b/tests/test-data/stochastic/expected/partition1.json new file mode 100644 index 00000000..977fcfcc --- /dev/null +++ b/tests/test-data/stochastic/expected/partition1.json @@ -0,0 +1,3 @@ +{ + "logZ": 0 +} diff --git a/tests/test-data/stochastic/expected/partition2.json b/tests/test-data/stochastic/expected/partition2.json new file mode 100644 index 00000000..168fd70c --- /dev/null +++ b/tests/test-data/stochastic/expected/partition2.json @@ -0,0 +1,3 @@ +{ + "logZ": -0.6539264674066638 +} diff --git a/tests/test-data/stochastic/expected/partition3.json b/tests/test-data/stochastic/expected/partition3.json new file mode 100644 index 00000000..4f94f556 --- /dev/null +++ b/tests/test-data/stochastic/expected/partition3.json @@ -0,0 +1,3 @@ +{ + "logZ": -2.266 +} diff --git a/tests/test-data/stochastic/models/partition1.wppl b/tests/test-data/stochastic/models/partition1.wppl new file mode 100644 index 00000000..caa72c26 --- /dev/null +++ b/tests/test-data/stochastic/models/partition1.wppl @@ -0,0 +1,3 @@ +var model = function() { + return flip(); +}; diff --git a/tests/test-data/stochastic/models/partition2.wppl b/tests/test-data/stochastic/models/partition2.wppl new file mode 100644 index 00000000..25a66360 --- /dev/null +++ b/tests/test-data/stochastic/models/partition2.wppl @@ -0,0 +1,6 @@ +var model = function() { + var x = flip(0.2); + var y = flip(0.8); + observe(Bernoulli({p : x && y ? 0.1 : 0.6}), true); + return {x, y}; +}; diff --git a/tests/test-data/stochastic/models/partition3.wppl b/tests/test-data/stochastic/models/partition3.wppl new file mode 100644 index 00000000..10612bb0 --- /dev/null +++ b/tests/test-data/stochastic/models/partition3.wppl @@ -0,0 +1,7 @@ +// The exact normalization constant for this model is: +// exp(-1) / (2 * sqrt(pi)) +var model = function() { + var mu = gaussian(0, 1); + observe(Gaussian({mu, sigma: 1}), 2); + return mu; +}; diff --git a/tests/test-inference.js b/tests/test-inference.js index dc61305d..4f6fe8b5 100644 --- a/tests/test-inference.js +++ b/tests/test-inference.js @@ -754,6 +754,21 @@ var tests = [ dream2: true, dream3: true } + }, + { + name: 'AIS', + settings: { + args: { + steps: 500, + samples: 100 + }, + logZ: {check: true, tol: 0.05} + }, + models: { + partition1: {logZ: {check: true, tol: 1e-6}}, + partition2: true, + partition3: true + } } ];