Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more information to the migration guide #296

Merged
merged 3 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 139 additions & 18 deletions docs/src/manuals/migration-guide-v2-v3.md
Original file line number Diff line number Diff line change
@@ -1,31 +1,158 @@
# Migration Guide from version 2.x to 3.x

This guide is intended to help you migrate your project from version 2.x to 3.x of `RxInfer`. The main difference between these two versions is the redefinition of the model specification language. A detailed explanation of the new model definition language can be found in the [GraphPPL documentation](https://reactivebayes.github.io/GraphPPL.jl/dev/migration_3_to_4/). Here, we will give an overview of the most important changes and introduce `RxInfer` specific changes.
This guide is intended to help you migrate your project from version 2.x to 3.x of `RxInfer`. The main difference between these two versions is the redefinition of the model specification language. A detailed explanation of the new model definition language can be found in the [GraphPPL documentation](https://reactivebayes.github.io/GraphPPL.jl/stable/migration_3_to_4/). Here, we will give an overview of the most important changes and introduce `RxInfer` specific changes.

## Model specification

Also read the [Model specification](@ref user-guide-model-specification) guide.

### `randomvar`, `datavar` and `constvar` have been removed

The most notable change in the model specification is the removal of the `randomvar`, `datavar`, and `constvar` functions.
Now, the `@model` macro automatically determines whether to use `randomvar` or `constvar` based on their usage.
Previously declared `datavar` variables must now be listed in the argument list of the model.

The following example is a simple model definition in version 3:
```julia
@model function SSM(n, x0, A, B, Q, P)
x = randomvar(n)
y = datavar(Vector{Float64}, n)
x_prior ~ MvNormal(μ = mean(x0), Σ = cov(x0))
x_prev = x_prior
for i in 1:n
x[i] ~ MvNormal(μ = A * x_prev, Σ = Q)
y[i] ~ MvNormal(μ = B * x[i], Σ = P)
x_prev = x[i]
end
end
```

The equivalent model definition in version 4 is as follows:
```julia
@model function SSM(y, prior_x, A, B, Q, P)
x_prev ~ prior_x
for i in eachindex(y)
x[i] ~ MvNormal(μ = A * x_prev, Σ = Q)
y[i] ~ MvNormal(μ = B * x[i], Σ = P)
x_prev = x[i]
end
end
```

Read more about the change in the [GraphPPL documentation](https://reactivebayes.github.io/GraphPPL.jl/stable/migration_3_to_4/) and
in the updated [Model specification](@ref user-guide-model-specification) guide.

## Model Definition
### Positional arguments are converted to keyword arguments

The model definition in the `@model` macro has changed significantly. This change also has implications for the `infer` function. Since all interfaces to a model are now passed as arguments to the `@model` macro, the `infer` function needs additional information on model construction. Therefore we only support keyword arguments on model construction. An example of the new model definition is shown below:
The changes in the model specification also have implications for the [`infer`](@ref) function. Since all interfaces to a model are now passed as arguments to the `@model` macro, the `infer` function needs additional information on model construction. Therefore, the model function definition converts all positional arguments to keyword arguments. Positional arguments are no longer supported in the model function definition. Below is an example of the new model definition:

```@example migration-guide
using Test #hide
using RxInfer

@model function coin_toss(prior, y)
θ ~ prior
y .~ Bernoulli(θ)
end

# Here, we pass a prior as a parameter to the model, and the data `y` is passed as data. Since we have to distinguish between what should be used as which argument, we have to pass the data as a keyword argument.
infer(model = coin_toss(prior=Beta(1, 1)),
data=(y=[1, 0, 1],)
# Here, we pass a prior as a parameter to the model, and the data `y` is passed as data.
# Since we have to distinguish between what should be used as which argument, we have to pass the data as a keyword argument.
infer(
model = coin_toss(prior = Beta(1, 1)),
data = (y = [1, 0, 1],)
)
```

### Multiple dispatch is no long supported

Due to the previous change, it is not possible to use multiple dispatch for model function definitions. In other words, type constraints for model arguments are ignored because Julia does not support multiple dispatch for keyword arguments.

### Return value from the model function

Accessing the return value of the model function has changed. Previously, the return value was returned together with the model upon creation. Now, the return value is saved in the model's data structure, which can be accessed with the [`RxInfer.getreturnval`](@ref) function. To demonstrate the difference, previously we could do the following:
```julia
@model function test_model(a, b)
y = datavar(Float64)
θ ~ Beta(1.0, 1.0)
y ~ Bernoulli(θ)
return "Hello, world!"
end
modelgenerator = test_model(1.0, 1.0)
model, returnval = RxInfer.create_model(modelgenerator)
returnval # "Hello, world!"
```
The new API is changed to:
```@example migration-guide
@model function test_model(y, a, b) #hide
θ ~ Beta(1.0, 1.0) #hide
y ~ Bernoulli(θ) #hide
return "Hello, world!" #hide
end #hide
modelgenerator = test_model(a = 1.0, b = 1.0) | (y = 1, )
model = RxInfer.create_model(modelgenerator)
@test RxInfer.getreturnval(model) == "Hello, world!" #hide
RxInfer.getreturnval(model)
```

The [`InferenceResult`](@ref) also no longer stores the `returnval` field. Instead, use the `model` field and the [`RxInfer.getreturnval`](@ref) function:
```@example migration-guide
result = infer(
model = test_model(a = 1.0, b = 1.0),
data = (y = 1, )
)
@test RxInfer.getreturnval(result.model) == "Hello, world!" #hide
RxInfer.getreturnval(result.model)
```

### Returning variables from the model

Similar to the previous version, you can still return latent variables from the model definition:
```@example migration-guide
@model function test_model(y, a, b)
θ ~ Beta(1.0, 1.0)
y ~ Bernoulli(θ)
return θ
end
```
However, their type has changed to internal data structures from the `GraphPPL` package. To access the `ReactiveMP` data structures (e.g., to retrieve the messages or marginals streams), use `RxInfer.getvarref` along with `RxInfer.getvariable`:
```@example migration-guide
using ReactiveMP, Rocket
result = infer(
model = test_model(a = 1.0, b = 1.0),
data = (y = 1, )
)

θlabel = RxInfer.getreturnval(result.model)
θvarref = RxInfer.getvarref(result.model, θlabel)
θvar = RxInfer.getvariable(θvarref)
@test θvar isa ReactiveMP.RandomVariable #hide
qθ_test = [] #hide
subscribe!(ReactiveMP.getmarginal(θvar) |> take(1), (qθ) -> push!(qθ_test, qθ)) #hide
@test length(qθ_test) === 1 #hide
@test first(ReactiveMP.getdata(qθ_test)) == Beta(2.0, 1.0) #hide

# `|> take(1)` ensures automatic unsubscription
θmarginals_subscription = subscribe!(ReactiveMP.getmarginal(θvar) |> take(1), (qθ) -> println(qθ))
nothing #hide
```

## Initialization

Initialization of messages and marginals to kickstart the inference procedure was previously done with the `initmessages` and `initmarginals` keyword. With the introduction of a nested model specificiation in the `@model` macro, we now need a more specific way to initialize messages and marginals. This is done with the new `@initialization` macro. The syntax for the `@initialization` macro is similar to the `@constraints` and `@meta` macro. An example is shown below:
Also read the [Initialization](@ref initialization) guide.

Initialization of messages and marginals to kickstart the inference procedure was previously done with the `initmessages` and `initmarginals` keyword. With the introduction of a nested model specificiation in the `@model` macro, we now need a more specific way to initialize messages and marginals. This is done with the new [`@initialization`](@ref) macro. The syntax for the `@initialization` macro is similar to the `@constraints` and `@meta` macro. An example is shown below:

```@example migration-guide
@model function submodel() end #hide
@model function submodel(z, x)
t := x + 1
z ~ Normal(mean = t, var = 1.0)
end

@model function mymodel(y)
x ~ Normal(mean = 0.0, var = 1.0)
z ~ submodel(x = x)
y ~ Normal(mean = z, var = 1.0)
end

@initialization begin
# Initialize the marginal for the variable x
Expand All @@ -36,12 +163,12 @@ Initialization of messages and marginals to kickstart the inference procedure wa

# Specify the initialization for a submodel of type `submodel`
for init in submodel
q(some_var) = vague(NormalMeanVariance)
q(t) = vague(NormalMeanVariance)
end

# Specify the initialization for a submodel of type `submodel` with a specific index
for init in (submodel, 1)
q(some_var) = vague(NormalMeanVariance)
q(t) = vague(NormalMeanVariance)
end
end
```
Expand All @@ -58,20 +185,14 @@ Similar to the `@constraints` macro, the `@initialization` macro also supports f

# Specify the initialization for a submodel of type `submodel`
for init in submodel
q(some_var) = vague(NormalMeanVariance)
q(t) = vague(NormalMeanVariance)
end

# Specify the initialization for a submodel of type `submodel` with a specific index
for init in (submodel, 1)
q(some_var) = vague(NormalMeanVariance)
q(t) = vague(NormalMeanVariance)
end
end
```

The result of the initialization macro can be passed to the inference function with keyword argument `initialization`.

## Deprecated syntax

The following syntax is deprecated and will be removed in future versions of `RxInfer`:
- `initmessages` and `initmarginals` keyword arguments
- `randomvar` and `datavar` syntax in the `@model` macro
3 changes: 3 additions & 0 deletions src/model/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
"Returns the factor nodes from the model specification."
getfactornodes(model::ProbabilisticModel) = getfactornodes(getmodel(model))

# Redirect the `getvarref` call to the underlying model
getvarref(model::ProbabilisticModel, label) = getvarref(getmodel(model), label)

Check warning on line 37 in src/model/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model/model.jl#L37

Added line #L37 was not covered by tests

"""
ConditionedModelGenerator(generator, conditioned_on)

Expand Down
Loading