Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redo all Discrete Transition nodes #455

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

wouterwln
Copy link
Member

Now adds 1(!!!) function to govern all Discrete Transition categorical rules. Probably the generic implementation is a bit slow at times but I will work on performance (and maybe add a few alternative strategies for when we have a nice and known structure)

@wouterwln wouterwln requested a review from bvdmitri March 5, 2025 13:39
Copy link

codecov bot commented Mar 5, 2025

Codecov Report

Attention: Patch coverage is 97.11538% with 3 lines in your changes missing coverage. Please review.

Project coverage is 74.92%. Comparing base (90112d3) to head (486cc8d).

Files with missing lines Patch % Lines
src/rules/discrete_transition/marginals.jl 76.92% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #455      +/-   ##
==========================================
+ Coverage   74.76%   74.92%   +0.15%     
==========================================
  Files         198      196       -2     
  Lines        5723     5771      +48     
==========================================
+ Hits         4279     4324      +45     
- Misses       1444     1447       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@bvdmitri bvdmitri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks neat, but some tests were either removed or changed. Fine with me, because those tests were checking some numerical answers? But what is the reason for changing tests in in_tests.jl and out_tests.jl?

Comment on lines -7 to -13
@testset "Belief Propagation: (m_out::Categorical, m_a::PointMass)" begin
@test_rules [check_type_promotion = false] DiscreteTransition(:in, Marginalisation) [(
input = (m_out = Categorical([0.1, 0.4, 0.5]), m_a = PointMass([0.2 0.1 0.7; 0.4 0.3 0.3; 0.1 0.6 0.3])),
output = Categorical([0.23000000000000004, 0.43, 0.33999999999999997])
)]
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were these wrong?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is wrong, since we cannot get an incoming message on a, we can only set it to PointMass and it will become a marginal (for which the rule is implemented). The rule for VMP was actually wrong, as was the test, and this was caught by the new general rule (I double checked this with Thijs)

@bvdmitri
Copy link
Member

bvdmitri commented Mar 6, 2025

I'm running tests for RxInfer and RxInferExamples now, if they pass I'll merge :)

@bvdmitri
Copy link
Member

bvdmitri commented Mar 6, 2025

@wouterwln there is a failing test in RxInfer repo called "infer with UnfactorizedData"

This is the stacktrace

ERROR: LoadError: MethodError: no method matching sum_out_dimensions(::Matrix{Float64}, ::Tuple{Int64}, ::Vector{Int64})
The function `sum_out_dimensions` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  sum_out_dimensions(::AbstractArray{T, M}, ::NTuple{N, Int64}, !Matched::AbstractArray{T, N}) where {T, M, N}
   @ ReactiveMP ~/.julia/dev/ReactiveMP.jl/src/rules/discrete_transition/categoricals.jl:30

Stacktrace:
   [1] (::ReactiveMP.var"#3049#3050")(::Tuple{Symbol, ReactiveMP.Marginal{BayesBase.PointMass{Vector{Int64}}, Nothing}})
     @ ReactiveMP ~/.julia/dev/ReactiveMP.jl/src/rules/discrete_transition/categoricals.jl:97
   [2] foreach(f::ReactiveMP.var"#3049#3050", itr::Base.Iterators.Zip{Tuple{Tuple{Symbol, Symbol}, Tuple{ReactiveMP.Marginal{BayesBase.PointMass{Vector{Int64}}, Nothing}, ReactiveMP.Marginal{BayesBase.PointMass{Matrix{Float64}}, Nothing}}}})
     @ Base ./abstractarray.jl:3187
   [3] discrete_transition_process_marginals
     @ ~/.julia/dev/ReactiveMP.jl/src/rules/discrete_transition/categoricals.jl:92 [inlined]
   [4] discrete_transition_structured_message_rule(message_names::Tuple{}, messages::Tuple{}, marginals_names::Tuple{Symbol, Symbol}, marginals::Tuple{ReactiveMP.Marginal{BayesBase.PointMass{Vector{Int64}}, Nothing}, ReactiveMP.Marginal{BayesBase.PointMass{Matrix{Float64}}, Nothing}}, q_a::ReactiveMP.Marginal{BayesBase.PointMass{Matrix{Float64}}, Nothing})
     @ ReactiveMP ~/.julia/dev/ReactiveMP.jl/src/rules/discrete_transition/categoricals.jl:147
   [5] rule
     @ ~/.julia/dev/ReactiveMP.jl/src/rules/discrete_transition/categoricals.jl:184 [inlined]
   [6] MessageMapping
     @ ~/.julia/dev/ReactiveMP.jl/src/message.jl:353 [inlined]
   [7] as_message(message::ReactiveMP.DeferredMessage{Nothing, Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}, ReactiveMP.MessageMapping{ReactiveMP.DiscreteTransition, Val{:in}, ReactiveMP.Marginalisation, Nothing, Val{(:out, :a)}, Nothing, Nothing, Nothing, Nothing}}, cache::Nothing, messages::Nothing, marginals::Tuple{ReactiveMP.Marginal{BayesBase.PointMass{Vector{Int64}}, Nothing}, ReactiveMP.Marginal{BayesBase.PointMass{Matrix{Float64}}, Nothing}})
     @ ReactiveMP ~/.julia/dev/ReactiveMP.jl/src/message.jl:235
   [8] as_message(message::ReactiveMP.DeferredMessage{Nothing, Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}, ReactiveMP.MessageMapping{ReactiveMP.DiscreteTransition, Val{:in}, ReactiveMP.Marginalisation, Nothing, Val{(:out, :a)}, Nothing, Nothing, Nothing, Nothing}}, cache::Nothing)
     @ ReactiveMP ~/.julia/dev/ReactiveMP.jl/src/message.jl:231
   [9] as_message(message::ReactiveMP.DeferredMessage{Nothing, Tuple{ReactiveMP.MarginalObservable, ReactiveMP.MarginalObservable}, ReactiveMP.MessageMapping{ReactiveMP.DiscreteTransition, Val{:in}, ReactiveMP.Marginalisation, Nothing, Val{(:out, :a)}, Nothing, Nothing, Nothing, Nothing}})
     @ ReactiveMP ~/.julia/dev/ReactiveMP.jl/src/message.jl:223
  [10] materialize!(type::ReactiveMP.EqualityRightOutbound, chain::ReactiveMP.EqualityChain{Rocket.ScheduleOnOperator{Rocket.AsapScheduler}, ReactiveMP.var"#15#17"{BayesBase.GenericProd, ReactiveMP.CompositeFormConstraint{Tuple{ReactiveMP.UnspecifiedFormConstraint, RxInfer.EnsureSupportedFunctionalForm}}}}, node_index::Int64)
     @ ReactiveMP ~/.julia/dev/ReactiveMP.jl/src/nodes/equality.jl:125
  [11] (::ReactiveMP.ChainOutboundMapping)(::Tuple{Rocket.LazyObservable{Missing}, Rocket.SingleObservable{Missing, Rocket.AsapScheduler}})
     @ ReactiveMP ~/.julia/dev/ReactiveMP.jl/src/nodes/equality.jl:165

RxInferExamples repo is fine and compiled without errors

@bvdmitri
Copy link
Member

bvdmitri commented Mar 6, 2025

The problem is in unmatched T, first argument has T=Float64 and the second has it T=Int64, I noticed this actually in tests that you changed here. I'm not sure if its 100% related. Maybe revert the test or add a new one with Int such that we catch this edge case in ReactiveMP repo as well without need to run RxInfer testset?

@wouterwln
Copy link
Member Author

@bvdmitri should be fixed now

@wouterwln wouterwln requested a review from bvdmitri March 6, 2025 15:58
@bvdmitri
Copy link
Member

bvdmitri commented Mar 7, 2025

The same test in RxInfer is still failing but for a different reason now

ERROR: LoadError: DomainError with [NaN, NaN, NaN, NaN]:
Categorical: vector p is not a probability vector
Stacktrace:
   [1] #116
     @ ~/.julia/packages/Distributions/tQhJE/src/univariate/discrete/categorical.jl:30 [inlined]
   [2] check_args
     @ ~/.julia/packages/Distributions/tQhJE/src/utils.jl:89 [inlined]
   [3] #_#115
     @ ~/.julia/packages/Distributions/tQhJE/src/univariate/discrete/categorical.jl:30 [inlined]
   [4] DiscreteNonParametric
     @ ~/.julia/packages/Distributions/tQhJE/src/univariate/discrete/categorical.jl:29 [inlined]
   [5] (Distributions.Categorical{P} where P<:Real)(p::Vector{Float64})
     @ Distributions ~/.julia/packages/Distributions/tQhJE/src/univariate/discrete/categorical.jl:34
   [6] discrete_transition_structured_message_rule(message_names::Tuple{}, messages::Tuple{}, marginals_names::Tuple{Symbol, Symbol}, marginals::Tuple{ReactiveMP.Marginal{BayesBase.PointMass{Vector{Int64}}, Nothing}, ReactiveMP.Marginal{BayesBase.PointMass{Matrix{Float64}}, Nothing}}, q_a::ReactiveMP.Marginal{BayesBase.PointMass{Matrix{Float64}}, Nothing})
     @ ReactiveMP ~/.julia/dev/ReactiveMP.jl/src/rules/discrete_transition/categoricals.jl:166
   [7] rule
     @ ~/.julia/dev/ReactiveMP.jl/src/rules/discrete_transition/categoricals.jl:198 [inlined]
   [8] MessageMapping
     @ ~/.julia/dev/ReactiveMP.jl/src/message.jl:353 [inlined]

this is the only test that is failing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants