Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support "transitive" Scan log-probabilities #75

Open
Tracked by #229
brandonwillard opened this issue Oct 15, 2021 · 0 comments
Open
Tracked by #229

Support "transitive" Scan log-probabilities #75

brandonwillard opened this issue Oct 15, 2021 · 0 comments
Labels
enhancement New feature or request graph rewriting Involves the implementation of rewrites to Aesara graphs help wanted Extra attention is needed important This label is used to indicate priority over things not given this label op-probability Involves the implementation of log-probabilities for Aesara `Op`s scan

Comments

@brandonwillard
Copy link
Member

Currently, Scan log-probability support only handles cases in which the MeasurableVariable is created inside the body/step function of the Scan, and not when the body/step function simply references a MeasurableVariable that is being iterated over by the Scan.

For example, the following is not supported:

import aesara
import aesara.tensor as at

from aeppl.joint_logprob import factorized_joint_logprob


srng = at.random.RandomStream(seed=2320)
N = 10

Y_rv = srng.normal(0, 1, size=N, name="Y")


def step_fn(y_t):
    return y_t


Y_1T_rv, _ = aesara.scan(
    fn=step_fn,
    sequences=[Y_rv],
    strict=True,
)

y_vv = Y_1T_rv.clone()
y_vv.name = "y"

logp_parts = factorized_joint_logprob({Y_1T_rv: y_vv})

This example is very trivial, but, if we change step_fn so that it performs a supported, measurable operation on y_t (e.g. indexing a mixture), it wouldn't work for the same reason.

When a value is assigned to Scan output terms like Y_1T_rv, we could "push" the relevant sequences inputs into the step function. In other words, we could construct the type of graph we currently handle.

Working from the example above, we would rewrite the Scan into something like the following:

# Apply a rewrite like `local_rv_size_lift` to get properly `size`-broadcasted parameters
# in a new variable `new_Y_rv`
mu_bcast, sigma_bcast = new_Y_rv.owner.inputs[3:]

def new_step_fn(mu_t, sigma_t)
    return Y_rv.owner.op(mu_t, sigma_t, name="Y[t]")

new_Y_1T_rv, _ = aesara.scan(
    fn=new_step_fn,
    sequences=[mu_bcast, sigma_bcast],
    strict=True,
)
@brandonwillard brandonwillard added graph rewriting Involves the implementation of rewrites to Aesara graphs important This label is used to indicate priority over things not given this label op-probability Involves the implementation of log-probabilities for Aesara `Op`s help wanted Extra attention is needed enhancement New feature or request labels Oct 15, 2021
@brandonwillard brandonwillard pinned this issue Oct 25, 2021
@rlouf rlouf moved this to Graph features in AePPL Roadmap Feb 6, 2023
@rlouf rlouf removed this from AePPL Roadmap Feb 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting Involves the implementation of rewrites to Aesara graphs help wanted Extra attention is needed important This label is used to indicate priority over things not given this label op-probability Involves the implementation of log-probabilities for Aesara `Op`s scan
Projects
None yet
Development

No branches or pull requests

1 participant