Skip to content

Commit

Permalink
Merge pull request #905 from null-a/ais
Browse files Browse the repository at this point in the history
Add AIS
  • Loading branch information
stuhlmueller authored Jul 30, 2018
2 parents da4a023 + 50b8a98 commit 9e4d473
Show file tree
Hide file tree
Showing 14 changed files with 217 additions and 5 deletions.
51 changes: 51 additions & 0 deletions docs/functions/other.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,54 @@ Other

Constructs a :js:func:`KDE` distribution from a sample based
marginal distribution.

.. js:function:: AIS(model[, options])

Returns an estimate of the log of the normalization constant of
``model``. This is not an unbiased estimator, rather it is a
stochastic lower bound. [grosse16]_

The sequence of intermediate distributions used by AIS is obtained
by scaling the contribution to the overall score made by the
``factor`` statements in ``model``.

When a model includes hard factors (e.g. ``factor(-Infinity)``,
``condition(bool)``) this approach does not produce an estimate of
the expected quantity. Hence, to avoid confusion, an error is
generated by ``AIS`` if a hard factor is encountered in the model.

The length of the sequence of distributions is given by the
``steps`` option. At step ``k`` the score given by each ``factor``
is scaled by ``k / steps``.


The MCMC transition operator used is based on the :ref:`MH kernel
<mh>`.


The following options are supported:

.. describe:: steps

The length of the sequence of intermediate distributions.

Default: ``20``

.. describe:: samples

The number of times the AIS procedure is repeated. ``AIS``
returns the average of the log of the estimates produced by the
individual runs.

Default: ``1``

Example usage::

AIS(model, {samples: 100, steps: 100})

.. rubric:: Bibliography

.. [grosse16] Grosse, Roger B., Siddharth Ancha, and Daniel M. Roy.
"Measuring the reliability of MCMC inference with
bidirectional Monte Carlo." Advances in Neural
Information Processing Systems. 2016.
2 changes: 2 additions & 0 deletions docs/inference/methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ Kernels

The following kernels are available:

.. _mh:

.. describe:: MH

Implements single site Metropolis-Hastings. [wingate11]_
Expand Down
3 changes: 2 additions & 1 deletion src/header.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ try {
var smc = require('./inference/smc');
var rejection = require('./inference/rejection');
var incrementalmh = require('./inference/incrementalmh');
var ais = require('./inference/ais');
var optimize = require('./inference/optimize');
var forwardSample = require('./inference/forwardSample');
var checkSampleAfterFactor = require('./inference/checkSampleAfterFactor');
Expand Down Expand Up @@ -186,7 +187,7 @@ module.exports = function(env) {
// Inference functions and header utils
var headerModules = [
enumerate, asyncpf, mcmc, incrementalmh, pmcmc,
smc, rejection, optimize, forwardSample, dreamSample, checkSampleAfterFactor,
smc, rejection, optimize, ais, forwardSample, dreamSample, checkSampleAfterFactor,
headerUtils, params
];
headerModules.forEach(function(mod) {
Expand Down
10 changes: 10 additions & 0 deletions src/header.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,15 @@ var OptimizeThenSample = function(wpplFn, options) {
return SampleGuide(wpplFn, opts);
};

// Make AIS available via Infer to allow AIS to be tested using
// existing inference tests.
var AISforInfer = function(wpplFn, options) {
var dummyMarginal = Infer(constF(true));
// This mutates dummyMarginal.
_.assign(dummyMarginal, {normalizationConstant: AIS(wpplFn, options)});
return dummyMarginal;
};

/*
* DefaultInfer() called when no options are specified
* var maxEnumTreeSize: upper bound for enumeration tree size, enumeration ends when above threshold.
Expand Down Expand Up @@ -538,6 +547,7 @@ var Infer = function(options, maybeFn) {
incrementalMH: IncrementalMH,
forward: ForwardSample,
optimize: OptimizeThenSample,
AIS: AISforInfer,
defaultInfer: DefaultInfer
};

Expand Down
76 changes: 76 additions & 0 deletions src/inference/ais.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// This closely follows the AIS implementation developed for WebPPL as
// part of "Measuring the reliability of MCMC inference with
// bidirectional Monte Carlo" (Grosse et al).

// https://arxiv.org/abs/1606.02275
// https://github.com/siddancha/webppl/tree/b607efe714d78c44f763ffd36324c0b67de96f56

'use strict';

var _ = require('lodash');
var assert = require('assert');
var util = require('../util');
var numeric = require('../math/numeric');

module.exports = function(env){

var Initialize = require('./initialize')(env);
var kernels = require('./kernels')(env);

function AIS(s, k, a, wpplFn, options) {
options = util.mergeDefaults(options, {
steps: 20,
samples: 1
});

var weights = [];

var singleSample = function(k) {

var initialize, run, finish;

initialize = function() {
return Initialize(run, wpplFn, s, env.exit, a, {});
};

run = function(initialTrace) {

var curStep = 0;
var increment = 1 / options.steps;
var weight = 0;

var MHKernel = kernels.parseOptions('MH');

var mhStepKernel = function(k, trace) {
weight += increment * trace.scoreAllFactors();
curStep += 1;
return MHKernel(k, trace, {
factorCoeff: curStep * increment,
allowHardFactors: false
});
};

var mhChainKernel = kernels.repeat(options.steps, mhStepKernel);

return mhChainKernel(function(trace) {
return k(weight);
}, initialTrace);
};

return initialize();
};

return util.cpsLoop(options.samples, function(i, next) {
return singleSample(function(weight) {
weights.push(weight);
return next();
});
}, function() {
var avgWeight = numeric._sum(weights) / options.samples;
return k(s, avgWeight);
});
}

return {AIS: AIS};

};
25 changes: 22 additions & 3 deletions src/inference/mhkernel.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,18 @@ module.exports = function(env) {

runOpts = util.mergeDefaults(runOpts, {
proposalBoundary: 0,
exitFactor: 0
exitFactor: 0,
factorCoeff: 1,
allowHardFactors: true
});

this.proposalBoundary = runOpts.proposalBoundary;
this.exitFactor = runOpts.exitFactor;

this.factorCoeff = runOpts.factorCoeff;
assert.ok(0 <= this.factorCoeff && this.factorCoeff <= 1);
this.allowHardFactors = runOpts.allowHardFactors;

this.cont = cont;
this.oldTrace = oldTrace;
this.a = oldTrace.baseAddress; // Support relative addressing.
Expand All @@ -61,6 +67,9 @@ module.exports = function(env) {
MHKernel.prototype.factor = function(s, k, a, score) {
// Optimization: Bail early if we know acceptProb will be zero.
if (ad.value(score) === -Infinity) {
if (!this.allowHardFactors) {
throw new Error('Hard factor statements are not allowed.');
}
return this.finish(this.oldTrace, false);
}
this.trace.numFactors += 1;
Expand Down Expand Up @@ -203,10 +212,20 @@ module.exports = function(env) {
// assert(_.isNumber(ad.value(oldTrace.score)));
// assert(_.isNumber(this.regenFrom));
// assert(_.isNumber(this.proposalBoundary));

var fw = this.transitionProb(oldTrace, trace, this.fwdProposalDist);
var bw = this.transitionProb(trace, oldTrace, this.revProposalDist);
var p = Math.exp(ad.value(trace.score) - ad.value(oldTrace.score) + bw - fw);

var newTraceScore, oldTraceScore;
if (this.factorCoeff == 1) {
// Optimise for the common case.
newTraceScore = ad.value(trace.score);
oldTraceScore = ad.value(oldTrace.score);
} else {
newTraceScore = ad.value(trace.scoreAllChoices()) + this.factorCoeff * ad.value(trace.scoreAllFactors());
oldTraceScore = ad.value(oldTrace.scoreAllChoices()) + this.factorCoeff * ad.value(oldTrace.scoreAllFactors());
}

var p = Math.exp(newTraceScore - oldTraceScore + bw - fw);
assert(!isNaN(p));
return Math.min(1, p);
};
Expand Down
15 changes: 14 additions & 1 deletion src/trace.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ Trace.prototype.addChoice = function(dist, val, address, store, continuation, op
// assert(_.isObject(store));
// assert(_.isFunction(continuation));

var choiceScore = dist.score(val);

var choice = {
k: continuation,
address: address,
Expand All @@ -67,6 +69,7 @@ Trace.prototype.addChoice = function(dist, val, address, store, continuation, op
// Record the score without adding the choiceScore. This is the score we'll
// need if we regen from this choice.
score: this.score,
choiceScore: choiceScore,
val: val,
store: _.clone(store),
numFactors: this.numFactors
Expand All @@ -75,10 +78,20 @@ Trace.prototype.addChoice = function(dist, val, address, store, continuation, op
this.choices.push(choice);
this.addressMap[address] = choice;
this.length += 1;
this.score = ad.scalar.add(this.score, dist.score(val));
this.score = ad.scalar.add(this.score, choiceScore);
// this.checkConsistency();
};

Trace.prototype.scoreAllChoices = function() {
return this.choices.reduce(function(acc, choice) {
return ad.scalar.add(acc, choice.choiceScore);
}, 0);
};

Trace.prototype.scoreAllFactors = function() {
return ad.scalar.sub(this.score, this.scoreAllChoices());
};

Trace.prototype.complete = function(value) {
// Called at coroutine exit.
assert.strictEqual(this.value, undefined);
Expand Down
3 changes: 3 additions & 0 deletions tests/test-data/stochastic/expected/partition1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"logZ": 0
}
3 changes: 3 additions & 0 deletions tests/test-data/stochastic/expected/partition2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"logZ": -0.6539264674066638
}
3 changes: 3 additions & 0 deletions tests/test-data/stochastic/expected/partition3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"logZ": -2.266
}
3 changes: 3 additions & 0 deletions tests/test-data/stochastic/models/partition1.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
var model = function() {
return flip();
};
6 changes: 6 additions & 0 deletions tests/test-data/stochastic/models/partition2.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
var model = function() {
var x = flip(0.2);
var y = flip(0.8);
observe(Bernoulli({p : x && y ? 0.1 : 0.6}), true);
return {x, y};
};
7 changes: 7 additions & 0 deletions tests/test-data/stochastic/models/partition3.wppl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// The exact normalization constant for this model is:
// exp(-1) / (2 * sqrt(pi))
var model = function() {
var mu = gaussian(0, 1);
observe(Gaussian({mu, sigma: 1}), 2);
return mu;
};
15 changes: 15 additions & 0 deletions tests/test-inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,21 @@ var tests = [
dream2: true,
dream3: true
}
},
{
name: 'AIS',
settings: {
args: {
steps: 500,
samples: 100
},
logZ: {check: true, tol: 0.05}
},
models: {
partition1: {logZ: {check: true, tol: 1e-6}},
partition2: true,
partition3: true
}
}
];

Expand Down

0 comments on commit 9e4d473

Please sign in to comment.