From 6d273d0a6998330e5e87ee41d3c18712c08ca165 Mon Sep 17 00:00:00 2001 From: Chengfeng-Jia Date: Mon, 27 Nov 2023 09:19:36 +0100 Subject: [PATCH 1/9] add an example for add ReTestItems --- test/helpers_test.jl | 37 +++++++++++++++++++++++++++++++++++++ test/run_runtest.jl | 13 +++++++++++++ test/test_helpers.jl | 42 ------------------------------------------ test/test_test.jl | 10 ++++++++++ test/test_test2.jl | 10 ++++++++++ 5 files changed, 70 insertions(+), 42 deletions(-) create mode 100644 test/helpers_test.jl create mode 100644 test/run_runtest.jl delete mode 100644 test/test_helpers.jl create mode 100644 test/test_test.jl create mode 100644 test/test_test2.jl diff --git a/test/helpers_test.jl b/test/helpers_test.jl new file mode 100644 index 000000000..35168ea6c --- /dev/null +++ b/test/helpers_test.jl @@ -0,0 +1,37 @@ + + + @testitem "NamedTuple helpers" begin + import RxInfer: fields, nthasfield + + @test fields((x = 1, y = 2)) === (:x, :y) + @test fields((x = 1, y = 2, c = 3)) === (:x, :y, :c) + @test fields(typeof((x = 1, y = 2))) === (:x, :y) + @test fields(typeof((x = 1, y = 2, c = 3))) === (:x, :y, :c) + + @test nthasfield(:x, (x = 1, y = 2)) === true + @test nthasfield(:c, (x = 1, y = 2)) === false + @test nthasfield(:x, typeof((x = 1, y = 2))) === true + @test nthasfield(:c, typeof((x = 1, y = 2))) === false + end + + @testitem "Tuple helpers" begin + import RxInfer: as_tuple + + @test as_tuple(1) === (1,) + @test as_tuple((1,)) === (1,) + + @test as_tuple("string") === ("string",) + @test as_tuple(("string",)) === ("string",) + end + + @testitem "Val helpers" begin + import RxInfer: unval + + @test unval(Val(1)) === 1 + @test unval(Val(())) === () + @test unval(Val(nothing)) === nothing + + @test_throws ErrorException unval(1) + @test_throws ErrorException unval(()) + @test_throws ErrorException unval(nothing) + end diff --git a/test/run_runtest.jl b/test/run_runtest.jl new file mode 100644 index 000000000..d284e0f9d --- /dev/null +++ b/test/run_runtest.jl @@ -0,0 +1,13 @@ +using Aqua, CpuId,ReTestItems,RxInfer + +# runtests( +# "./"; +# ) +Aqua.test_all(RxInfer; ambiguities=false, piracies=false, deps_compat = (; check_extras = false, check_weakdeps = true)) + +nthreads = max(cputhreads(), 1) +ncores = max(cpucores(), 1) + +runtests( + RxInfer; nworkers=ncores, nworker_threads=Int(nthreads / ncores), memory_threshold=1.0 +) \ No newline at end of file diff --git a/test/test_helpers.jl b/test/test_helpers.jl deleted file mode 100644 index 1a242f8ef..000000000 --- a/test/test_helpers.jl +++ /dev/null @@ -1,42 +0,0 @@ -module RxInferHelpersTest - -using Test -using RxInfer - -@testset "NamedTuple helpers" begin - import RxInfer: fields, nthasfield - - @test fields((x = 1, y = 2)) === (:x, :y) - @test fields((x = 1, y = 2, c = 3)) === (:x, :y, :c) - @test fields(typeof((x = 1, y = 2))) === (:x, :y) - @test fields(typeof((x = 1, y = 2, c = 3))) === (:x, :y, :c) - - @test nthasfield(:x, (x = 1, y = 2)) === true - @test nthasfield(:c, (x = 1, y = 2)) === false - @test nthasfield(:x, typeof((x = 1, y = 2))) === true - @test nthasfield(:c, typeof((x = 1, y = 2))) === false -end - -@testset "Tuple helpers" begin - import RxInfer: as_tuple - - @test as_tuple(1) === (1,) - @test as_tuple((1,)) === (1,) - - @test as_tuple("string") === ("string",) - @test as_tuple(("string",)) === ("string",) -end - -@testset "Val helpers" begin - import RxInfer: unval - - @test unval(Val(1)) === 1 - @test unval(Val(())) === () - @test unval(Val(nothing)) === nothing - - @test_throws ErrorException unval(1) - @test_throws ErrorException unval(()) - @test_throws ErrorException unval(nothing) -end - -end diff --git a/test/test_test.jl b/test/test_test.jl new file mode 100644 index 000000000..d027cf7af --- /dev/null +++ b/test/test_test.jl @@ -0,0 +1,10 @@ +@testitem "addition" begin + @test 1 + 2 == 3 + @test 0 + 2 == 2 + @test -1 + 2 == 1 +end +@testitem "multiplication" begin + @test 1 * 2 == 2 + @test 0 * 2 == 0 + @test -1 * 2 == -2 +end \ No newline at end of file diff --git a/test/test_test2.jl b/test/test_test2.jl new file mode 100644 index 000000000..d027cf7af --- /dev/null +++ b/test/test_test2.jl @@ -0,0 +1,10 @@ +@testitem "addition" begin + @test 1 + 2 == 3 + @test 0 + 2 == 2 + @test -1 + 2 == 1 +end +@testitem "multiplication" begin + @test 1 * 2 == 2 + @test 0 * 2 == 0 + @test -1 * 2 == -2 +end \ No newline at end of file From 8bb1a2504af3f0c06cdb3402c07ec18593963135 Mon Sep 17 00:00:00 2001 From: MarcoH Date: Mon, 27 Nov 2023 11:40:57 +0100 Subject: [PATCH 2/9] using ReTestItems for tests --- Project.toml | 1 + test/{test_helpers.jl => helpers_tests.jl} | 12 +- test/{test_inference.jl => inference_test.jl} | 22 +- test/{test_model.jl => model_tests.jl} | 10 +- test/{test_node.jl => node_tests.jl} | 12 +- test/runtests.jl | 264 +----------------- test/runtests_prev.jl | 257 +++++++++++++++++ 7 files changed, 282 insertions(+), 296 deletions(-) rename test/{test_helpers.jl => helpers_tests.jl} (86%) rename test/{test_inference.jl => inference_test.jl} (98%) rename test/{test_model.jl => model_tests.jl} (98%) rename test/{test_node.jl => node_tests.jl} (99%) create mode 100644 test/runtests_prev.jl diff --git a/Project.toml b/Project.toml index f8d8871cd..4a6719014 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Optim = "429524aa-4258-5aef-a3af-852621145aeb" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReactiveMP = "a194aa59-28ba-4574-a09c-4a745416d6e3" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Rocket = "df971d30-c9d6-4b37-b8ff-e965b2cb3a40" diff --git a/test/test_helpers.jl b/test/helpers_tests.jl similarity index 86% rename from test/test_helpers.jl rename to test/helpers_tests.jl index 1a242f8ef..230a6cc99 100644 --- a/test/test_helpers.jl +++ b/test/helpers_tests.jl @@ -1,9 +1,4 @@ -module RxInferHelpersTest - -using Test -using RxInfer - -@testset "NamedTuple helpers" begin +@testitem "NamedTuple helpers" begin import RxInfer: fields, nthasfield @test fields((x = 1, y = 2)) === (:x, :y) @@ -17,7 +12,7 @@ using RxInfer @test nthasfield(:c, typeof((x = 1, y = 2))) === false end -@testset "Tuple helpers" begin +@testitem "Tuple helpers" begin import RxInfer: as_tuple @test as_tuple(1) === (1,) @@ -27,7 +22,7 @@ end @test as_tuple(("string",)) === ("string",) end -@testset "Val helpers" begin +@testitem "Val helpers" begin import RxInfer: unval @test unval(Val(1)) === 1 @@ -39,4 +34,3 @@ end @test_throws ErrorException unval(nothing) end -end diff --git a/test/test_inference.jl b/test/inference_test.jl similarity index 98% rename from test/test_inference.jl rename to test/inference_test.jl index abfe19e5b..c2150274d 100644 --- a/test/test_inference.jl +++ b/test/inference_test.jl @@ -1,10 +1,4 @@ -module RxInferInferenceTest - -using Test -using RxInfer -using Random - -@testset "__inference_check_itertype" begin +@testitem "__inference_check_itertype" begin import RxInfer: __inference_check_itertype @test __inference_check_itertype(:something, nothing) === nothing @@ -18,7 +12,7 @@ using Random @test_throws ErrorException __inference_check_itertype(:something, missing) end -@testset "__inference_check_dicttype" begin +@testitem "__inference_check_dicttype" begin import RxInfer: __inference_check_dicttype @test __inference_check_dicttype(:something, nothing) === nothing @@ -33,7 +27,7 @@ end @test_throws ErrorException __inference_check_dicttype(:something, (missing)) end -@testset "`@autoupdates` macro" begin +@testitem "`@autoupdates` macro" begin function somefunction(something) return nothing end @@ -143,7 +137,7 @@ end end end -@testset "Static inference with `inference`" begin +@testitem "Static inference with `inference`" begin # A simple model for testing that resembles a simple kalman filter with # random walk state transition and unknown observational noise @@ -311,7 +305,7 @@ end end end -@testset "Test warn argument in `inference()`" begin +@testitem "Test warn argument in `inference()`" begin @testset "Test warning for addons" begin #Add a new case for testing warning of addons @@ -406,7 +400,7 @@ end end end -@testset "Reactive inference with `rxinference` for test model #1" begin +@testitem "Reactive inference with `rxinference` for test model #1" begin # A simple model for testing that resembles a simple kalman filter with # random walk state transition and unknown observational noise @@ -815,7 +809,7 @@ end end end -@testset "Predictions functionality" begin +@testitem "Predictions functionality" begin # test #1 (array with missing + predictvars) data = (y = [1.0, -500.0, missing, 100.0],) @@ -1032,5 +1026,3 @@ end @test all(result.predictions[:y] .== Bernoulli(mean(Beta(1.0, 1.0)))) end - -end diff --git a/test/test_model.jl b/test/model_tests.jl similarity index 98% rename from test/test_model.jl rename to test/model_tests.jl index 44890d926..d24af5292 100644 --- a/test/test_model.jl +++ b/test/model_tests.jl @@ -1,10 +1,4 @@ -module RxInferModelTest - -using Test -using RxInfer -using Random - -@testset "@model macro tests" begin +@testitem "@model macro tests" begin @testset "Tuple based variables usage #1" begin @model function mixture_model() mean1 ~ Normal(mean = 10, variance = 10000) @@ -97,6 +91,7 @@ using Random end @testset "Priors in arguments" begin + import Random: MersenneTwister @model function coin_model_priors1(n, prior) y = datavar(Float64, n) θ ~ prior @@ -188,4 +183,3 @@ using Random end end -end diff --git a/test/test_node.jl b/test/node_tests.jl similarity index 99% rename from test/test_node.jl rename to test/node_tests.jl index 74d79a6a9..0435e468b 100644 --- a/test/test_node.jl +++ b/test/node_tests.jl @@ -1,10 +1,4 @@ -module RxInferNodeTest - -using Test -using RxInfer -using Random - -@testset "@node macro integration tests" begin +@testitem "@node macro integration tests" begin @testset "make_node compatibility tests for stochastic nodes" begin struct CustomStochasticNode end @@ -715,6 +709,4 @@ using Random end end end -end - -end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 997b64939..d284e0f9d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,257 +1,13 @@ +using Aqua, CpuId,ReTestItems,RxInfer -## https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 -## https://gr-framework.org/workstations.html#no-output -ENV["GKSwstype"] = "100" +# runtests( +# "./"; +# ) +Aqua.test_all(RxInfer; ambiguities=false, piracies=false, deps_compat = (; check_extras = false, check_weakdeps = true)) -const IS_USE_DEV = get(ENV, "USE_DEV", "false") == "true" -const IS_BENCHMARK = get(ENV, "BENCHMARK", "false") == "true" +nthreads = max(cputhreads(), 1) +ncores = max(cpucores(), 1) -# We use only `1` runner in case if benchmarks are enabled to improve the -# quality of the benchmarking procedure -const NUM_RUNNERS = IS_BENCHMARK ? 1 : min(Sys.CPU_THREADS, 4) - -using Distributed - -const worker_io_lock = ReentrantLock() -const worker_ios = Dict() - -worker_io(ident) = get!(() -> IOBuffer(), worker_ios, string(ident)) - -# Dynamically overwrite default worker's `print` function for better control over stdout -Distributed.redirect_worker_output(ident, stream) = begin - task = @async while !eof(stream) - line = readline(stream) - lock(worker_io_lock) do - io = worker_io(ident) - write(io, line, "\n") - end - end - @static if VERSION >= v"1.7" - Base.errormonitor(task) - end -end - -# This function prints `worker's` standard output into the global standard output -function flush_workerio(ident) - lock(worker_io_lock) do - wio = worker_io(ident) - str = String(take!(wio)) - println(stdout, str) - flush(stdout) - end -end - -import Pkg - -if IS_USE_DEV - Pkg.rm("ReactiveMP") - Pkg.rm("GraphPPL") - Pkg.rm("Rocket") - Pkg.develop(Pkg.PackageSpec(path = joinpath(Pkg.devdir(), "ReactiveMP.jl"))) - Pkg.develop(Pkg.PackageSpec(path = joinpath(Pkg.devdir(), "GraphPPL.jl"))) - Pkg.develop(Pkg.PackageSpec(path = joinpath(Pkg.devdir(), "Rocket.jl"))) - Pkg.update() -end - -# DocMeta.setdocmeta!(RxInfer, :DocTestSetup, :(using RxInfer, Distributions); recursive=true) - -# Example usage of a reduced testset -# julia --project --color=yes -e 'import Pkg; Pkg.test(test_args = [ "distributions:normal_mean_variance" ])' - -@info "Running tests using $(NUM_RUNNERS) runners." -addprocs(NUM_RUNNERS) - -@everywhere using Test, Documenter, RxInfer -@everywhere using TestSetExtensions - -import Base: wait - -mutable struct TestRunner - enabled_tests - found_tests - test_tasks - workerpool - jobschannel - exschannel - iochannel - - function TestRunner(ARGS) - enabled_tests = lowercase.(ARGS) - found_tests = Dict(map(test -> test => false, enabled_tests)) - test_tasks = [] - jobschannel = RemoteChannel(() -> Channel(Inf), myid()) # Channel for jobs - exschannel = RemoteChannel(() -> Channel(Inf), myid()) # Channel for exceptions - iochannel = RemoteChannel(() -> Channel(0), myid()) - @async begin - while isopen(iochannel) - ident = take!(iochannel) - flush_workerio(ident) - end - end - return new(enabled_tests, found_tests, test_tasks, 2:nprocs(), jobschannel, exschannel, iochannel) - end -end - -function Base.run(testrunner::TestRunner) - println("") # New line for 'better' alignment of the `testrunner` results - - foreach(testrunner.workerpool) do worker - # For each worker we create a `nothing` token in the `jobschannel` - # This token indicates that there are no other jobs left - put!(testrunner.jobschannel, nothing) - # We create a remote call for another Julia process to execute our test with `include(filename)` - task = remotecall(worker, testrunner.jobschannel, testrunner.exschannel, testrunner.iochannel) do jobschannel, exschannel, iochannel - finish = false - while !finish - # Each worker takes jobs sequentially from the shared jobs pool - job_filename = take!(jobschannel) - if isnothing(job_filename) # At the end there are should be only `emptyjobs`, in which case the worker finishes its tasks - finish = true - else # Otherwise we assume that the `job` contains the valid `filename` and execute test - try # Here we can easily get the `LoadError` if some tests are failing - include(job_filename) - catch iexception - put!(exschannel, iexception) - end - # After the work is done we put the worker's `id` into `iochannel` (this triggers test info printing) - put!(iochannel, myid()) - end - end - return nothing - end - # We save the created task for later synchronization - push!(testrunner.test_tasks, task) - end - - # For each remotely called task we `fetch` its result or save an exception - foreach(fetch, testrunner.test_tasks) - - # If exception are not empty we notify the user and force-fail - if isready(testrunner.exschannel) - println(stderr, "Tests have failed with the following exceptions: ") - while isready(testrunner.exschannel) - exception = take!(testrunner.exschannel) - showerror(stderr, exception) - println(stderr, "\n", "="^80) - end - exit(-1) - end - - close(testrunner.iochannel) - close(testrunner.exschannel) - close(testrunner.jobschannel) - - # At the very last stage we check that there are no "missing" tests, - # aka tests that have been specified in the `enabled_tests`, - # but for which the corresponding `filename` does not exist in the `test/` folder - notfound_tests = filter(v -> v[2] === false, testrunner.found_tests) - if !isempty(notfound_tests) - println(stderr, "There are missing tests, double check correct spelling/path for the following entries:") - foreach(keys(notfound_tests)) do key - println(stderr, " - ", key) - end - exit(-1) - end -end - -const testrunner = TestRunner(lowercase.(ARGS)) - -@everywhere workerlocal_lock = ReentrantLock() - -function addtests(testrunner::TestRunner, filename) - # First we transform filename into `key` and check if we have this entry in the `enabled_tests` (if `enabled_tests` is not empty) - key = filename_to_key(filename) - if isempty(testrunner.enabled_tests) || key in testrunner.enabled_tests - # If `enabled_tests` is not empty we mark the corresponding key with the `true` value to indicate that we found the corresponding `file` in the `/test` folder - if !isempty(testrunner.enabled_tests) - setindex!(testrunner.found_tests, true, key) # Mark that test has been found - end - # At this stage we simply put the `filename` into the `jobschannel` that will be processed later (see the `execute` function) - put!(testrunner.jobschannel, filename) - end -end - -function key_to_filename(key) - splitted = split(key, ":") - return if length(splitted) === 1 - string("test_", first(splitted), ".jl") - else - string(join(splitted[1:(end - 1)], "/"), "/test_", splitted[end], ".jl") - end -end - -function filename_to_key(filename) - splitted = split(filename, "/") - if length(splitted) === 1 - return replace(replace(first(splitted), ".jl" => ""), "test_" => "") - else - path, name = splitted[1:(end - 1)], splitted[end] - return string(join(path, ":"), ":", replace(replace(name, ".jl" => ""), "test_" => "")) - end -end - -using Aqua - -if isempty(testrunner.enabled_tests) - println("Running all tests (including Aqua)...") - # We pirate some methods from ReactiveMP for now - Aqua.test_all(RxInfer; ambiguities = false, piracies = false, deps_compat = (; check_extras = false, check_weakdeps = true)) -else - println("Running specific tests only:") - foreach(testrunner.enabled_tests) do test - println(" - ", test) - end -end - -@testset ExtendedTestSet "RxInfer" begin - @testset "Testset helpers" begin - @test key_to_filename(filename_to_key("distributions/test_normal_mean_variance.jl")) == "distributions/test_normal_mean_variance.jl" - @test filename_to_key(key_to_filename("distributions:normal_mean_variance")) == "distributions:normal_mean_variance" - @test key_to_filename(filename_to_key("test_message.jl")) == "test_message.jl" - @test filename_to_key(key_to_filename("message")) == "message" - end - - addtests(testrunner, "test_helpers.jl") - - addtests(testrunner, "score/test_bfe.jl") - - addtests(testrunner, "constraints/test_meta_constraints.jl") - addtests(testrunner, "constraints/test_form_constraints.jl") - addtests(testrunner, "constraints/test_factorisation_constraints.jl") - addtests(testrunner, "constraints/form/test_form_point_mass.jl") - addtests(testrunner, "constraints/form/test_form_sample_list.jl") - - addtests(testrunner, "test_node.jl") - addtests(testrunner, "test_model.jl") - addtests(testrunner, "test_inference.jl") - - addtests(testrunner, "models/aliases/test_aliases_binary.jl") - addtests(testrunner, "models/aliases/test_aliases_normal.jl") - - addtests(testrunner, "models/autoregressive/test_ar.jl") - addtests(testrunner, "models/autoregressive/test_lar.jl") - - addtests(testrunner, "models/datavars/test_fn_datavars.jl") - - addtests(testrunner, "models/mixtures/test_gmm_univariate.jl") - addtests(testrunner, "models/mixtures/test_gmm_multivariate.jl") - addtests(testrunner, "models/mixtures/test_mixture.jl") - - addtests(testrunner, "models/statespace/test_ulgssm.jl") - addtests(testrunner, "models/statespace/test_mlgssm.jl") - addtests(testrunner, "models/statespace/test_hmm.jl") - addtests(testrunner, "models/statespace/test_probit.jl") - addtests(testrunner, "models/statespace/test_hgf.jl") - - addtests(testrunner, "models/iid/test_mv_iid_precision.jl") - addtests(testrunner, "models/iid/test_mv_iid_precision_known_mean.jl") - addtests(testrunner, "models/iid/test_mv_iid_covariance.jl") - addtests(testrunner, "models/iid/test_mv_iid_covariance_known_mean.jl") - - addtests(testrunner, "models/nonlinear/test_generic_applicability.jl") - addtests(testrunner, "models/nonlinear/test_cvi.jl") - - addtests(testrunner, "models/regression/test_linreg.jl") - - run(testrunner) -end +runtests( + RxInfer; nworkers=ncores, nworker_threads=Int(nthreads / ncores), memory_threshold=1.0 +) \ No newline at end of file diff --git a/test/runtests_prev.jl b/test/runtests_prev.jl new file mode 100644 index 000000000..997b64939 --- /dev/null +++ b/test/runtests_prev.jl @@ -0,0 +1,257 @@ + +## https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 +## https://gr-framework.org/workstations.html#no-output +ENV["GKSwstype"] = "100" + +const IS_USE_DEV = get(ENV, "USE_DEV", "false") == "true" +const IS_BENCHMARK = get(ENV, "BENCHMARK", "false") == "true" + +# We use only `1` runner in case if benchmarks are enabled to improve the +# quality of the benchmarking procedure +const NUM_RUNNERS = IS_BENCHMARK ? 1 : min(Sys.CPU_THREADS, 4) + +using Distributed + +const worker_io_lock = ReentrantLock() +const worker_ios = Dict() + +worker_io(ident) = get!(() -> IOBuffer(), worker_ios, string(ident)) + +# Dynamically overwrite default worker's `print` function for better control over stdout +Distributed.redirect_worker_output(ident, stream) = begin + task = @async while !eof(stream) + line = readline(stream) + lock(worker_io_lock) do + io = worker_io(ident) + write(io, line, "\n") + end + end + @static if VERSION >= v"1.7" + Base.errormonitor(task) + end +end + +# This function prints `worker's` standard output into the global standard output +function flush_workerio(ident) + lock(worker_io_lock) do + wio = worker_io(ident) + str = String(take!(wio)) + println(stdout, str) + flush(stdout) + end +end + +import Pkg + +if IS_USE_DEV + Pkg.rm("ReactiveMP") + Pkg.rm("GraphPPL") + Pkg.rm("Rocket") + Pkg.develop(Pkg.PackageSpec(path = joinpath(Pkg.devdir(), "ReactiveMP.jl"))) + Pkg.develop(Pkg.PackageSpec(path = joinpath(Pkg.devdir(), "GraphPPL.jl"))) + Pkg.develop(Pkg.PackageSpec(path = joinpath(Pkg.devdir(), "Rocket.jl"))) + Pkg.update() +end + +# DocMeta.setdocmeta!(RxInfer, :DocTestSetup, :(using RxInfer, Distributions); recursive=true) + +# Example usage of a reduced testset +# julia --project --color=yes -e 'import Pkg; Pkg.test(test_args = [ "distributions:normal_mean_variance" ])' + +@info "Running tests using $(NUM_RUNNERS) runners." +addprocs(NUM_RUNNERS) + +@everywhere using Test, Documenter, RxInfer +@everywhere using TestSetExtensions + +import Base: wait + +mutable struct TestRunner + enabled_tests + found_tests + test_tasks + workerpool + jobschannel + exschannel + iochannel + + function TestRunner(ARGS) + enabled_tests = lowercase.(ARGS) + found_tests = Dict(map(test -> test => false, enabled_tests)) + test_tasks = [] + jobschannel = RemoteChannel(() -> Channel(Inf), myid()) # Channel for jobs + exschannel = RemoteChannel(() -> Channel(Inf), myid()) # Channel for exceptions + iochannel = RemoteChannel(() -> Channel(0), myid()) + @async begin + while isopen(iochannel) + ident = take!(iochannel) + flush_workerio(ident) + end + end + return new(enabled_tests, found_tests, test_tasks, 2:nprocs(), jobschannel, exschannel, iochannel) + end +end + +function Base.run(testrunner::TestRunner) + println("") # New line for 'better' alignment of the `testrunner` results + + foreach(testrunner.workerpool) do worker + # For each worker we create a `nothing` token in the `jobschannel` + # This token indicates that there are no other jobs left + put!(testrunner.jobschannel, nothing) + # We create a remote call for another Julia process to execute our test with `include(filename)` + task = remotecall(worker, testrunner.jobschannel, testrunner.exschannel, testrunner.iochannel) do jobschannel, exschannel, iochannel + finish = false + while !finish + # Each worker takes jobs sequentially from the shared jobs pool + job_filename = take!(jobschannel) + if isnothing(job_filename) # At the end there are should be only `emptyjobs`, in which case the worker finishes its tasks + finish = true + else # Otherwise we assume that the `job` contains the valid `filename` and execute test + try # Here we can easily get the `LoadError` if some tests are failing + include(job_filename) + catch iexception + put!(exschannel, iexception) + end + # After the work is done we put the worker's `id` into `iochannel` (this triggers test info printing) + put!(iochannel, myid()) + end + end + return nothing + end + # We save the created task for later synchronization + push!(testrunner.test_tasks, task) + end + + # For each remotely called task we `fetch` its result or save an exception + foreach(fetch, testrunner.test_tasks) + + # If exception are not empty we notify the user and force-fail + if isready(testrunner.exschannel) + println(stderr, "Tests have failed with the following exceptions: ") + while isready(testrunner.exschannel) + exception = take!(testrunner.exschannel) + showerror(stderr, exception) + println(stderr, "\n", "="^80) + end + exit(-1) + end + + close(testrunner.iochannel) + close(testrunner.exschannel) + close(testrunner.jobschannel) + + # At the very last stage we check that there are no "missing" tests, + # aka tests that have been specified in the `enabled_tests`, + # but for which the corresponding `filename` does not exist in the `test/` folder + notfound_tests = filter(v -> v[2] === false, testrunner.found_tests) + if !isempty(notfound_tests) + println(stderr, "There are missing tests, double check correct spelling/path for the following entries:") + foreach(keys(notfound_tests)) do key + println(stderr, " - ", key) + end + exit(-1) + end +end + +const testrunner = TestRunner(lowercase.(ARGS)) + +@everywhere workerlocal_lock = ReentrantLock() + +function addtests(testrunner::TestRunner, filename) + # First we transform filename into `key` and check if we have this entry in the `enabled_tests` (if `enabled_tests` is not empty) + key = filename_to_key(filename) + if isempty(testrunner.enabled_tests) || key in testrunner.enabled_tests + # If `enabled_tests` is not empty we mark the corresponding key with the `true` value to indicate that we found the corresponding `file` in the `/test` folder + if !isempty(testrunner.enabled_tests) + setindex!(testrunner.found_tests, true, key) # Mark that test has been found + end + # At this stage we simply put the `filename` into the `jobschannel` that will be processed later (see the `execute` function) + put!(testrunner.jobschannel, filename) + end +end + +function key_to_filename(key) + splitted = split(key, ":") + return if length(splitted) === 1 + string("test_", first(splitted), ".jl") + else + string(join(splitted[1:(end - 1)], "/"), "/test_", splitted[end], ".jl") + end +end + +function filename_to_key(filename) + splitted = split(filename, "/") + if length(splitted) === 1 + return replace(replace(first(splitted), ".jl" => ""), "test_" => "") + else + path, name = splitted[1:(end - 1)], splitted[end] + return string(join(path, ":"), ":", replace(replace(name, ".jl" => ""), "test_" => "")) + end +end + +using Aqua + +if isempty(testrunner.enabled_tests) + println("Running all tests (including Aqua)...") + # We pirate some methods from ReactiveMP for now + Aqua.test_all(RxInfer; ambiguities = false, piracies = false, deps_compat = (; check_extras = false, check_weakdeps = true)) +else + println("Running specific tests only:") + foreach(testrunner.enabled_tests) do test + println(" - ", test) + end +end + +@testset ExtendedTestSet "RxInfer" begin + @testset "Testset helpers" begin + @test key_to_filename(filename_to_key("distributions/test_normal_mean_variance.jl")) == "distributions/test_normal_mean_variance.jl" + @test filename_to_key(key_to_filename("distributions:normal_mean_variance")) == "distributions:normal_mean_variance" + @test key_to_filename(filename_to_key("test_message.jl")) == "test_message.jl" + @test filename_to_key(key_to_filename("message")) == "message" + end + + addtests(testrunner, "test_helpers.jl") + + addtests(testrunner, "score/test_bfe.jl") + + addtests(testrunner, "constraints/test_meta_constraints.jl") + addtests(testrunner, "constraints/test_form_constraints.jl") + addtests(testrunner, "constraints/test_factorisation_constraints.jl") + addtests(testrunner, "constraints/form/test_form_point_mass.jl") + addtests(testrunner, "constraints/form/test_form_sample_list.jl") + + addtests(testrunner, "test_node.jl") + addtests(testrunner, "test_model.jl") + addtests(testrunner, "test_inference.jl") + + addtests(testrunner, "models/aliases/test_aliases_binary.jl") + addtests(testrunner, "models/aliases/test_aliases_normal.jl") + + addtests(testrunner, "models/autoregressive/test_ar.jl") + addtests(testrunner, "models/autoregressive/test_lar.jl") + + addtests(testrunner, "models/datavars/test_fn_datavars.jl") + + addtests(testrunner, "models/mixtures/test_gmm_univariate.jl") + addtests(testrunner, "models/mixtures/test_gmm_multivariate.jl") + addtests(testrunner, "models/mixtures/test_mixture.jl") + + addtests(testrunner, "models/statespace/test_ulgssm.jl") + addtests(testrunner, "models/statespace/test_mlgssm.jl") + addtests(testrunner, "models/statespace/test_hmm.jl") + addtests(testrunner, "models/statespace/test_probit.jl") + addtests(testrunner, "models/statespace/test_hgf.jl") + + addtests(testrunner, "models/iid/test_mv_iid_precision.jl") + addtests(testrunner, "models/iid/test_mv_iid_precision_known_mean.jl") + addtests(testrunner, "models/iid/test_mv_iid_covariance.jl") + addtests(testrunner, "models/iid/test_mv_iid_covariance_known_mean.jl") + + addtests(testrunner, "models/nonlinear/test_generic_applicability.jl") + addtests(testrunner, "models/nonlinear/test_cvi.jl") + + addtests(testrunner, "models/regression/test_linreg.jl") + + run(testrunner) +end From d48cce1f77bdb24d347f99dc07f9b53ccc7a0f98 Mon Sep 17 00:00:00 2001 From: MarcoH Date: Mon, 27 Nov 2023 13:48:07 +0100 Subject: [PATCH 3/9] renaming functions --- Project.toml | 1 - ....jl => factorisation_constraints_tests.jl} | 14 ++----- ...point_mass.jl => form_point_mass_tests.jl} | 41 ++++++++----------- ...mple_list.jl => form_sample_list_tests.jl} | 13 +----- ...nstraints.jl => form_constraints_tests.jl} | 19 +++------ ...nstraints.jl => meta_constraints_tests.jl} | 14 ++----- 6 files changed, 32 insertions(+), 70 deletions(-) rename test/constraints/{test_factorisation_constraints.jl => factorisation_constraints_tests.jl} (99%) rename test/constraints/form/{test_form_point_mass.jl => form_point_mass_tests.jl} (71%) rename test/constraints/form/{test_form_sample_list.jl => form_sample_list_tests.jl} (86%) rename test/constraints/{test_form_constraints.jl => form_constraints_tests.jl} (96%) rename test/constraints/{test_meta_constraints.jl => meta_constraints_tests.jl} (98%) diff --git a/Project.toml b/Project.toml index 4a6719014..f8d8871cd 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,6 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Optim = "429524aa-4258-5aef-a3af-852621145aeb" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReactiveMP = "a194aa59-28ba-4574-a09c-4a745416d6e3" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Rocket = "df971d30-c9d6-4b37-b8ff-e965b2cb3a40" diff --git a/test/constraints/test_factorisation_constraints.jl b/test/constraints/factorisation_constraints_tests.jl similarity index 99% rename from test/constraints/test_factorisation_constraints.jl rename to test/constraints/factorisation_constraints_tests.jl index c4e4a2714..03a45f266 100644 --- a/test/constraints/test_factorisation_constraints.jl +++ b/test/constraints/factorisation_constraints_tests.jl @@ -1,12 +1,7 @@ -module RxInferFactorisationConstraintsTest - -using Test, Logging -using RxInfer - -import ReactiveMP: resolve_factorisation, setanonymous! -import ReactiveMP: activate! - -@testset "Factorisation constraints resolution with @constraints" begin +@testitem "Factorisation constraints resolution with @constraints" begin + using Logging + import ReactiveMP: resolve_factorisation, setanonymous! + import ReactiveMP: activate! # Factorisation constrains resolution function accepts a `fform` symbol as an input for error printing # We don't care about actual symbol in tests @@ -636,4 +631,3 @@ import ReactiveMP: activate! end end -end diff --git a/test/constraints/form/test_form_point_mass.jl b/test/constraints/form/form_point_mass_tests.jl similarity index 71% rename from test/constraints/form/test_form_point_mass.jl rename to test/constraints/form/form_point_mass_tests.jl index 1238aa23c..4290335da 100644 --- a/test/constraints/form/test_form_point_mass.jl +++ b/test/constraints/form/form_point_mass_tests.jl @@ -1,29 +1,24 @@ -module RxInferPointMassFormConstraintTest +@testitem "PointMassFormConstraint" begin + using Test + using RxInfer, LinearAlgebra + using Random, StableRNGs, DomainSets, Distributions -using Test -using RxInfer, LinearAlgebra -using Random, StableRNGs, DomainSets, Distributions - -import ReactiveMP: constrain_form -import RxInfer: PointMassFormConstraint, is_point_mass_form_constraint, call_boundaries, call_starting_point, call_optimizer - -struct MyDistributionWithMode <: ContinuousUnivariateDistribution - mode::Float64 -end + struct MyDistributionWithMode <: ContinuousUnivariateDistribution + mode::Float64 + end -# We are testing specifically that the point mass optimizer does not call `logpdf` and -# chooses a fast path with `mode` for `<: Distribution` objects -Distributions.logpdf(::MyDistributionWithMode, _) = error("This should not be called") -Distributions.mode(d::MyDistributionWithMode) = d.mode -Distributions.support(::MyDistributionWithMode) = RealInterval(-Inf, Inf) + # We are testing specifically that the point mass optimizer does not call `logpdf` and + # chooses a fast path with `mode` for `<: Distribution` objects + Distributions.logpdf(::MyDistributionWithMode, _) = error("This should not be called") + Distributions.mode(d::MyDistributionWithMode) = d.mode + Distributions.support(::MyDistributionWithMode) = RealInterval(-Inf, Inf) -const arbitrary_dist_1 = ContinuousUnivariateLogPdf(RealLine(), (x) -> logpdf(NormalMeanVariance(0, 1), x)) -const arbitrary_dist_2 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(Gamma(1, 1), x)) -const arbitrary_dist_3 = ContinuousUnivariateLogPdf(RealLine(), (x) -> logpdf(NormalMeanVariance(-10, 10), x)) -const arbitrary_dist_4 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(GammaShapeRate(100, 10), x)) -const arbitrary_dist_5 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(GammaShapeRate(100, 100), x)) + const arbitrary_dist_1 = ContinuousUnivariateLogPdf(RealLine(), (x) -> logpdf(NormalMeanVariance(0, 1), x)) + const arbitrary_dist_2 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(Gamma(1, 1), x)) + const arbitrary_dist_3 = ContinuousUnivariateLogPdf(RealLine(), (x) -> logpdf(NormalMeanVariance(-10, 10), x)) + const arbitrary_dist_4 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(GammaShapeRate(100, 10), x)) + const arbitrary_dist_5 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(GammaShapeRate(100, 100), x)) -@testset "PointMassFormConstraint" begin @testset "is_point_mass_form_constraint" begin @test is_point_mass_form_constraint(PointMassFormConstraint()) end @@ -112,5 +107,3 @@ const arbitrary_dist_5 = ContinuousUnivariateLogPdf(HalfLine(), (x) -> logpdf(Ga end end end - -end diff --git a/test/constraints/form/test_form_sample_list.jl b/test/constraints/form/form_sample_list_tests.jl similarity index 86% rename from test/constraints/form/test_form_sample_list.jl rename to test/constraints/form/form_sample_list_tests.jl index f7f554029..7f53a80e8 100644 --- a/test/constraints/form/test_form_sample_list.jl +++ b/test/constraints/form/form_sample_list_tests.jl @@ -1,13 +1,5 @@ - -module RxInferSampleListFormConstraintTest - -using Test -using RxInfer, LinearAlgebra -using Random, StableRNGs, DomainSets - -import RxInfer: SampleListFormConstraint, is_point_mass_form_constraint, constrain_form - -@testset "PointMassFormConstraint" begin +@testitem "PointMassFormConstraint" begin + import RxInfer: SampleListFormConstraint, is_point_mass_form_constraint, constrain_form @testset "is_point_mass_form_constraint" begin @test !is_point_mass_form_constraint(SampleListFormConstraint(100)) end @@ -59,4 +51,3 @@ import RxInfer: SampleListFormConstraint, is_point_mass_form_constraint, constra end end -end diff --git a/test/constraints/test_form_constraints.jl b/test/constraints/form_constraints_tests.jl similarity index 96% rename from test/constraints/test_form_constraints.jl rename to test/constraints/form_constraints_tests.jl index da0d6ee62..237a87491 100644 --- a/test/constraints/test_form_constraints.jl +++ b/test/constraints/form_constraints_tests.jl @@ -1,14 +1,9 @@ -module RxInferFormConstraintsSpecificationTest - -using Test, Logging -using RxInfer - -import RxInfer: PointMassFormConstraint, SampleListFormConstraint, FixedMarginalFormConstraint -import ReactiveMP: CompositeFormConstraint -import ReactiveMP: resolve_marginal_form_prod, resolve_messages_form_prod -import ReactiveMP: activate! - -@testset "Form constraints specification with @constraints macro" begin +@testitem "Form constraints specification with @constraints macro" begin + using Logging + import RxInfer: PointMassFormConstraint, SampleListFormConstraint, FixedMarginalFormConstraint + import ReactiveMP: CompositeFormConstraint + import ReactiveMP: resolve_marginal_form_prod, resolve_messages_form_prod + import ReactiveMP: activate! @testset "Use case #1" begin cs = @constraints begin q(x)::PointMass @@ -410,5 +405,3 @@ import ReactiveMP: activate! end end end - -end diff --git a/test/constraints/test_meta_constraints.jl b/test/constraints/meta_constraints_tests.jl similarity index 98% rename from test/constraints/test_meta_constraints.jl rename to test/constraints/meta_constraints_tests.jl index f8c2f9d81..85ce62631 100644 --- a/test/constraints/test_meta_constraints.jl +++ b/test/constraints/meta_constraints_tests.jl @@ -1,12 +1,6 @@ -module ReactiveMPMetaSpecificationHelpers - -using Test -using RxInfer -using Distributions -using Logging - -@testset "Meta specification with @meta macro" begin +@testitem "Meta specification with @meta macro" begin import ReactiveMP: resolve_meta, make_node, activate! + using Logging struct SomeNode end struct SomeOtherNode end @@ -230,6 +224,4 @@ using Logging @test resolve_meta(meta, SomeNode, (y, z)) === nothing @test resolve_meta(meta, SomeNode, (z,)) === nothing end -end - -end +end \ No newline at end of file From e73c0e46c92d49d072f42650d97c275c39fd9b86 Mon Sep 17 00:00:00 2001 From: MarcoH Date: Mon, 27 Nov 2023 15:38:33 +0100 Subject: [PATCH 4/9] using RetestItems for tests --- .../constraints/form/form_point_mass_tests.jl | 4 +- .../form/form_sample_list_tests.jl | 1 + test/constraints/meta_constraints_tests.jl | 1 + test/models/aliases/aliases_binary_tests.jl | 21 ++ test/models/aliases/aliases_normal_tests.jl | 45 +++++ test/models/aliases/test_aliases_binary.jl | 29 --- test/models/aliases/test_aliases_normal.jl | 51 ----- test/models/autoregressive/ar_tests.jl | 66 +++++++ test/models/autoregressive/lar_tests.jl | 181 +++++++++++++++++ test/models/autoregressive/test_ar.jl | 70 ------- test/models/autoregressive/test_lar.jl | 187 ------------------ ...st_fn_datavars.jl => fn_datavars_tests.jl} | 104 +++++----- ... => mv_iid_covariance_known_mean_tests.jl} | 38 ++-- ...variance.jl => mv_iid_covariance_tests.jl} | 70 +++---- ...l => mv_iid_precision_known_mean_tests.jl} | 41 ++-- test/models/iid/mv_iid_precision_tests.jl | 70 +++++++ test/models/iid/test_mv_iid_precision.jl | 77 -------- ...tivariate.jl => gmm_multivariate_tests.jl} | 139 +++++++------ ..._univariate.jl => gmm_univariate_tests.jl} | 62 +++--- .../{test_mixture.jl => mixture_tests.jl} | 73 ++++--- .../nonlinear/{test_cvi.jl => cvi_tests.jl} | 135 ++++++------- .../nonlinear/generic_applicability_tests.jl | 158 +++++++++++++++ .../nonlinear/test_generic_applicability.jl | 164 --------------- test/models/regression/linreg_tests.jl | 78 ++++++++ test/models/regression/test_linreg.jl | 84 -------- .../statespace/{test_hgf.jl => hgf_tests.jl} | 116 ++++++----- .../statespace/{test_hmm.jl => hmm_tests.jl} | 76 ++++--- .../{test_mlgssm.jl => mlgssm_test.jl} | 60 +++--- .../{test_probit.jl => probit_tests.jl} | 46 ++--- .../{test_ulgssm.jl => ulgssm_tests.jl} | 46 ++--- test/score/{test_actor.jl => actor_tests.jl} | 17 +- test/score/{test_bfe.jl => bfe_tests.jl} | 17 +- 32 files changed, 1097 insertions(+), 1230 deletions(-) create mode 100644 test/models/aliases/aliases_binary_tests.jl create mode 100644 test/models/aliases/aliases_normal_tests.jl delete mode 100644 test/models/aliases/test_aliases_binary.jl delete mode 100644 test/models/aliases/test_aliases_normal.jl create mode 100644 test/models/autoregressive/ar_tests.jl create mode 100644 test/models/autoregressive/lar_tests.jl delete mode 100644 test/models/autoregressive/test_ar.jl delete mode 100644 test/models/autoregressive/test_lar.jl rename test/models/datavars/{test_fn_datavars.jl => fn_datavars_tests.jl} (56%) rename test/models/iid/{test_mv_iid_covariance_known_mean.jl => mv_iid_covariance_known_mean_tests.jl} (57%) rename test/models/iid/{test_mv_iid_covariance.jl => mv_iid_covariance_tests.jl} (50%) rename test/models/iid/{test_mv_iid_precision_known_mean.jl => mv_iid_precision_known_mean_tests.jl} (57%) create mode 100644 test/models/iid/mv_iid_precision_tests.jl delete mode 100644 test/models/iid/test_mv_iid_precision.jl rename test/models/mixtures/{test_gmm_multivariate.jl => gmm_multivariate_tests.jl} (53%) rename test/models/mixtures/{test_gmm_univariate.jl => gmm_univariate_tests.jl} (72%) rename test/models/mixtures/{test_mixture.jl => mixture_tests.jl} (75%) rename test/models/nonlinear/{test_cvi.jl => cvi_tests.jl} (57%) create mode 100644 test/models/nonlinear/generic_applicability_tests.jl delete mode 100644 test/models/nonlinear/test_generic_applicability.jl create mode 100644 test/models/regression/linreg_tests.jl delete mode 100644 test/models/regression/test_linreg.jl rename test/models/statespace/{test_hgf.jl => hgf_tests.jl} (50%) rename test/models/statespace/{test_hmm.jl => hmm_tests.jl} (58%) rename test/models/statespace/{test_mlgssm.jl => mlgssm_test.jl} (68%) rename test/models/statespace/{test_probit.jl => probit_tests.jl} (62%) rename test/models/statespace/{test_ulgssm.jl => ulgssm_tests.jl} (60%) rename test/score/{test_actor.jl => actor_tests.jl} (95%) rename test/score/{test_bfe.jl => bfe_tests.jl} (93%) diff --git a/test/constraints/form/form_point_mass_tests.jl b/test/constraints/form/form_point_mass_tests.jl index 4290335da..34d7a3b19 100644 --- a/test/constraints/form/form_point_mass_tests.jl +++ b/test/constraints/form/form_point_mass_tests.jl @@ -1,7 +1,7 @@ @testitem "PointMassFormConstraint" begin - using Test - using RxInfer, LinearAlgebra + using LinearAlgebra using Random, StableRNGs, DomainSets, Distributions + import RxInfer: PointMassFormConstraint, is_point_mass_form_constraint, call_boundaries, call_starting_point, call_optimizer struct MyDistributionWithMode <: ContinuousUnivariateDistribution mode::Float64 diff --git a/test/constraints/form/form_sample_list_tests.jl b/test/constraints/form/form_sample_list_tests.jl index 7f53a80e8..45f6ca69b 100644 --- a/test/constraints/form/form_sample_list_tests.jl +++ b/test/constraints/form/form_sample_list_tests.jl @@ -1,4 +1,5 @@ @testitem "PointMassFormConstraint" begin + using DomainSets, StableRNGs, DomainSets, Distributions, Random, LinearAlgebra import RxInfer: SampleListFormConstraint, is_point_mass_form_constraint, constrain_form @testset "is_point_mass_form_constraint" begin @test !is_point_mass_form_constraint(SampleListFormConstraint(100)) diff --git a/test/constraints/meta_constraints_tests.jl b/test/constraints/meta_constraints_tests.jl index 85ce62631..f83168628 100644 --- a/test/constraints/meta_constraints_tests.jl +++ b/test/constraints/meta_constraints_tests.jl @@ -1,5 +1,6 @@ @testitem "Meta specification with @meta macro" begin import ReactiveMP: resolve_meta, make_node, activate! + using Distributions using Logging struct SomeNode end diff --git a/test/models/aliases/aliases_binary_tests.jl b/test/models/aliases/aliases_binary_tests.jl new file mode 100644 index 000000000..673646bb8 --- /dev/null +++ b/test/models/aliases/aliases_binary_tests.jl @@ -0,0 +1,21 @@ +@testitem "aliases for binary operations" begin + @model function binary_aliases() + x1 ~ Bernoulli(0.5) + x2 ~ Bernoulli(0.5) + x3 ~ Bernoulli(0.5) + x4 ~ Bernoulli(0.5) + + x ~ x1 -> x2 && x3 || ¬x4 + + y = datavar(Float64) + x ~ Bernoulli(y) + end + + function binary_aliases_inference() + return inference(model = binary_aliases(), data = (y = 0.5,), free_energy = true) + end + results = binary_aliases_inference() + # Here we simply test that it ran and gave some output + @test mean(results.posteriors[:x1]) ≈ 0.5 + @test first(results.free_energy) ≈ 0.6931471805599454 +end \ No newline at end of file diff --git a/test/models/aliases/aliases_normal_tests.jl b/test/models/aliases/aliases_normal_tests.jl new file mode 100644 index 000000000..fb38b5eff --- /dev/null +++ b/test/models/aliases/aliases_normal_tests.jl @@ -0,0 +1,45 @@ +@testitem "aliases for `Normal` family of distributions" begin + @model function normal_aliases() + x1 ~ MvNormal(μ = zeros(2), Σ⁻¹ = diageye(2)) + x2 ~ MvNormal(μ = zeros(2), Λ = diageye(2)) + x3 ~ MvNormal(mean = zeros(2), W = diageye(2)) + x4 ~ MvNormal(μ = zeros(2), prec = diageye(2)) + x5 ~ MvNormal(m = zeros(2), precision = diageye(2)) + + y1 ~ MvNormal(mean = zeros(2), Σ = diageye(2)) + y2 ~ MvNormal(m = zeros(2), Λ⁻¹ = diageye(2)) + y3 ~ MvNormal(μ = zeros(2), V = diageye(2)) + y4 ~ MvNormal(mean = zeros(2), cov = diageye(2)) + y5 ~ MvNormal(mean = zeros(2), covariance = diageye(2)) + + x ~ x1 + x2 + x3 + x4 + x5 + y ~ y1 + y2 + y3 + y4 + y5 + + r1 ~ Normal(μ = dot(x + y, ones(2)), τ = 1.0) + r2 ~ Normal(m = r1, γ = 1.0) + r3 ~ Normal(mean = r2, σ⁻² = 1.0) + r4 ~ Normal(mean = r3, w = 1.0) + r5 ~ Normal(mean = r4, p = 1.0) + r6 ~ Normal(mean = r5, prec = 1.0) + r7 ~ Normal(mean = r6, precision = 1.0) + + s1 ~ Normal(μ = r7, σ² = 1.0) + s2 ~ Normal(m = s1, τ⁻¹ = 1.0) + s3 ~ Normal(mean = s2, v = 1.0) + s4 ~ Normal(mean = s3, var = 1.0) + s5 ~ Normal(mean = s4, variance = 1.0) + + d = datavar(Float64) + d ~ Normal(μ = s5, variance = 1.0) + end + + function normal_aliases_inference() + return inference(model = normal_aliases(), data = (d = 1.0,), returnvars = (x1 = KeepLast(),), free_energy = true) + end + result = normal_aliases_inference() + # Here we simply test that it ran and gave some output + @test first(mean(result.posteriors[:x1])) ≈ 0.04182509505703423 + @test first(result.free_energy) ≈ 2.319611135721246 +end + + diff --git a/test/models/aliases/test_aliases_binary.jl b/test/models/aliases/test_aliases_binary.jl deleted file mode 100644 index c010762dc..000000000 --- a/test/models/aliases/test_aliases_binary.jl +++ /dev/null @@ -1,29 +0,0 @@ -module RxInferModelsAliasesTest - -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, LinearAlgebra, StableRNGs - -@model function binary_aliases() - x1 ~ Bernoulli(0.5) - x2 ~ Bernoulli(0.5) - x3 ~ Bernoulli(0.5) - x4 ~ Bernoulli(0.5) - - x ~ x1 -> x2 && x3 || ¬x4 - - y = datavar(Float64) - x ~ Bernoulli(y) -end - -function binary_aliases_inference() - return inference(model = binary_aliases(), data = (y = 0.5,), free_energy = true) -end - -@testset "aliases for binary operations" begin - results = binary_aliases_inference() - # Here we simply test that it ran and gave some output - @test mean(results.posteriors[:x1]) ≈ 0.5 - @test first(results.free_energy) ≈ 0.6931471805599454 -end - -end diff --git a/test/models/aliases/test_aliases_normal.jl b/test/models/aliases/test_aliases_normal.jl deleted file mode 100644 index 6776dee81..000000000 --- a/test/models/aliases/test_aliases_normal.jl +++ /dev/null @@ -1,51 +0,0 @@ -module RxInferModelsAliasesTest - -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, LinearAlgebra, StableRNGs - -@model function normal_aliases() - x1 ~ MvNormal(μ = zeros(2), Σ⁻¹ = diageye(2)) - x2 ~ MvNormal(μ = zeros(2), Λ = diageye(2)) - x3 ~ MvNormal(mean = zeros(2), W = diageye(2)) - x4 ~ MvNormal(μ = zeros(2), prec = diageye(2)) - x5 ~ MvNormal(m = zeros(2), precision = diageye(2)) - - y1 ~ MvNormal(mean = zeros(2), Σ = diageye(2)) - y2 ~ MvNormal(m = zeros(2), Λ⁻¹ = diageye(2)) - y3 ~ MvNormal(μ = zeros(2), V = diageye(2)) - y4 ~ MvNormal(mean = zeros(2), cov = diageye(2)) - y5 ~ MvNormal(mean = zeros(2), covariance = diageye(2)) - - x ~ x1 + x2 + x3 + x4 + x5 - y ~ y1 + y2 + y3 + y4 + y5 - - r1 ~ Normal(μ = dot(x + y, ones(2)), τ = 1.0) - r2 ~ Normal(m = r1, γ = 1.0) - r3 ~ Normal(mean = r2, σ⁻² = 1.0) - r4 ~ Normal(mean = r3, w = 1.0) - r5 ~ Normal(mean = r4, p = 1.0) - r6 ~ Normal(mean = r5, prec = 1.0) - r7 ~ Normal(mean = r6, precision = 1.0) - - s1 ~ Normal(μ = r7, σ² = 1.0) - s2 ~ Normal(m = s1, τ⁻¹ = 1.0) - s3 ~ Normal(mean = s2, v = 1.0) - s4 ~ Normal(mean = s3, var = 1.0) - s5 ~ Normal(mean = s4, variance = 1.0) - - d = datavar(Float64) - d ~ Normal(μ = s5, variance = 1.0) -end - -function normal_aliases_inference() - return inference(model = normal_aliases(), data = (d = 1.0,), returnvars = (x1 = KeepLast(),), free_energy = true) -end - -@testset "aliases for `Normal` family of distributions" begin - result = normal_aliases_inference() - # Here we simply test that it ran and gave some output - @test first(mean(result.posteriors[:x1])) ≈ 0.04182509505703423 - @test first(result.free_energy) ≈ 2.319611135721246 -end - -end diff --git a/test/models/autoregressive/ar_tests.jl b/test/models/autoregressive/ar_tests.jl new file mode 100644 index 000000000..f05b3045a --- /dev/null +++ b/test/models/autoregressive/ar_tests.jl @@ -0,0 +1,66 @@ +@testitem "Autoregressive model" begin + using StableRNGs, BenchmarkTools + + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + + @model function ar_model(n, order) + x = datavar(Vector{Float64}, n) + y = datavar(Float64, n) + + γ ~ Gamma(shape = 1.0, rate = 1.0) + θ ~ MvNormal(mean = zeros(order), precision = diageye(order)) + + for i in 1:n + y[i] ~ Normal(mean = dot(x[i], θ), precision = γ) + end + end + + function ar_inference(inputs, outputs, order, niter) + return inference( + model = ar_model(length(outputs), order), + data = (x = inputs, y = outputs), + constraints = MeanField(), + options = (limit_stack_depth = 500,), + initmarginals = (γ = GammaShapeRate(1.0, 1.0),), + returnvars = (γ = KeepEach(), θ = KeepEach()), + iterations = niter, + free_energy = Float64 + ) + end + + function ar_ssm(series, order) + inputs = [reverse!(series[1:order])] + outputs = [series[order + 1]] + for x in series[(order + 2):end] + push!(inputs, vcat(outputs[end], inputs[end])[1:(end - 1)]) + push!(outputs, x) + end + return inputs, outputs + end + rng = StableRNG(1234) + + + + ## Inference execution and test inference results + for order in 1:5 + series = randn(rng, 1_000) + inputs, outputs = ar_ssm(series, order) + result = ar_inference(inputs, outputs, order, 15) + qs = result.posteriors + + (γ, θ) = (qs[:γ], qs[:θ]) + fe = result.free_energy + + @test length(γ) === 15 + @test length(θ) === 15 + @test length(fe) === 15 + @test last(fe) < first(fe) + @test all(filter(e -> abs(e) > 1e-3, diff(fe)) .< 0) + end + + benchrng = randn(StableRNG(32), 1_000) + inputs5, outputs5 = ar_ssm(benchrng, 5) + + @test_benchmark "models" "ar" ar_inference($inputs5, $outputs5, 5, 15) +end \ No newline at end of file diff --git a/test/models/autoregressive/lar_tests.jl b/test/models/autoregressive/lar_tests.jl new file mode 100644 index 000000000..661e2e77b --- /dev/null +++ b/test/models/autoregressive/lar_tests.jl @@ -0,0 +1,181 @@ +@testitem "Latent autoregressive model" begin + using StableRNGs, Plots, BenchmarkTools + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + + @model function lar_model(::Type{Multivariate}, n, order, c, stype, τ) + + # Parameter priors + γ ~ Gamma(shape = 1.0, rate = 1.0) + θ ~ MvNormal(mean = zeros(order), precision = diageye(order)) + + # We create a sequence of random variables for hidden states + x = randomvar(n) + # As well a sequence of observartions + y = datavar(Float64, n) + + ct = constvar(c) + # We assume observation noise to be known + cτ = constvar(τ) + + # Prior for first state + x0 ~ MvNormal(mean = zeros(order), precision = diageye(order)) + + x_prev = x0 + + # AR process requires extra meta information + meta = ARMeta(Multivariate, order, stype) + + for i in 1:n + # Autoregressive node uses structured factorisation assumption between states + x[i] ~ AR(x_prev, θ, γ) where {q = q(y, x)q(γ)q(θ), meta = meta} + y[i] ~ Normal(mean = dot(ct, x[i]), precision = cτ) + x_prev = x[i] + end + end + + @model function lar_model(::Type{Univariate}, n, order, c, stype, τ) + + # Parameter priors + γ ~ Gamma(shape = 1.0, rate = 1.0) + θ ~ Normal(mean = 0.0, precision = 1.0) + + # We create a sequence of random variables for hidden states + x = randomvar(n) + # As well a sequence of observartions + y = datavar(Float64, n) + + ct = constvar(c) + # We assume observation noise to be known + cτ = constvar(τ) + + # Prior for first state + x0 ~ Normal(mean = 0.0, precision = 1.0) + + x_prev = x0 + + # AR process requires extra meta information + meta = ARMeta(Univariate, order, stype) + + for i in 1:n + x[i] ~ AR(x_prev, θ, γ) where {q = q(y, x)q(γ)q(θ), meta = meta} + y[i] ~ Normal(mean = ct * x[i], precision = cτ) + x_prev = x[i] + end + end + + function lar_init_marginals(::Type{Multivariate}, order) + return (γ = GammaShapeRate(1.0, 1.0), θ = MvNormalMeanPrecision(zeros(order), diageye(order))) + end + + function lar_init_marginals(::Type{Univariate}, order) + return (γ = GammaShapeRate(1.0, 1.0), θ = NormalMeanPrecision(0.0, 1.0)) + end + + function lar_inference(data, order, artype, stype, niter, τ) + n = length(data) + c = ReactiveMP.ar_unit(artype, order) + return inference( + model = lar_model(artype, n, order, c, stype, τ), + data = (y = data,), + initmarginals = lar_init_marginals(artype, order), + returnvars = (γ = KeepEach(), θ = KeepEach(), x = KeepLast()), + iterations = niter, + free_energy = Float64 + ) + end + + # The following coefficients correspond to stable poles + coefs_ar_5 = [0.10699399235785655, -0.5237303489793305, 0.3068897071844715, -0.17232255282458891, 0.13323964347539288] + + function generate_lar_data(rng, n, θ, γ, τ) + order = length(θ) + states = Vector{Vector{Float64}}(undef, n + 3order) + observations = Vector{Float64}(undef, n + 3order) + + γ_std = sqrt(inv(γ)) + τ_std = sqrt(inv(γ)) + + states[1] = randn(rng, order) + + for i in 2:(n + 3order) + states[i] = vcat(rand(rng, Normal(dot(θ, states[i - 1]), γ_std)), states[i - 1][1:(end - 1)]) + observations[i] = rand(rng, Normal(states[i][1], τ_std)) + end + + return states[(1 + 3order):end], observations[(1 + 3order):end] + end + + + # Seed for reproducibility + rng = StableRNG(123) + + # Number of observations in synthetic dataset + n = 500 + + # AR process parameters + real_γ = 5.0 + real_τ = 5.0 + real_θ = coefs_ar_5 + states, observations = generate_lar_data(rng, n, real_θ, real_γ, real_τ) + + # Test AR(1) + Univariate + result = lar_inference(observations, 1, Univariate, ARsafe(), 15, real_τ) + qs = result.posteriors + fe = result.free_energy + + (γ, θ, xs) = (qs[:γ], qs[:θ], qs[:x]) + + @test length(xs) === n + @test length(γ) === 15 + @test length(θ) === 15 + @test length(fe) === 15 + @test abs(last(fe) - 518.9182342) < 0.01 + @test last(fe) < first(fe) + @test all(filter(e -> abs(e) > 1e-3, diff(fe)) .< 0) + + # Test AR(k) + Multivariate + for k in 1:4 + result = lar_inference(observations, k, Multivariate, ARsafe(), 15, real_τ) + qs = result.posteriors + fe = result.free_energy + + (γ, θ, xs) = (qs[:γ], qs[:θ], qs[:x]) + + @test length(xs) === n + @test length(γ) === 15 + @test length(θ) === 15 + @test length(fe) === 15 + @test last(fe) < first(fe) + end + + # AR(5) + Multivariate + result = lar_inference(observations, length(real_θ), Multivariate, ARsafe(), 15, real_τ) + qs = result.posteriors + fe = result.free_energy + + (γ, θ, xs) = (qs[:γ], qs[:θ], qs[:x]) + + @test length(xs) === n + @test length(γ) === 15 + @test length(θ) === 15 + @test length(fe) === 15 + @test abs(last(fe) - 514.66086) < 0.01 + @test all(filter(e -> abs(e) > 1e-1, diff(fe)) .< 0) + @test (mean(last(γ)) - 3.0std(last(γ)) < real_γ < mean(last(γ)) + 3.0std(last(γ))) + + @test_plot "models" "lar" begin + p1 = plot(first.(states), label = "Hidden state") + p1 = scatter!(p1, observations, label = "Observations") + p1 = plot!(p1, first.(mean.(xs)), ribbon = sqrt.(first.(var.(xs))), label = "Inferred states", legend = :bottomright) + + p2 = plot(mean.(γ), ribbon = std.(γ), label = "Inferred transition precision", legend = :bottomright) + p2 = plot!([real_γ], seriestype = :hline, label = "Real transition precision") + + p3 = plot(fe, label = "Bethe Free Energy") + + p = plot(p1, p2, p3, layout = @layout([a; b c])) + end + + @test_benchmark "models" "lar" lar_inference($observations, length($real_θ), Multivariate, ARsafe(), 15, $real_τ) +end \ No newline at end of file diff --git a/test/models/autoregressive/test_ar.jl b/test/models/autoregressive/test_ar.jl deleted file mode 100644 index 1f6acf8c9..000000000 --- a/test/models/autoregressive/test_ar.jl +++ /dev/null @@ -1,70 +0,0 @@ -module RxInferModelsAutoregressiveTest - -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, LinearAlgebra, StableRNGs - -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) - -@model function ar_model(n, order) - x = datavar(Vector{Float64}, n) - y = datavar(Float64, n) - - γ ~ Gamma(shape = 1.0, rate = 1.0) - θ ~ MvNormal(mean = zeros(order), precision = diageye(order)) - - for i in 1:n - y[i] ~ Normal(mean = dot(x[i], θ), precision = γ) - end -end - -function ar_inference(inputs, outputs, order, niter) - return inference( - model = ar_model(length(outputs), order), - data = (x = inputs, y = outputs), - constraints = MeanField(), - options = (limit_stack_depth = 500,), - initmarginals = (γ = GammaShapeRate(1.0, 1.0),), - returnvars = (γ = KeepEach(), θ = KeepEach()), - iterations = niter, - free_energy = Float64 - ) -end - -function ar_ssm(series, order) - inputs = [reverse!(series[1:order])] - outputs = [series[order + 1]] - for x in series[(order + 2):end] - push!(inputs, vcat(outputs[end], inputs[end])[1:(end - 1)]) - push!(outputs, x) - end - return inputs, outputs -end - -@testset "Autoregressive model" begin - rng = StableRNG(1234) - - ## Inference execution and test inference results - for order in 1:5 - series = randn(rng, 1_000) - inputs, outputs = ar_ssm(series, order) - result = ar_inference(inputs, outputs, order, 15) - qs = result.posteriors - - (γ, θ) = (qs[:γ], qs[:θ]) - fe = result.free_energy - - @test length(γ) === 15 - @test length(θ) === 15 - @test length(fe) === 15 - @test last(fe) < first(fe) - @test all(filter(e -> abs(e) > 1e-3, diff(fe)) .< 0) - end - - benchrng = randn(StableRNG(32), 1_000) - inputs5, outputs5 = ar_ssm(benchrng, 5) - - @test_benchmark "models" "ar" ar_inference($inputs5, $outputs5, 5, 15) -end - -end diff --git a/test/models/autoregressive/test_lar.jl b/test/models/autoregressive/test_lar.jl deleted file mode 100644 index a1f3aa904..000000000 --- a/test/models/autoregressive/test_lar.jl +++ /dev/null @@ -1,187 +0,0 @@ -module RxInferModelsAutoregressiveTest - -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, LinearAlgebra, StableRNGs - -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) - -@model function lar_model(::Type{Multivariate}, n, order, c, stype, τ) - - # Parameter priors - γ ~ Gamma(shape = 1.0, rate = 1.0) - θ ~ MvNormal(mean = zeros(order), precision = diageye(order)) - - # We create a sequence of random variables for hidden states - x = randomvar(n) - # As well a sequence of observartions - y = datavar(Float64, n) - - ct = constvar(c) - # We assume observation noise to be known - cτ = constvar(τ) - - # Prior for first state - x0 ~ MvNormal(mean = zeros(order), precision = diageye(order)) - - x_prev = x0 - - # AR process requires extra meta information - meta = ARMeta(Multivariate, order, stype) - - for i in 1:n - # Autoregressive node uses structured factorisation assumption between states - x[i] ~ AR(x_prev, θ, γ) where {q = q(y, x)q(γ)q(θ), meta = meta} - y[i] ~ Normal(mean = dot(ct, x[i]), precision = cτ) - x_prev = x[i] - end -end - -@model function lar_model(::Type{Univariate}, n, order, c, stype, τ) - - # Parameter priors - γ ~ Gamma(shape = 1.0, rate = 1.0) - θ ~ Normal(mean = 0.0, precision = 1.0) - - # We create a sequence of random variables for hidden states - x = randomvar(n) - # As well a sequence of observartions - y = datavar(Float64, n) - - ct = constvar(c) - # We assume observation noise to be known - cτ = constvar(τ) - - # Prior for first state - x0 ~ Normal(mean = 0.0, precision = 1.0) - - x_prev = x0 - - # AR process requires extra meta information - meta = ARMeta(Univariate, order, stype) - - for i in 1:n - x[i] ~ AR(x_prev, θ, γ) where {q = q(y, x)q(γ)q(θ), meta = meta} - y[i] ~ Normal(mean = ct * x[i], precision = cτ) - x_prev = x[i] - end -end - -function lar_init_marginals(::Type{Multivariate}, order) - return (γ = GammaShapeRate(1.0, 1.0), θ = MvNormalMeanPrecision(zeros(order), diageye(order))) -end - -function lar_init_marginals(::Type{Univariate}, order) - return (γ = GammaShapeRate(1.0, 1.0), θ = NormalMeanPrecision(0.0, 1.0)) -end - -function lar_inference(data, order, artype, stype, niter, τ) - n = length(data) - c = ReactiveMP.ar_unit(artype, order) - return inference( - model = lar_model(artype, n, order, c, stype, τ), - data = (y = data,), - initmarginals = lar_init_marginals(artype, order), - returnvars = (γ = KeepEach(), θ = KeepEach(), x = KeepLast()), - iterations = niter, - free_energy = Float64 - ) -end - -# The following coefficients correspond to stable poles -coefs_ar_5 = [0.10699399235785655, -0.5237303489793305, 0.3068897071844715, -0.17232255282458891, 0.13323964347539288] - -function generate_lar_data(rng, n, θ, γ, τ) - order = length(θ) - states = Vector{Vector{Float64}}(undef, n + 3order) - observations = Vector{Float64}(undef, n + 3order) - - γ_std = sqrt(inv(γ)) - τ_std = sqrt(inv(γ)) - - states[1] = randn(rng, order) - - for i in 2:(n + 3order) - states[i] = vcat(rand(rng, Normal(dot(θ, states[i - 1]), γ_std)), states[i - 1][1:(end - 1)]) - observations[i] = rand(rng, Normal(states[i][1], τ_std)) - end - - return states[(1 + 3order):end], observations[(1 + 3order):end] -end - -@testset "Latent autoregressive model" begin - - # Seed for reproducibility - rng = StableRNG(123) - - # Number of observations in synthetic dataset - n = 500 - - # AR process parameters - real_γ = 5.0 - real_τ = 5.0 - real_θ = coefs_ar_5 - states, observations = generate_lar_data(rng, n, real_θ, real_γ, real_τ) - - # Test AR(1) + Univariate - result = lar_inference(observations, 1, Univariate, ARsafe(), 15, real_τ) - qs = result.posteriors - fe = result.free_energy - - (γ, θ, xs) = (qs[:γ], qs[:θ], qs[:x]) - - @test length(xs) === n - @test length(γ) === 15 - @test length(θ) === 15 - @test length(fe) === 15 - @test abs(last(fe) - 518.9182342) < 0.01 - @test last(fe) < first(fe) - @test all(filter(e -> abs(e) > 1e-3, diff(fe)) .< 0) - - # Test AR(k) + Multivariate - for k in 1:4 - result = lar_inference(observations, k, Multivariate, ARsafe(), 15, real_τ) - qs = result.posteriors - fe = result.free_energy - - (γ, θ, xs) = (qs[:γ], qs[:θ], qs[:x]) - - @test length(xs) === n - @test length(γ) === 15 - @test length(θ) === 15 - @test length(fe) === 15 - @test last(fe) < first(fe) - end - - # AR(5) + Multivariate - result = lar_inference(observations, length(real_θ), Multivariate, ARsafe(), 15, real_τ) - qs = result.posteriors - fe = result.free_energy - - (γ, θ, xs) = (qs[:γ], qs[:θ], qs[:x]) - - @test length(xs) === n - @test length(γ) === 15 - @test length(θ) === 15 - @test length(fe) === 15 - @test abs(last(fe) - 514.66086) < 0.01 - @test all(filter(e -> abs(e) > 1e-1, diff(fe)) .< 0) - @test (mean(last(γ)) - 3.0std(last(γ)) < real_γ < mean(last(γ)) + 3.0std(last(γ))) - - @test_plot "models" "lar" begin - p1 = plot(first.(states), label = "Hidden state") - p1 = scatter!(p1, observations, label = "Observations") - p1 = plot!(p1, first.(mean.(xs)), ribbon = sqrt.(first.(var.(xs))), label = "Inferred states", legend = :bottomright) - - p2 = plot(mean.(γ), ribbon = std.(γ), label = "Inferred transition precision", legend = :bottomright) - p2 = plot!([real_γ], seriestype = :hline, label = "Real transition precision") - - p3 = plot(fe, label = "Bethe Free Energy") - - p = plot(p1, p2, p3, layout = @layout([a; b c])) - end - - @test_benchmark "models" "lar" lar_inference($observations, length($real_θ), Multivariate, ARsafe(), 15, $real_τ) -end - -end diff --git a/test/models/datavars/test_fn_datavars.jl b/test/models/datavars/fn_datavars_tests.jl similarity index 56% rename from test/models/datavars/test_fn_datavars.jl rename to test/models/datavars/fn_datavars_tests.jl index 1b8c4846b..e79c1921f 100644 --- a/test/models/datavars/test_fn_datavars.jl +++ b/test/models/datavars/fn_datavars_tests.jl @@ -1,56 +1,52 @@ -module RxInferModelsDatavarsTest +@testitem "datavars" begin + using StableRNGs + # Please use StableRNGs for random number generators + + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + + ## Model definition + @model function sum_datavars_as_gaussian_mean_1() + a = datavar(Float64) + b = datavar(Float64) + y = datavar(Float64) + + x ~ Normal(mean = a + b, variance = 1.0) + y ~ Normal(mean = x, variance = 1.0) + end + + @model function sum_datavars_as_gaussian_mean_2() + a = datavar(Float64) + b = datavar(Float64) + c = constvar(0.0) # Should not change the result + y = datavar(Float64) + + x ~ Normal(mean = (a + b) + c, variance = 1.0) + y ~ Normal(mean = x, variance = 1.0) + end + + @model function ratio_datavars_as_gaussian_mean() + a = datavar(Float64) + b = datavar(Float64) + y = datavar(Float64) + + x ~ Normal(mean = a / b, variance = 1.0) + y ~ Normal(mean = x, variance = 1.0) + end + + @model function idx_datavars_as_gaussian_mean() + a = datavar(Vector{Float64}) + b = datavar(Matrix{Float64}) + y = datavar(Float64) + + x ~ Normal(mean = dot(a[1:2], b[1:2, 1]), variance = 1.0) + y ~ Normal(mean = x, variance = 1.0) + end + + # Inference function + function fn_datavars_inference(modelfn, adata, bdata, ydata) + return inference(model = modelfn(), data = (a = adata, b = bdata, y = ydata), free_energy = true) + end -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs - -# Please use StableRNGs for random number generators - -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) - -## Model definition -@model function sum_datavars_as_gaussian_mean_1() - a = datavar(Float64) - b = datavar(Float64) - y = datavar(Float64) - - x ~ Normal(mean = a + b, variance = 1.0) - y ~ Normal(mean = x, variance = 1.0) -end - -@model function sum_datavars_as_gaussian_mean_2() - a = datavar(Float64) - b = datavar(Float64) - c = constvar(0.0) # Should not change the result - y = datavar(Float64) - - x ~ Normal(mean = (a + b) + c, variance = 1.0) - y ~ Normal(mean = x, variance = 1.0) -end - -@model function ratio_datavars_as_gaussian_mean() - a = datavar(Float64) - b = datavar(Float64) - y = datavar(Float64) - - x ~ Normal(mean = a / b, variance = 1.0) - y ~ Normal(mean = x, variance = 1.0) -end - -@model function idx_datavars_as_gaussian_mean() - a = datavar(Vector{Float64}) - b = datavar(Matrix{Float64}) - y = datavar(Float64) - - x ~ Normal(mean = dot(a[1:2], b[1:2, 1]), variance = 1.0) - y ~ Normal(mean = x, variance = 1.0) -end - -# Inference function -function fn_datavars_inference(modelfn, adata, bdata, ydata) - return inference(model = modelfn(), data = (a = adata, b = bdata, y = ydata), free_energy = true) -end - -@testset "datavars" begin adata = 2.0 bdata = 1.0 ydata = 0.0 @@ -99,6 +95,4 @@ end A_data = [1.0, 2.0, 3.0] B_data = [1.0 0.5; 0.5 1.0] @test_broken result = fn_datavars_inference(idx_datavars_as_gaussian_mean, A_data, B_data, ydata) -end - -end +end \ No newline at end of file diff --git a/test/models/iid/test_mv_iid_covariance_known_mean.jl b/test/models/iid/mv_iid_covariance_known_mean_tests.jl similarity index 57% rename from test/models/iid/test_mv_iid_covariance_known_mean.jl rename to test/models/iid/mv_iid_covariance_known_mean_tests.jl index 8e0eb8a08..5c7ccda03 100644 --- a/test/models/iid/test_mv_iid_covariance_known_mean.jl +++ b/test/models/iid/mv_iid_covariance_known_mean_tests.jl @@ -1,29 +1,25 @@ -module RxInferModelsMvIIDCovarianceKnownMeanTest +@testitem "Multivariate IID: Covariance parametrisation with known mean" begin + using StableRNGs, Plots, BenchmarkTools + # Please use StableRNGs for random number generators -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) -# Please use StableRNGs for random number generators + @model function mv_iid_inverse_wishart_known_mean(mean, n, d) + C ~ InverseWishart(d + 1, diageye(d)) -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + m = constvar(mean) + y = datavar(Vector{Float64}, n) -@model function mv_iid_inverse_wishart_known_mean(mean, n, d) - C ~ InverseWishart(d + 1, diageye(d)) - - m = constvar(mean) - y = datavar(Vector{Float64}, n) - - for i in 1:n - y[i] ~ MvNormal(mean = m, covariance = C) + for i in 1:n + y[i] ~ MvNormal(mean = m, covariance = C) + end end -end -function inference_mv_inverse_wishart_known_mean(mean, data, n, d) - return inference(model = mv_iid_inverse_wishart_known_mean(mean, n, d), data = (y = data,), iterations = 10, returnvars = KeepLast(), free_energy = Float64) -end + function inference_mv_inverse_wishart_known_mean(mean, data, n, d) + return inference(model = mv_iid_inverse_wishart_known_mean(mean, n, d), data = (y = data,), iterations = 10, returnvars = KeepLast(), free_energy = Float64) + end -@testset "Multivariate IID: Covariance parametrisation with known mean" begin ## Data creation rng = StableRNG(123) @@ -55,6 +51,4 @@ end end @test_benchmark "models" "iid_inverse_wishart_known_mean" inference_mv_inverse_wishart_known_mean($m, $data, $n, $d) -end - -end +end \ No newline at end of file diff --git a/test/models/iid/test_mv_iid_covariance.jl b/test/models/iid/mv_iid_covariance_tests.jl similarity index 50% rename from test/models/iid/test_mv_iid_covariance.jl rename to test/models/iid/mv_iid_covariance_tests.jl index f9dbd3ec3..266bbef14 100644 --- a/test/models/iid/test_mv_iid_covariance.jl +++ b/test/models/iid/mv_iid_covariance_tests.jl @@ -1,41 +1,35 @@ -module RxInferModelsMvIIDCovarianceTest - -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs - -# Please use StableRNGs for random number generators - -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) - -@model function mv_iid_inverse_wishart(n, d) - m ~ MvNormal(mean = zeros(d), precision = 100 * diageye(d)) - C ~ InverseWishart(d + 1, diageye(d)) - - y = datavar(Vector{Float64}, n) - - for i in 1:n - y[i] ~ MvNormal(mean = m, covariance = C) +@testitem "Multivariate IID: Covariance parametrisation" begin + using StableRNGs, Plots, BenchmarkTools + + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + + @model function mv_iid_inverse_wishart(n, d) + m ~ MvNormal(mean = zeros(d), precision = 100 * diageye(d)) + C ~ InverseWishart(d + 1, diageye(d)) + + y = datavar(Vector{Float64}, n) + + for i in 1:n + y[i] ~ MvNormal(mean = m, covariance = C) + end + end + + @constraints function constraints_mv_iid_inverse_wishart() + q(m, C) = q(m)q(C) + end + + function inference_mv_inverse_wishart(data, n, d) + return inference( + model = mv_iid_inverse_wishart(n, d), + data = (y = data,), + constraints = constraints_mv_iid_inverse_wishart(), + initmarginals = (m = vague(MvNormalMeanCovariance, d), C = vague(InverseWishart, d)), + returnvars = KeepLast(), + iterations = 10, + free_energy = Float64 + ) end -end - -@constraints function constraints_mv_iid_inverse_wishart() - q(m, C) = q(m)q(C) -end - -function inference_mv_inverse_wishart(data, n, d) - return inference( - model = mv_iid_inverse_wishart(n, d), - data = (y = data,), - constraints = constraints_mv_iid_inverse_wishart(), - initmarginals = (m = vague(MvNormalMeanCovariance, d), C = vague(InverseWishart, d)), - returnvars = KeepLast(), - iterations = 10, - free_energy = Float64 - ) -end - -@testset "Multivariate IID: Covariance parametrisation" begin ## Data creation rng = StableRNG(123) @@ -68,5 +62,3 @@ end @test_benchmark "models" "iid_inverse_wishart" inference_mv_inverse_wishart($data, $n, $d) end - -end diff --git a/test/models/iid/test_mv_iid_precision_known_mean.jl b/test/models/iid/mv_iid_precision_known_mean_tests.jl similarity index 57% rename from test/models/iid/test_mv_iid_precision_known_mean.jl rename to test/models/iid/mv_iid_precision_known_mean_tests.jl index e900260f6..b6c3f36c3 100644 --- a/test/models/iid/test_mv_iid_precision_known_mean.jl +++ b/test/models/iid/mv_iid_precision_known_mean_tests.jl @@ -1,31 +1,26 @@ -module RxInferModelsMvIIDPrecisionKnownMeanTest +@testitem "Multivariate IID: Precision parametrisation with known mean" begin + using StableRNGs, BenchmarkTools, Plots + # Please use StableRNGs for random number generators -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) -# Please use StableRNGs for random number generators + ## Model and constraints definition -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + @model function mv_iid_wishart_known_mean(mean, n, d) + P ~ Wishart(d + 1, diageye(d)) -## Model and constraints definition + m = constvar(mean) + y = datavar(Vector{Float64}, n) -@model function mv_iid_wishart_known_mean(mean, n, d) - P ~ Wishart(d + 1, diageye(d)) - - m = constvar(mean) - y = datavar(Vector{Float64}, n) - - for i in 1:n - y[i] ~ MvNormal(mean = m, precision = P) + for i in 1:n + y[i] ~ MvNormal(mean = m, precision = P) + end end -end - -function inference_mv_wishart_known_mean(mean, data, n, d) - return inference(model = mv_iid_wishart_known_mean(mean, n, d), data = (y = data,), iterations = 10, returnvars = KeepLast(), free_energy = Float64) -end -@testset "Multivariate IID: Precision parametrisation with known mean" begin + function inference_mv_wishart_known_mean(mean, data, n, d) + return inference(model = mv_iid_wishart_known_mean(mean, n, d), data = (y = data,), iterations = 10, returnvars = KeepLast(), free_energy = Float64) + end ## Data creation rng = StableRNG(123) @@ -58,6 +53,4 @@ end end @test_benchmark "models" "iid_wishart_known_mean" inference_mv_wishart_known_mean($m, $data, $n, $d) -end - -end +end \ No newline at end of file diff --git a/test/models/iid/mv_iid_precision_tests.jl b/test/models/iid/mv_iid_precision_tests.jl new file mode 100644 index 000000000..edead3842 --- /dev/null +++ b/test/models/iid/mv_iid_precision_tests.jl @@ -0,0 +1,70 @@ +@testitem "Multivariate IID: Precision parametrisation" begin + using StableRNGs, BenchmarkTools, Plots + # Please use StableRNGs for random number generators + + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + + ## Model and constraints definition + + @model function mv_iid_wishart(n, d) + m ~ MvNormal(mean = zeros(d), precision = 100 * diageye(d)) + P ~ Wishart(d + 1, diageye(d)) + + y = datavar(Vector{Float64}, n) + + for i in 1:n + y[i] ~ MvNormal(mean = m, precision = P) + end + end + + @constraints function constraints_mv_iid_wishart() + q(m, P) = q(m)q(P) + end + + ## Inference definition + + function inference_mv_wishart(data, n, d) + return inference( + model = mv_iid_wishart(n, d), + data = (y = data,), + constraints = constraints_mv_iid_wishart(), + initmarginals = (m = vague(MvNormalMeanCovariance, d), P = vague(Wishart, d)), + returnvars = KeepLast(), + iterations = 10, + free_energy = Float64 + ) + end + + ## Data creation + rng = StableRNG(123) + + n = 1500 + d = 2 + + m = rand(rng, d) + L = randn(rng, d, d) + C = L * L' + P = inv(C) + + data = rand(rng, MvNormalMeanPrecision(m, P), n) |> eachcol |> collect .|> collect + + ## Inference execution + result = inference_mv_wishart(data, n, d) + + ## Test inference results + @test isapprox(mean(result.posteriors[:m]), m, atol = 0.05) + @test isapprox(mean(result.posteriors[:P]), P, atol = 0.07) + @test all(<(0), filter(e -> abs(e) > 1e-10, diff(result.free_energy))) + + @test_plot "models" "iid_mv_precision" begin + X = range(-5, 5, length = 200) + Y = range(-5, 5, length = 200) + + p = plot(title = "MvIID experiment / Precision parametrisation") + p = contour!(p, X, Y, (x, y) -> pdf(MvNormalMeanPrecision(mean(result.posteriors[:m]), mean(result.posteriors[:P])), [x, y]), label = "Estimated") + p = contour!(p, X, Y, (x, y) -> pdf(MvNormalMeanPrecision(m, P), [x, y]), label = "Real") + end + + @test_benchmark "models" "iid_mv_wishart" inference_mv_wishart($data, $n, $d) +end \ No newline at end of file diff --git a/test/models/iid/test_mv_iid_precision.jl b/test/models/iid/test_mv_iid_precision.jl deleted file mode 100644 index 6fdf25f9e..000000000 --- a/test/models/iid/test_mv_iid_precision.jl +++ /dev/null @@ -1,77 +0,0 @@ -module RxInferModelsMvIIDPrecisionTest - -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs - -# Please use StableRNGs for random number generators - -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) - -## Model and constraints definition - -@model function mv_iid_wishart(n, d) - m ~ MvNormal(mean = zeros(d), precision = 100 * diageye(d)) - P ~ Wishart(d + 1, diageye(d)) - - y = datavar(Vector{Float64}, n) - - for i in 1:n - y[i] ~ MvNormal(mean = m, precision = P) - end -end - -@constraints function constraints_mv_iid_wishart() - q(m, P) = q(m)q(P) -end - -## Inference definition - -function inference_mv_wishart(data, n, d) - return inference( - model = mv_iid_wishart(n, d), - data = (y = data,), - constraints = constraints_mv_iid_wishart(), - initmarginals = (m = vague(MvNormalMeanCovariance, d), P = vague(Wishart, d)), - returnvars = KeepLast(), - iterations = 10, - free_energy = Float64 - ) -end - -@testset "Multivariate IID: Precision parametrisation" begin - - ## Data creation - rng = StableRNG(123) - - n = 1500 - d = 2 - - m = rand(rng, d) - L = randn(rng, d, d) - C = L * L' - P = inv(C) - - data = rand(rng, MvNormalMeanPrecision(m, P), n) |> eachcol |> collect .|> collect - - ## Inference execution - result = inference_mv_wishart(data, n, d) - - ## Test inference results - @test isapprox(mean(result.posteriors[:m]), m, atol = 0.05) - @test isapprox(mean(result.posteriors[:P]), P, atol = 0.07) - @test all(<(0), filter(e -> abs(e) > 1e-10, diff(result.free_energy))) - - @test_plot "models" "iid_mv_precision" begin - X = range(-5, 5, length = 200) - Y = range(-5, 5, length = 200) - - p = plot(title = "MvIID experiment / Precision parametrisation") - p = contour!(p, X, Y, (x, y) -> pdf(MvNormalMeanPrecision(mean(result.posteriors[:m]), mean(result.posteriors[:P])), [x, y]), label = "Estimated") - p = contour!(p, X, Y, (x, y) -> pdf(MvNormalMeanPrecision(m, P), [x, y]), label = "Real") - end - - @test_benchmark "models" "iid_mv_wishart" inference_mv_wishart($data, $n, $d) -end - -end diff --git a/test/models/mixtures/test_gmm_multivariate.jl b/test/models/mixtures/gmm_multivariate_tests.jl similarity index 53% rename from test/models/mixtures/test_gmm_multivariate.jl rename to test/models/mixtures/gmm_multivariate_tests.jl index 731f99093..36fcf548e 100644 --- a/test/models/mixtures/test_gmm_multivariate.jl +++ b/test/models/mixtures/gmm_multivariate_tests.jl @@ -1,79 +1,76 @@ -module RxInferModelsGMMTest - -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, LinearAlgebra, StableRNGs - -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) - -@model function multivariate_gaussian_mixture_model(rng, L, nmixtures, n) - z = randomvar(n) - m = randomvar(nmixtures) - w = randomvar(nmixtures) - - basis_v = [1.0, 0.0] - - for i in 1:nmixtures - # Assume we now only approximate location of cluters's mean - approximate_angle_prior = ((2π + +rand(rng)) / nmixtures) * (i - 1) - approximate_basis_v = L / 2 * (basis_v .+ rand(rng, 2)) - approximate_rotation = [ - cos(approximate_angle_prior) -sin(approximate_angle_prior) - sin(approximate_angle_prior) cos(approximate_angle_prior) - ] - mean_mean_prior = approximate_rotation * approximate_basis_v - mean_mean_cov = [1e6 0.0; 0.0 1e6] - - m[i] ~ MvNormal(mean = mean_mean_prior, cov = mean_mean_cov) - w[i] ~ Wishart(3, [1e2 0.0; 0.0 1e2]) - end +@testitem "Multivariate Gaussian Mixture model" begin + using BenchmarkTools, Plots, StableRNGs, LinearAlgebra + + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + + @model function multivariate_gaussian_mixture_model(rng, L, nmixtures, n) + z = randomvar(n) + m = randomvar(nmixtures) + w = randomvar(nmixtures) + + basis_v = [1.0, 0.0] + + for i in 1:nmixtures + # Assume we now only approximate location of cluters's mean + approximate_angle_prior = ((2π + +rand(rng)) / nmixtures) * (i - 1) + approximate_basis_v = L / 2 * (basis_v .+ rand(rng, 2)) + approximate_rotation = [ + cos(approximate_angle_prior) -sin(approximate_angle_prior) + sin(approximate_angle_prior) cos(approximate_angle_prior) + ] + mean_mean_prior = approximate_rotation * approximate_basis_v + mean_mean_cov = [1e6 0.0; 0.0 1e6] + + m[i] ~ MvNormal(mean = mean_mean_prior, cov = mean_mean_cov) + w[i] ~ Wishart(3, [1e2 0.0; 0.0 1e2]) + end - s ~ Dirichlet(ones(nmixtures)) + s ~ Dirichlet(ones(nmixtures)) - y = datavar(Vector{Float64}, n) + y = datavar(Vector{Float64}, n) - means = tuple(m...) - precs = tuple(w...) + means = tuple(m...) + precs = tuple(w...) - for i in 1:n - z[i] ~ Categorical(s) - y[i] ~ NormalMixture(z[i], means, precs) + for i in 1:n + z[i] ~ Categorical(s) + y[i] ~ NormalMixture(z[i], means, precs) + end end -end - -function inference_multivariate(rng, L, nmixtures, data, viters, constraints) - basis_v = [1.0, 0.0] - - minitmarginals = [] - winitmarginals = [] - - for i in 1:nmixtures - # Assume we now only approximate location of cluters's mean - approximate_angle_prior = ((2π + +rand(rng)) / nmixtures) * (i - 1) - approximate_basis_v = L / 2 * (basis_v .+ rand(rng, 2)) - approximate_rotation = [ - cos(approximate_angle_prior) -sin(approximate_angle_prior) - sin(approximate_angle_prior) cos(approximate_angle_prior) - ] - mean_mean_prior = approximate_rotation * approximate_basis_v - mean_mean_cov = [1e6 0.0; 0.0 1e6] - - push!(minitmarginals, MvNormalMeanCovariance(mean_mean_prior, mean_mean_cov)) - push!(winitmarginals, Wishart(3, [1e2 0.0; 0.0 1e2])) + + function inference_multivariate(rng, L, nmixtures, data, viters, constraints) + basis_v = [1.0, 0.0] + + minitmarginals = [] + winitmarginals = [] + + for i in 1:nmixtures + # Assume we now only approximate location of cluters's mean + approximate_angle_prior = ((2π + +rand(rng)) / nmixtures) * (i - 1) + approximate_basis_v = L / 2 * (basis_v .+ rand(rng, 2)) + approximate_rotation = [ + cos(approximate_angle_prior) -sin(approximate_angle_prior) + sin(approximate_angle_prior) cos(approximate_angle_prior) + ] + mean_mean_prior = approximate_rotation * approximate_basis_v + mean_mean_cov = [1e6 0.0; 0.0 1e6] + + push!(minitmarginals, MvNormalMeanCovariance(mean_mean_prior, mean_mean_cov)) + push!(winitmarginals, Wishart(3, [1e2 0.0; 0.0 1e2])) + end + + return inference( + model = multivariate_gaussian_mixture_model(rng, L, nmixtures, length(data)), + data = (y = data,), + constraints = constraints, + returnvars = KeepEach(), + free_energy = Float64, + iterations = viters, + initmarginals = (s = vague(Dirichlet, nmixtures), m = minitmarginals, w = winitmarginals) + ) end - return inference( - model = multivariate_gaussian_mixture_model(rng, L, nmixtures, length(data)), - data = (y = data,), - constraints = constraints, - returnvars = KeepEach(), - free_energy = Float64, - iterations = viters, - initmarginals = (s = vague(Dirichlet, nmixtures), m = minitmarginals, w = winitmarginals) - ) -end - -@testset "Multivariate Gaussian Mixture model" begin rng = StableRNG(43) L = 50.0 @@ -163,6 +160,4 @@ end end @test_benchmark "models" "gmm_multivariate" inference_multivariate(StableRNG(123), $L, $nmixtures, $y, 25, MeanField()) -end - -end +end \ No newline at end of file diff --git a/test/models/mixtures/test_gmm_univariate.jl b/test/models/mixtures/gmm_univariate_tests.jl similarity index 72% rename from test/models/mixtures/test_gmm_univariate.jl rename to test/models/mixtures/gmm_univariate_tests.jl index 86b9eb21c..ee17887e3 100644 --- a/test/models/mixtures/test_gmm_univariate.jl +++ b/test/models/mixtures/gmm_univariate_tests.jl @@ -1,42 +1,38 @@ -module RxInferModelsGMMTest +@testitem "Univariate Gaussian Mixture model " begin + using BenchmarkTools, Plots, LinearAlgebra, StableRNGs -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, LinearAlgebra, StableRNGs + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + @model function univariate_gaussian_mixture_model(n) + s ~ Beta(1.0, 1.0) -@model function univariate_gaussian_mixture_model(n) - s ~ Beta(1.0, 1.0) + m1 ~ Normal(mean = -2.0, variance = 1e3) + w1 ~ Gamma(shape = 0.01, rate = 0.01) - m1 ~ Normal(mean = -2.0, variance = 1e3) - w1 ~ Gamma(shape = 0.01, rate = 0.01) + m2 ~ Normal(mean = 2.0, variance = 1e3) + w2 ~ Gamma(shape = 0.01, rate = 0.01) - m2 ~ Normal(mean = 2.0, variance = 1e3) - w2 ~ Gamma(shape = 0.01, rate = 0.01) + z = randomvar(n) + y = datavar(Float64, n) - z = randomvar(n) - y = datavar(Float64, n) + for i in 1:n + z[i] ~ Bernoulli(s) + y[i] ~ NormalMixture(z[i], (m1, m2), (w1, w2)) + end + end - for i in 1:n - z[i] ~ Bernoulli(s) - y[i] ~ NormalMixture(z[i], (m1, m2), (w1, w2)) + function inference_univariate(data, n_its, constraints) + return inference( + model = univariate_gaussian_mixture_model(length(data)), + data = (y = data,), + constraints = constraints, + returnvars = KeepEach(), + free_energy = Float64, + iterations = n_its, + initmarginals = (s = vague(Beta), m1 = NormalMeanVariance(-2.0, 1e3), m2 = NormalMeanVariance(2.0, 1e3), w1 = vague(GammaShapeRate), w2 = vague(GammaShapeRate)) + ) end -end - -function inference_univariate(data, n_its, constraints) - return inference( - model = univariate_gaussian_mixture_model(length(data)), - data = (y = data,), - constraints = constraints, - returnvars = KeepEach(), - free_energy = Float64, - iterations = n_its, - initmarginals = (s = vague(Beta), m1 = NormalMeanVariance(-2.0, 1e3), m2 = NormalMeanVariance(2.0, 1e3), w1 = vague(GammaShapeRate), w2 = vague(GammaShapeRate)) - ) -end - -@testset "Univariate Gaussian Mixture model " begin ## -------------------------------------------- ## ## Data creation ## -------------------------------------------- ## @@ -133,6 +129,4 @@ end end @test_benchmark "models" "gmm_univariate" inference_univariate($y, 10, MeanField()) -end - -end +end \ No newline at end of file diff --git a/test/models/mixtures/test_mixture.jl b/test/models/mixtures/mixture_tests.jl similarity index 75% rename from test/models/mixtures/test_mixture.jl rename to test/models/mixtures/mixture_tests.jl index fefa0dedf..d40d9e340 100644 --- a/test/models/mixtures/test_mixture.jl +++ b/test/models/mixtures/mixture_tests.jl @@ -1,60 +1,55 @@ +@testitem "Model mixture" begin + using Distributions + using BenchmarkTools, LinearAlgebra, StableRNGs, Plots -module RxInferModelsMixtureTest + # Please use StableRNGs for random number generators -using Test, InteractiveUtils -using RxInfer, Distributions -using BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) -# Please use StableRNGs for random number generators + ## Model definition + ## -------------------------------------------- ## -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + @model function beta_model1(n) + y = datavar(Float64, n) -## Model definition -## -------------------------------------------- ## + θ ~ Beta(4.0, 8.0) -@model function beta_model1(n) - y = datavar(Float64, n) - - θ ~ Beta(4.0, 8.0) + for i in 1:n + y[i] ~ Bernoulli(θ) + end - for i in 1:n - y[i] ~ Bernoulli(θ) + return y, θ end - return y, θ -end + @model function beta_model2(n) + y = datavar(Float64, n) -@model function beta_model2(n) - y = datavar(Float64, n) + θ ~ Beta(8.0, 4.0) - θ ~ Beta(8.0, 4.0) + for i in 1:n + y[i] ~ Bernoulli(θ) + end - for i in 1:n - y[i] ~ Bernoulli(θ) + return y, θ end - return y, θ -end + @model function beta_mixture_model(n) + y = datavar(Float64, n) -@model function beta_mixture_model(n) - y = datavar(Float64, n) + selector ~ Bernoulli(0.7) - selector ~ Bernoulli(0.7) + in1 ~ Beta(4.0, 8.0) + in2 ~ Beta(8.0, 4.0) - in1 ~ Beta(4.0, 8.0) - in2 ~ Beta(8.0, 4.0) + θ ~ Mixture(selector, (in1, in2)) - θ ~ Mixture(selector, (in1, in2)) + for i in 1:n + y[i] ~ Bernoulli(θ) + end - for i in 1:n - y[i] ~ Bernoulli(θ) + return y, θ end - - return y, θ -end - -@testset "Model mixture" begin @testset "Check inference results" begin ## -------------------------------------------- ## ## Data creation @@ -109,6 +104,4 @@ end return p end end -end - -end +end \ No newline at end of file diff --git a/test/models/nonlinear/test_cvi.jl b/test/models/nonlinear/cvi_tests.jl similarity index 57% rename from test/models/nonlinear/test_cvi.jl rename to test/models/nonlinear/cvi_tests.jl index fee9af3f1..546b5ec4b 100644 --- a/test/models/nonlinear/test_cvi.jl +++ b/test/models/nonlinear/cvi_tests.jl @@ -1,72 +1,69 @@ -module ReactiveMPModelsNonLinearDynamicsTest - -using Test, InteractiveUtils -using RxInfer, Distributions -using BenchmarkTools, Random, Plots, Dates, StableRNGs, Optimisers - -# Please use StableRNGs for random number generators - -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) - -## Model definition -## -------------------------------------------- ## -sensor_location = 53 -P = 5 -sensor_var = 5 -function f(z) - (z - sensor_location)^2 -end - -@model function non_linear_dynamics(T) - z = randomvar(T) - x = randomvar(T) - y = datavar(Float64, T) - - τ ~ GammaShapeRate(1.0, 1.0e-12) - θ ~ GammaShapeRate(1.0, 1.0e-12) - - z[1] ~ NormalMeanPrecision(0, τ) - x[1] ~ f(z[1]) - y[1] ~ NormalMeanPrecision(x[1], θ) - - for t in 2:T - z[t] ~ NormalMeanPrecision(z[t - 1] + 1, τ) - x[t] ~ f(z[t]) - y[t] ~ NormalMeanPrecision(x[t], θ) +@testitem "Non linear dynamics" begin + using Distributions + using BenchmarkTools, Plots, StableRNGs, Optimisers, Random, Dates + + # Please use StableRNGs for random number generators + + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + + ## Model definition + ## -------------------------------------------- ## + sensor_location = 53 + P = 5 + sensor_var = 5 + function f(z) + (z - sensor_location)^2 + end + + @model function non_linear_dynamics(T) + z = randomvar(T) + x = randomvar(T) + y = datavar(Float64, T) + + τ ~ GammaShapeRate(1.0, 1.0e-12) + θ ~ GammaShapeRate(1.0, 1.0e-12) + + z[1] ~ NormalMeanPrecision(0, τ) + x[1] ~ f(z[1]) + y[1] ~ NormalMeanPrecision(x[1], θ) + + for t in 2:T + z[t] ~ NormalMeanPrecision(z[t - 1] + 1, τ) + x[t] ~ f(z[t]) + y[t] ~ NormalMeanPrecision(x[t], θ) + end + + return z, x, y + end + + constraints = @constraints begin + q(z, x, τ, θ) = q(z)q(x)q(τ)q(θ) + end + + @meta function model_meta(rng, n_iterations, n_samples, learning_rate) + f() -> CVI(rng, n_iterations, n_samples, Optimisers.Descent(learning_rate)) + end + + ## -------------------------------------------- ## + ## Inference definition + ## -------------------------------------------- ## + function inference_cvi(transformed, rng, iterations) + T = length(transformed) + + return inference( + model = non_linear_dynamics(T), + data = (y = transformed,), + iterations = iterations, + free_energy = true, + returnvars = (z = KeepLast(),), + constraints = constraints, + meta = model_meta(rng, 600, 600, 0.01), + initmessages = (z = NormalMeanVariance(0, P),), + initmarginals = (z = NormalMeanVariance(0, P), τ = GammaShapeRate(1.0, 1.0e-12), θ = GammaShapeRate(1.0, 1.0e-12)) + ) end - return z, x, y -end - -constraints = @constraints begin - q(z, x, τ, θ) = q(z)q(x)q(τ)q(θ) -end - -@meta function model_meta(rng, n_iterations, n_samples, learning_rate) - f() -> CVI(rng, n_iterations, n_samples, Optimisers.Descent(learning_rate)) -end - -## -------------------------------------------- ## -## Inference definition -## -------------------------------------------- ## -function inference_cvi(transformed, rng, iterations) - T = length(transformed) - - return inference( - model = non_linear_dynamics(T), - data = (y = transformed,), - iterations = iterations, - free_energy = true, - returnvars = (z = KeepLast(),), - constraints = constraints, - meta = model_meta(rng, 600, 600, 0.01), - initmessages = (z = NormalMeanVariance(0, P),), - initmarginals = (z = NormalMeanVariance(0, P), τ = GammaShapeRate(1.0, 1.0e-12), θ = GammaShapeRate(1.0, 1.0e-12)) - ) -end - -@testset "Non linear dynamics" begin @testset "Use case #1" begin ## -------------------------------------------- ## ## Data creation @@ -122,6 +119,4 @@ end @test_benchmark "models" "cvi" inference_cvi($transformed, $rng, 110) end -end - -end +end \ No newline at end of file diff --git a/test/models/nonlinear/generic_applicability_tests.jl b/test/models/nonlinear/generic_applicability_tests.jl new file mode 100644 index 000000000..0a70d7114 --- /dev/null +++ b/test/models/nonlinear/generic_applicability_tests.jl @@ -0,0 +1,158 @@ +@testitem "Nonlinear models: generic applicability" begin + using BenchmarkTools, Random, Plots, LinearAlgebra, StableRNGs + + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + + # Please use StableRNGs for random number generators + + ## Model definition + ## -------------------------------------------- ## + + # We test that the function can depend on a global variable + # A particular value does not matter here, only the fact that it runs + globalvar = 0 + + function f₁(x) + return sqrt.(x .+ globalvar) + end + + function f₁_inv(x) + return x .^ 2 + end + + @model function delta_1input(meta) + y2 = datavar(Float64) + c = zeros(2) + c[1] = 1.0 + + x ~ MvNormal(μ = ones(2), Λ = diageye(2)) + z ~ f₁(x) where {meta = meta} + y1 ~ Normal(μ = dot(z, c), σ² = 1.0) + y2 ~ Normal(μ = y1, σ² = 0.5) + end + + function f₂(x, θ) + return x .+ θ + end + + function f₂_x(θ, z) + return z .- θ + end + + function f₂_θ(x, z) + return z .- x + end + + @model function delta_2inputs(meta) + y2 = datavar(Float64) + c = zeros(2) + c[1] = 1.0 + + θ ~ MvNormal(μ = ones(2), Λ = diageye(2)) + x ~ MvNormal(μ = zeros(2), Λ = diageye(2)) + z ~ f₂(x, θ) where {meta = meta} + y1 ~ Normal(μ = dot(z, c), σ² = 1.0) + y2 ~ Normal(μ = y1, σ² = 0.5) + end + + function f₃(x, θ, ζ) + return x .+ θ .+ ζ + end + + @model function delta_3inputs(meta) + y2 = datavar(Float64) + c = zeros(2) + c[1] = 1.0 + + θ ~ MvNormal(μ = ones(2), Λ = diageye(2)) + ζ ~ MvNormal(μ = 0.5ones(2), Λ = diageye(2)) + x ~ MvNormal(μ = zeros(2), Λ = diageye(2)) + z ~ f₃(x, θ, ζ) where {meta = meta} + y1 ~ Normal(μ = dot(z, c), σ² = 1.0) + y2 ~ Normal(μ = y1, σ² = 0.5) + end + + function f₄(x, θ) + return θ .* x + end + + @model function delta_2input_1d2d(meta) + y2 = datavar(Float64) + c = zeros(2) + c[1] = 1.0 + + θ ~ Normal(μ = 0.5, γ = 1.0) + x ~ MvNormal(μ = zeros(2), Λ = diageye(2)) + z ~ f₄(x, θ) where {meta = meta} + y1 ~ Normal(μ = dot(z, c), σ² = 1.0) + y2 ~ Normal(μ = y1, σ² = 0.5) + end + + ## -------------------------------------------- ## + ## Inference definition + ## -------------------------------------------- ## + function inference_1input(data) + + # We test here different approximation methods + metas = ( + DeltaMeta(method = Linearization(), inverse = f₁_inv), + DeltaMeta(method = Unscented(), inverse = f₁_inv), + DeltaMeta(method = Linearization()), + DeltaMeta(method = Unscented()), + Linearization(), + Unscented() + ) + + return map(metas) do meta + return inference(model = delta_1input(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs())) + end + end + + function inference_2inputs(data) + metas = ( + DeltaMeta(method = Linearization(), inverse = (f₂_x, f₂_θ)), + DeltaMeta(method = Unscented(), inverse = (f₂_x, f₂_θ)), + DeltaMeta(method = Linearization()), + DeltaMeta(method = Unscented()), + Linearization(), + Unscented() + ) + + return map(metas) do meta + return inference(model = delta_2inputs(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs())) + end + end + + function inference_3inputs(data) + metas = (DeltaMeta(method = Linearization()), DeltaMeta(method = Unscented()), Linearization(), Unscented()) + + return map(metas) do meta + return inference(model = delta_3inputs(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs())) + end + end + + function inference_2input_1d2d(data) + metas = (DeltaMeta(method = Linearization()), DeltaMeta(method = Unscented()), Linearization(), Unscented()) + + return map(metas) do meta + return inference( + model = delta_2input_1d2d(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs()) + ) + end + end + @testset "Linearization, Unscented transforms" begin + ## -------------------------------------------- ## + ## Data creation + data = 4.0 + ## -------------------------------------------- ## + ## Inference execution + result₁ = inference_1input(data) + result₂ = inference_2inputs(data) + result₃ = inference_3inputs(data) + result₄ = inference_2input_1d2d(data) + + ## All models have been created. The inference finished without errors ## + @test true + end +end \ No newline at end of file diff --git a/test/models/nonlinear/test_generic_applicability.jl b/test/models/nonlinear/test_generic_applicability.jl deleted file mode 100644 index 5f1ccd9a1..000000000 --- a/test/models/nonlinear/test_generic_applicability.jl +++ /dev/null @@ -1,164 +0,0 @@ -module RxInferNonlinearityModelsDeltaTest - -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, LinearAlgebra, StableRNGs - -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) - -# Please use StableRNGs for random number generators - -## Model definition -## -------------------------------------------- ## - -# We test that the function can depend on a global variable -# A particular value does not matter here, only the fact that it runs -globalvar = 0 - -function f₁(x) - return sqrt.(x .+ globalvar) -end - -function f₁_inv(x) - return x .^ 2 -end - -@model function delta_1input(meta) - y2 = datavar(Float64) - c = zeros(2) - c[1] = 1.0 - - x ~ MvNormal(μ = ones(2), Λ = diageye(2)) - z ~ f₁(x) where {meta = meta} - y1 ~ Normal(μ = dot(z, c), σ² = 1.0) - y2 ~ Normal(μ = y1, σ² = 0.5) -end - -function f₂(x, θ) - return x .+ θ -end - -function f₂_x(θ, z) - return z .- θ -end - -function f₂_θ(x, z) - return z .- x -end - -@model function delta_2inputs(meta) - y2 = datavar(Float64) - c = zeros(2) - c[1] = 1.0 - - θ ~ MvNormal(μ = ones(2), Λ = diageye(2)) - x ~ MvNormal(μ = zeros(2), Λ = diageye(2)) - z ~ f₂(x, θ) where {meta = meta} - y1 ~ Normal(μ = dot(z, c), σ² = 1.0) - y2 ~ Normal(μ = y1, σ² = 0.5) -end - -function f₃(x, θ, ζ) - return x .+ θ .+ ζ -end - -@model function delta_3inputs(meta) - y2 = datavar(Float64) - c = zeros(2) - c[1] = 1.0 - - θ ~ MvNormal(μ = ones(2), Λ = diageye(2)) - ζ ~ MvNormal(μ = 0.5ones(2), Λ = diageye(2)) - x ~ MvNormal(μ = zeros(2), Λ = diageye(2)) - z ~ f₃(x, θ, ζ) where {meta = meta} - y1 ~ Normal(μ = dot(z, c), σ² = 1.0) - y2 ~ Normal(μ = y1, σ² = 0.5) -end - -function f₄(x, θ) - return θ .* x -end - -@model function delta_2input_1d2d(meta) - y2 = datavar(Float64) - c = zeros(2) - c[1] = 1.0 - - θ ~ Normal(μ = 0.5, γ = 1.0) - x ~ MvNormal(μ = zeros(2), Λ = diageye(2)) - z ~ f₄(x, θ) where {meta = meta} - y1 ~ Normal(μ = dot(z, c), σ² = 1.0) - y2 ~ Normal(μ = y1, σ² = 0.5) -end - -## -------------------------------------------- ## -## Inference definition -## -------------------------------------------- ## -function inference_1input(data) - - # We test here different approximation methods - metas = ( - DeltaMeta(method = Linearization(), inverse = f₁_inv), - DeltaMeta(method = Unscented(), inverse = f₁_inv), - DeltaMeta(method = Linearization()), - DeltaMeta(method = Unscented()), - Linearization(), - Unscented() - ) - - return map(metas) do meta - return inference(model = delta_1input(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs())) - end -end - -function inference_2inputs(data) - metas = ( - DeltaMeta(method = Linearization(), inverse = (f₂_x, f₂_θ)), - DeltaMeta(method = Unscented(), inverse = (f₂_x, f₂_θ)), - DeltaMeta(method = Linearization()), - DeltaMeta(method = Unscented()), - Linearization(), - Unscented() - ) - - return map(metas) do meta - return inference(model = delta_2inputs(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs())) - end -end - -function inference_3inputs(data) - metas = (DeltaMeta(method = Linearization()), DeltaMeta(method = Unscented()), Linearization(), Unscented()) - - return map(metas) do meta - return inference(model = delta_3inputs(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs())) - end -end - -function inference_2input_1d2d(data) - metas = (DeltaMeta(method = Linearization()), DeltaMeta(method = Unscented()), Linearization(), Unscented()) - - return map(metas) do meta - return inference( - model = delta_2input_1d2d(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs()) - ) - end -end - -@testset "Nonlinear models: generic applicability" begin - @testset "Linearization, Unscented transforms" begin - ## -------------------------------------------- ## - ## Data creation - data = 4.0 - ## -------------------------------------------- ## - ## Inference execution - result₁ = inference_1input(data) - result₂ = inference_2inputs(data) - result₃ = inference_3inputs(data) - result₄ = inference_2input_1d2d(data) - - ## All models have been created. The inference finished without errors ## - @test true - end -end - -end diff --git a/test/models/regression/linreg_tests.jl b/test/models/regression/linreg_tests.jl new file mode 100644 index 000000000..cf3c0f0ef --- /dev/null +++ b/test/models/regression/linreg_tests.jl @@ -0,0 +1,78 @@ +@testitem "Linear regression" begin + using BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs + + # Please use StableRNGs for random number generators + + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + + ## Model definition + @model function linear_regression(n) + a ~ Normal(mean = 0.0, var = 1.0) + b ~ Normal(mean = 0.0, var = 1.0) + + x = datavar(Float64, n) + y = datavar(Float64, n) + + for i in 1:n + y[i] ~ Normal(mean = x[i] * b + a, var = 1.0) + end + end + + @model function linear_regression_broadcasted(n) + a ~ Normal(mean = 0.0, var = 1.0) + b ~ Normal(mean = 0.0, var = 1.0) + + x = datavar(Float64, n) + y = datavar(Float64, n) + + # Variance over-complicated for a purpose of checking that this expressions are allowed, it should be equal to `1.0` + y .~ Normal(mean = x .* b .+ a, var = det((diageye(2) .+ diageye(2)) ./ 2)) + end + + ## Inference definition + function linreg_inference(modelfn, niters, xdata, ydata) + return inference( + model = modelfn(length(xdata)), + data = (x = xdata, y = ydata), + returnvars = (a = KeepLast(), b = KeepLast()), + initmessages = (b = NormalMeanVariance(0.0, 100.0),), + free_energy = true, + iterations = niters + ) + end + + ## Data creation + reala = 10.0 + realb = -10.0 + + N = 100 + + rng = StableRNG(1234) + + xdata = collect(1:N) .+ 1 * randn(rng, N) + ydata = reala .+ realb .* xdata + + ## Inference execution + result = linreg_inference(linear_regression, 25, xdata, ydata) + resultb = linreg_inference(linear_regression_broadcasted, 25, xdata, ydata) + + ares = result.posteriors[:a] + bres = result.posteriors[:b] + fres = result.free_energy + + aresb = resultb.posteriors[:a] + bresb = resultb.posteriors[:b] + fresb = resultb.free_energy + + ## Test inference results + @test mean(ares) ≈ mean(aresb) && var(ares) ≈ var(aresb) # Broadcasting may change the order of computations, so slight + @test mean(bres) ≈ mean(bresb) && var(bres) ≈ var(bresb) # differences are allowed + @test all(fres .≈ fresb) + @test isapprox(mean(ares), reala, atol = 5) + @test isapprox(mean(bres), realb, atol = 0.1) + @test fres[end] < fres[2] # Loopy belief propagation has no guaranties though + + @test_benchmark "models" "linreg" linreg_inference(linear_regression, 25, $xdata, $ydata) + @test_benchmark "models" "linreg_broadcasted" linreg_inference(linear_regression_broadcasted, 25, $xdata, $ydata) +end \ No newline at end of file diff --git a/test/models/regression/test_linreg.jl b/test/models/regression/test_linreg.jl deleted file mode 100644 index 45bb87ca2..000000000 --- a/test/models/regression/test_linreg.jl +++ /dev/null @@ -1,84 +0,0 @@ -module RxInferModelsLinearRegressionTest - -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs - -# Please use StableRNGs for random number generators - -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) - -## Model definition -@model function linear_regression(n) - a ~ Normal(mean = 0.0, var = 1.0) - b ~ Normal(mean = 0.0, var = 1.0) - - x = datavar(Float64, n) - y = datavar(Float64, n) - - for i in 1:n - y[i] ~ Normal(mean = x[i] * b + a, var = 1.0) - end -end - -@model function linear_regression_broadcasted(n) - a ~ Normal(mean = 0.0, var = 1.0) - b ~ Normal(mean = 0.0, var = 1.0) - - x = datavar(Float64, n) - y = datavar(Float64, n) - - # Variance over-complicated for a purpose of checking that this expressions are allowed, it should be equal to `1.0` - y .~ Normal(mean = x .* b .+ a, var = det((diageye(2) .+ diageye(2)) ./ 2)) -end - -## Inference definition -function linreg_inference(modelfn, niters, xdata, ydata) - return inference( - model = modelfn(length(xdata)), - data = (x = xdata, y = ydata), - returnvars = (a = KeepLast(), b = KeepLast()), - initmessages = (b = NormalMeanVariance(0.0, 100.0),), - free_energy = true, - iterations = niters - ) -end - -@testset "Linear regression" begin - - ## Data creation - reala = 10.0 - realb = -10.0 - - N = 100 - - rng = StableRNG(1234) - - xdata = collect(1:N) .+ 1 * randn(rng, N) - ydata = reala .+ realb .* xdata - - ## Inference execution - result = linreg_inference(linear_regression, 25, xdata, ydata) - resultb = linreg_inference(linear_regression_broadcasted, 25, xdata, ydata) - - ares = result.posteriors[:a] - bres = result.posteriors[:b] - fres = result.free_energy - - aresb = resultb.posteriors[:a] - bresb = resultb.posteriors[:b] - fresb = resultb.free_energy - - ## Test inference results - @test mean(ares) ≈ mean(aresb) && var(ares) ≈ var(aresb) # Broadcasting may change the order of computations, so slight - @test mean(bres) ≈ mean(bresb) && var(bres) ≈ var(bresb) # differences are allowed - @test all(fres .≈ fresb) - @test isapprox(mean(ares), reala, atol = 5) - @test isapprox(mean(bres), realb, atol = 0.1) - @test fres[end] < fres[2] # Loopy belief propagation has no guaranties though - - @test_benchmark "models" "linreg" linreg_inference(linear_regression, 25, $xdata, $ydata) - @test_benchmark "models" "linreg_broadcasted" linreg_inference(linear_regression_broadcasted, 25, $xdata, $ydata) -end - -end diff --git a/test/models/statespace/test_hgf.jl b/test/models/statespace/hgf_tests.jl similarity index 50% rename from test/models/statespace/test_hgf.jl rename to test/models/statespace/hgf_tests.jl index 2b968192c..8a760d0f9 100644 --- a/test/models/statespace/test_hgf.jl +++ b/test/models/statespace/hgf_tests.jl @@ -1,76 +1,72 @@ -module RxInferModelsHGFTest +@testitem "Hierarchical Gaussian Filter" begin + using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + # We create a single-time step of corresponding state-space process to + # perform online learning (filtering) + @model function hgf(real_k, real_w, z_variance, y_variance) -# We create a single-time step of corresponding state-space process to -# perform online learning (filtering) -@model function hgf(real_k, real_w, z_variance, y_variance) + # Priors from previous time step for `z` + zt_min_mean = datavar(Float64) + zt_min_var = datavar(Float64) - # Priors from previous time step for `z` - zt_min_mean = datavar(Float64) - zt_min_var = datavar(Float64) + # Priors from previous time step for `x` + xt_min_mean = datavar(Float64) + xt_min_var = datavar(Float64) - # Priors from previous time step for `x` - xt_min_mean = datavar(Float64) - xt_min_var = datavar(Float64) + zt_min ~ Normal(mean = zt_min_mean, var = zt_min_var) + xt_min ~ Normal(mean = xt_min_mean, var = xt_min_var) - zt_min ~ Normal(mean = zt_min_mean, var = zt_min_var) - xt_min ~ Normal(mean = xt_min_mean, var = xt_min_var) + # Higher layer is modelled as a random walk + zt ~ Normal(mean = zt_min, var = z_variance) - # Higher layer is modelled as a random walk - zt ~ Normal(mean = zt_min, var = z_variance) + # Lower layer is modelled with `GCV` node + gcvnode, xt ~ GCV(xt_min, zt, real_k, real_w) - # Lower layer is modelled with `GCV` node - gcvnode, xt ~ GCV(xt_min, zt, real_k, real_w) + # Noisy observations + y = datavar(Float64) + y ~ Normal(mean = xt, var = y_variance) - # Noisy observations - y = datavar(Float64) - y ~ Normal(mean = xt, var = y_variance) + return gcvnode + end - return gcvnode -end + @constraints function hgfconstraints() + q(xt, zt, xt_min) = q(xt, xt_min)q(zt) + end -@constraints function hgfconstraints() - q(xt, zt, xt_min) = q(xt, xt_min)q(zt) -end + @meta function hgfmeta() + # Lets use 31 approximation points in the Gauss Hermite cubature approximation method + GCV(xt_min, xt, zt) -> GCVMetadata(GaussHermiteCubature(31)) + end -@meta function hgfmeta() - # Lets use 31 approximation points in the Gauss Hermite cubature approximation method - GCV(xt_min, xt, zt) -> GCVMetadata(GaussHermiteCubature(31)) -end + ## Inference definition + function hgf_online_inference(data, vmp_iters, real_k, real_w, z_variance, y_variance) + autoupdates = @autoupdates begin + zt_min_mean, zt_min_var = mean_var(q(zt)) + xt_min_mean, xt_min_var = mean_var(q(xt)) + end -## Inference definition -function hgf_online_inference(data, vmp_iters, real_k, real_w, z_variance, y_variance) - autoupdates = @autoupdates begin - zt_min_mean, zt_min_var = mean_var(q(zt)) - xt_min_mean, xt_min_var = mean_var(q(xt)) + return rxinference( + model = hgf(real_k, real_w, z_variance, y_variance), + constraints = hgfconstraints(), + meta = hgfmeta(), + data = (y = data,), + autoupdates = autoupdates, + keephistory = length(data), + historyvars = (xt = KeepLast(), zt = KeepLast()), + initmarginals = (zt = NormalMeanVariance(0.0, 5.0), xt = NormalMeanVariance(0.0, 5.0)), + iterations = vmp_iters, + free_energy = true, + autostart = true, + callbacks = (after_model_creation = (model, returnval) -> begin + gcvnode = returnval + setmarginal!(gcvnode, :y_x, MvNormalMeanCovariance([0.0, 0.0], [5.0, 5.0])) + end,) + ) end - return rxinference( - model = hgf(real_k, real_w, z_variance, y_variance), - constraints = hgfconstraints(), - meta = hgfmeta(), - data = (y = data,), - autoupdates = autoupdates, - keephistory = length(data), - historyvars = (xt = KeepLast(), zt = KeepLast()), - initmarginals = (zt = NormalMeanVariance(0.0, 5.0), xt = NormalMeanVariance(0.0, 5.0)), - iterations = vmp_iters, - free_energy = true, - autostart = true, - callbacks = (after_model_creation = (model, returnval) -> begin - gcvnode = returnval - setmarginal!(gcvnode, :y_x, MvNormalMeanCovariance([0.0, 0.0], [5.0, 5.0])) - end,) - ) -end - -@testset "Hierarchical Gaussian Filter" begin - ## Data creation function generate_data(rng, k, w, zv, yv) z_prev = 0.0 @@ -145,6 +141,4 @@ end end @test_benchmark "models" "hgf" hgf_online_inference($y, $vmp_iters, $real_k, $real_w, $z_variance, $y_variance) -end - -end +end \ No newline at end of file diff --git a/test/models/statespace/test_hmm.jl b/test/models/statespace/hmm_tests.jl similarity index 58% rename from test/models/statespace/test_hmm.jl rename to test/models/statespace/hmm_tests.jl index 07860057b..70f64e573 100644 --- a/test/models/statespace/test_hmm.jl +++ b/test/models/statespace/hmm_tests.jl @@ -1,49 +1,45 @@ -module ReactiveMPModelsHMMTest +@testitem "Hidden Markov Model" begin + using BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + ## Model definition + @model function hidden_markov_model(n) + A ~ MatrixDirichlet(ones(3, 3)) + B ~ MatrixDirichlet([10.0 1.0 1.0; 1.0 10.0 1.0; 1.0 1.0 10.0]) -## Model definition -@model function hidden_markov_model(n) - A ~ MatrixDirichlet(ones(3, 3)) - B ~ MatrixDirichlet([10.0 1.0 1.0; 1.0 10.0 1.0; 1.0 1.0 10.0]) + s_0 ~ Categorical(fill(1.0 / 3.0, 3)) - s_0 ~ Categorical(fill(1.0 / 3.0, 3)) + s = randomvar(n) + x = datavar(Vector{Float64}, n) - s = randomvar(n) - x = datavar(Vector{Float64}, n) + s_prev = s_0 - s_prev = s_0 + for t in 1:n + s[t] ~ Transition(s_prev, A) + x[t] ~ Transition(s[t], B) + s_prev = s[t] + end + end - for t in 1:n - s[t] ~ Transition(s_prev, A) - x[t] ~ Transition(s[t], B) - s_prev = s[t] + @constraints function hidden_markov_constraints() + q(s, s_0, A, B) = q(s, s_0)q(A)q(B) + end + + ## Inference definition + function hidden_markov_model_inference(data, vmp_iters) + return inference( + model = hidden_markov_model(length(data)), + constraints = hidden_markov_constraints(), + data = (x = data,), + options = (limit_stack_depth = 500,), + free_energy = true, + initmarginals = (A = vague(MatrixDirichlet, 3, 3), B = vague(MatrixDirichlet, 3, 3), s = vague(Categorical, 3)), + iterations = vmp_iters, + returnvars = (s = KeepEach(), A = KeepEach(), B = KeepEach()) + ) end -end - -@constraints function hidden_markov_constraints() - q(s, s_0, A, B) = q(s, s_0)q(A)q(B) -end - -## Inference definition -function hidden_markov_model_inference(data, vmp_iters) - return inference( - model = hidden_markov_model(length(data)), - constraints = hidden_markov_constraints(), - data = (x = data,), - options = (limit_stack_depth = 500,), - free_energy = true, - initmarginals = (A = vague(MatrixDirichlet, 3, 3), B = vague(MatrixDirichlet, 3, 3), s = vague(Categorical, 3)), - iterations = vmp_iters, - returnvars = (s = KeepEach(), A = KeepEach(), B = KeepEach()) - ) -end - -@testset "Hidden Markov Model" begin ## Data creation function rand_vec(rng, distribution::Categorical) @@ -106,6 +102,4 @@ end end @test_benchmark "models" "mlgssm" hidden_markov_model_inference($x_data, 20) -end - -end +end \ No newline at end of file diff --git a/test/models/statespace/test_mlgssm.jl b/test/models/statespace/mlgssm_test.jl similarity index 68% rename from test/models/statespace/test_mlgssm.jl rename to test/models/statespace/mlgssm_test.jl index 8cc9f8e6e..0c336fa92 100644 --- a/test/models/statespace/test_mlgssm.jl +++ b/test/models/statespace/mlgssm_test.jl @@ -1,41 +1,37 @@ -module RxInferModelsULGSSMTest +@testitem "Linear Gaussian State Space Model" begin + using BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + ## Model definition + @model function multivariate_lgssm_model(n, x0, A, B, Q, P) -## Model definition -@model function multivariate_lgssm_model(n, x0, A, B, Q, P) + # We create constvar references for better efficiency + cA = constvar(A) + cB = constvar(B) + cQ = constvar(Q) + cP = constvar(P) - # We create constvar references for better efficiency - cA = constvar(A) - cB = constvar(B) - cQ = constvar(Q) - cP = constvar(P) + # `x` is a sequence of hidden states + x = randomvar(n) + # `y` is a sequence of "clamped" observations + y = datavar(Vector{Float64}, n) - # `x` is a sequence of hidden states - x = randomvar(n) - # `y` is a sequence of "clamped" observations - y = datavar(Vector{Float64}, n) + x_prior ~ MvNormal(mean = mean(x0), cov = cov(x0)) + x_prev = x_prior - x_prior ~ MvNormal(mean = mean(x0), cov = cov(x0)) - x_prev = x_prior - - for i in 1:n - x[i] ~ MvNormal(mean = cA * x_prev, cov = cQ) - y[i] ~ MvNormal(mean = cB * x[i], cov = cP) - x_prev = x[i] + for i in 1:n + x[i] ~ MvNormal(mean = cA * x_prev, cov = cQ) + y[i] ~ MvNormal(mean = cB * x[i], cov = cP) + x_prev = x[i] + end end -end -## Inference definition -function multivariate_lgssm_inference(data, x0, A, B, Q, P) - return inference(model = multivariate_lgssm_model(length(data), x0, A, B, Q, P), data = (y = data,), free_energy = true, options = (limit_stack_depth = 500,)) -end - -@testset "Linear Gaussian State Space Model" begin + ## Inference definition + function multivariate_lgssm_inference(data, x0, A, B, Q, P) + return inference(model = multivariate_lgssm_model(length(data), x0, A, B, Q, P), data = (y = data,), free_energy = true, options = (limit_stack_depth = 500,)) + end ## Data creation function generate_data(rng, A, B, Q, P) @@ -110,6 +106,4 @@ end end @test_benchmark "models" "mlgssm" multivariate_lgssm_inference($y, $x0, $A, $B, $Q, $P) -end - -end +end \ No newline at end of file diff --git a/test/models/statespace/test_probit.jl b/test/models/statespace/probit_tests.jl similarity index 62% rename from test/models/statespace/test_probit.jl rename to test/models/statespace/probit_tests.jl index 19b8c44bb..d73a947d5 100644 --- a/test/models/statespace/test_probit.jl +++ b/test/models/statespace/probit_tests.jl @@ -1,34 +1,30 @@ -module RxInferModelsProbitTest +@testitem "Probit Model" begin + using BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs + using StatsFuns: normcdf -using StatsFuns: normcdf + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + # Please use StableRNGs for random number generators -# Please use StableRNGs for random number generators + ## Model definition + @model function probit_model(nr_samples::Int64) + x = randomvar(nr_samples + 1) + y = datavar(Float64, nr_samples) -## Model definition -@model function probit_model(nr_samples::Int64) - x = randomvar(nr_samples + 1) - y = datavar(Float64, nr_samples) + x[1] ~ Normal(mean = 0.0, precision = 0.01) - x[1] ~ Normal(mean = 0.0, precision = 0.01) - - for k in 2:(nr_samples + 1) - x[k] ~ Normal(mean = x[k - 1] + 0.1, precision = 100) - y[k - 1] ~ Probit(x[k]) where {pipeline = RequireMessage(in = NormalMeanPrecision(0, 1.0))} + for k in 2:(nr_samples + 1) + x[k] ~ Normal(mean = x[k - 1] + 0.1, precision = 100) + y[k - 1] ~ Probit(x[k]) where {pipeline = RequireMessage(in = NormalMeanPrecision(0, 1.0))} + end end -end -## Inference definition -function probit_inference(data_y) - return inference(model = probit_model(length(data_y)), data = (y = data_y,), iterations = 10, returnvars = (x = KeepLast(),), free_energy = true) -end - -@testset "Probit Model" begin + ## Inference definition + function probit_inference(data_y) + return inference(model = probit_model(length(data_y)), data = (y = data_y,), iterations = 10, returnvars = (x = KeepLast(),), free_energy = true) + end ## Data creation function generate_data(nr_samples::Int64; seed = 123) @@ -86,6 +82,4 @@ end end @test_benchmark "models" "probit" probit_inference($data_y) -end - -end +end \ No newline at end of file diff --git a/test/models/statespace/test_ulgssm.jl b/test/models/statespace/ulgssm_tests.jl similarity index 60% rename from test/models/statespace/test_ulgssm.jl rename to test/models/statespace/ulgssm_tests.jl index 720d152a9..be1b124f3 100644 --- a/test/models/statespace/test_ulgssm.jl +++ b/test/models/statespace/ulgssm_tests.jl @@ -1,33 +1,29 @@ -module RxInferModelsMLGSSMTest +@testitem "Univariate Linear Gaussian State Space Model" begin + using BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs -using Test, InteractiveUtils -using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs + # `include(test/utiltests.jl)` + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) -# `include(test/utiltests.jl)` -include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) + @model function univariate_lgssm_model(n, x0, c_, P_) + x_prior ~ Normal(mean = mean(x0), var = var(x0)) -@model function univariate_lgssm_model(n, x0, c_, P_) - x_prior ~ Normal(mean = mean(x0), var = var(x0)) + x = randomvar(n) + c = constvar(c_) + P = constvar(P_) + y = datavar(Float64, n) - x = randomvar(n) - c = constvar(c_) - P = constvar(P_) - y = datavar(Float64, n) + x_prev = x_prior - x_prev = x_prior - - for i in 1:n - x[i] ~ x_prev + c - y[i] ~ Normal(mean = x[i], var = P) - x_prev = x[i] + for i in 1:n + x[i] ~ x_prev + c + y[i] ~ Normal(mean = x[i], var = P) + x_prev = x[i] + end end -end - -function univariate_lgssm_inference(data, x0, c, P) - return inference(model = univariate_lgssm_model(length(data), x0, c, P), data = (y = data,), free_energy = true) -end -@testset "Univariate Linear Gaussian State Space Model" begin + function univariate_lgssm_inference(data, x0, c, P) + return inference(model = univariate_lgssm_model(length(data), x0, c, P), data = (y = data,), free_energy = true) + end ## Data creation rng = StableRNG(123) @@ -62,6 +58,4 @@ end end @test_benchmark "models" "ulgssm" univariate_lgssm_inference($data, $x0_prior, 1.0, $P) -end - -end +end \ No newline at end of file diff --git a/test/score/test_actor.jl b/test/score/actor_tests.jl similarity index 95% rename from test/score/test_actor.jl rename to test/score/actor_tests.jl index 54c544db9..070a5c01c 100644 --- a/test/score/test_actor.jl +++ b/test/score/actor_tests.jl @@ -1,12 +1,9 @@ -module RxInferScoreActorTest +@testitem "ScoreActor tests" begin + using Random + using RxInfer -using Test, Random -using RxInfer - -import RxInfer: ScoreActor, score_snapshot, score_snapshot_final, score_snapshot_iterations -import Rocket: release! - -@testset "ScoreActor tests" begin + import RxInfer: ScoreActor, score_snapshot, score_snapshot_final, score_snapshot_iterations + import Rocket: release! @testset "Basic functionality #1" begin actor = ScoreActor(Float64, 10, 1) @@ -174,6 +171,4 @@ import Rocket: release! @test length(aggregated) === 10 @test aggregated == 21:30 end -end - -end +end \ No newline at end of file diff --git a/test/score/test_bfe.jl b/test/score/bfe_tests.jl similarity index 93% rename from test/score/test_bfe.jl rename to test/score/bfe_tests.jl index aee8139e9..f7ed29acc 100644 --- a/test/score/test_bfe.jl +++ b/test/score/bfe_tests.jl @@ -1,12 +1,7 @@ -module RxInferScoreTest - -using Test, Random -using RxInfer - -import RxInfer: get_skip_strategy, get_scheduler, apply_diagnostic_check -import ReactiveMP: CountingReal, FactorNodeCreationOptions, make_node, activate! - -@testset "BetheFreeEnergy score tests" begin +@testitem "BetheFreeEnergy score tests" begin + using Random + import RxInfer: get_skip_strategy, get_scheduler, apply_diagnostic_check + import ReactiveMP: CountingReal, FactorNodeCreationOptions, make_node, activate! @testset "Diagnostic check tests" begin @testset "`BetheFreeEnergyCheckInfs` diagnostic" begin stream = Subject(Any) @@ -95,6 +90,4 @@ import ReactiveMP: CountingReal, FactorNodeCreationOptions, make_node, activate! @test length(events) === 4 end end -end - -end +end \ No newline at end of file From 5e6896ee3abb53f621931ca27aaf37f6f02cb609 Mon Sep 17 00:00:00 2001 From: MarcoH Date: Mon, 27 Nov 2023 15:58:57 +0100 Subject: [PATCH 5/9] Adding ReTestItems to the Project.toml extras --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index f8d8871cd..9cb66a192 100644 --- a/Project.toml +++ b/Project.toml @@ -58,6 +58,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" [targets] test = ["Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "Dates", "Distributed", "Documenter", "Plots", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "Aqua", "StableRNGs", "StatsFuns", "Optimisers"] From fa171d4f751fc62225b728391270dbd02ad63fb3 Mon Sep 17 00:00:00 2001 From: MarcoH Date: Mon, 27 Nov 2023 16:08:09 +0100 Subject: [PATCH 6/9] Fix code style issues --- .../factorisation_constraints_tests.jl | 1 - .../form/form_sample_list_tests.jl | 1 - test/constraints/meta_constraints_tests.jl | 2 +- test/helpers_tests.jl | 1 - test/model_tests.jl | 1 - test/models/aliases/aliases_binary_tests.jl | 2 +- test/models/aliases/aliases_normal_tests.jl | 2 -- test/models/autoregressive/ar_tests.jl | 6 ++--- test/models/autoregressive/lar_tests.jl | 3 +-- test/models/datavars/fn_datavars_tests.jl | 24 +++++++++---------- .../iid/mv_iid_covariance_known_mean_tests.jl | 3 +-- test/models/iid/mv_iid_covariance_tests.jl | 12 +++++----- .../iid/mv_iid_precision_known_mean_tests.jl | 6 ++--- test/models/iid/mv_iid_precision_tests.jl | 2 +- .../models/mixtures/gmm_multivariate_tests.jl | 2 +- test/models/mixtures/gmm_univariate_tests.jl | 4 ++-- test/models/mixtures/mixture_tests.jl | 2 +- test/models/nonlinear/cvi_tests.jl | 2 +- .../nonlinear/generic_applicability_tests.jl | 14 +++++++---- test/models/regression/linreg_tests.jl | 2 +- test/models/statespace/hgf_tests.jl | 2 +- test/models/statespace/hmm_tests.jl | 2 +- test/models/statespace/mlgssm_test.jl | 2 +- test/models/statespace/probit_tests.jl | 2 +- test/models/statespace/ulgssm_tests.jl | 2 +- test/node_tests.jl | 2 +- test/runtests.jl | 8 +++---- test/score/actor_tests.jl | 2 +- test/score/bfe_tests.jl | 4 ++-- 29 files changed, 56 insertions(+), 62 deletions(-) diff --git a/test/constraints/factorisation_constraints_tests.jl b/test/constraints/factorisation_constraints_tests.jl index 03a45f266..665654d72 100644 --- a/test/constraints/factorisation_constraints_tests.jl +++ b/test/constraints/factorisation_constraints_tests.jl @@ -630,4 +630,3 @@ @test_throws ErrorException ReactiveMP.resolve_factorisation(cs, getvariables(model), fform, (x, y)) end end - diff --git a/test/constraints/form/form_sample_list_tests.jl b/test/constraints/form/form_sample_list_tests.jl index 45f6ca69b..bd13fbd46 100644 --- a/test/constraints/form/form_sample_list_tests.jl +++ b/test/constraints/form/form_sample_list_tests.jl @@ -51,4 +51,3 @@ end end end - diff --git a/test/constraints/meta_constraints_tests.jl b/test/constraints/meta_constraints_tests.jl index f83168628..06feabd66 100644 --- a/test/constraints/meta_constraints_tests.jl +++ b/test/constraints/meta_constraints_tests.jl @@ -225,4 +225,4 @@ @test resolve_meta(meta, SomeNode, (y, z)) === nothing @test resolve_meta(meta, SomeNode, (z,)) === nothing end -end \ No newline at end of file +end diff --git a/test/helpers_tests.jl b/test/helpers_tests.jl index 230a6cc99..987f492b4 100644 --- a/test/helpers_tests.jl +++ b/test/helpers_tests.jl @@ -33,4 +33,3 @@ end @test_throws ErrorException unval(()) @test_throws ErrorException unval(nothing) end - diff --git a/test/model_tests.jl b/test/model_tests.jl index d24af5292..73f227de8 100644 --- a/test/model_tests.jl +++ b/test/model_tests.jl @@ -182,4 +182,3 @@ @test_throws ErrorException ReactiveMP.make_node(FactorGraphModel(), FactorNodeCreationOptions(), DummyDistributionTestModelError3, randomvar(:θ)) end end - diff --git a/test/models/aliases/aliases_binary_tests.jl b/test/models/aliases/aliases_binary_tests.jl index 673646bb8..edb81a805 100644 --- a/test/models/aliases/aliases_binary_tests.jl +++ b/test/models/aliases/aliases_binary_tests.jl @@ -18,4 +18,4 @@ # Here we simply test that it ran and gave some output @test mean(results.posteriors[:x1]) ≈ 0.5 @test first(results.free_energy) ≈ 0.6931471805599454 -end \ No newline at end of file +end diff --git a/test/models/aliases/aliases_normal_tests.jl b/test/models/aliases/aliases_normal_tests.jl index fb38b5eff..6df36e707 100644 --- a/test/models/aliases/aliases_normal_tests.jl +++ b/test/models/aliases/aliases_normal_tests.jl @@ -41,5 +41,3 @@ @test first(mean(result.posteriors[:x1])) ≈ 0.04182509505703423 @test first(result.free_energy) ≈ 2.319611135721246 end - - diff --git a/test/models/autoregressive/ar_tests.jl b/test/models/autoregressive/ar_tests.jl index f05b3045a..7e9a8cc42 100644 --- a/test/models/autoregressive/ar_tests.jl +++ b/test/models/autoregressive/ar_tests.jl @@ -1,6 +1,6 @@ @testitem "Autoregressive model" begin using StableRNGs, BenchmarkTools - + # `include(test/utiltests.jl)` include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) @@ -40,8 +40,6 @@ end rng = StableRNG(1234) - - ## Inference execution and test inference results for order in 1:5 series = randn(rng, 1_000) @@ -63,4 +61,4 @@ inputs5, outputs5 = ar_ssm(benchrng, 5) @test_benchmark "models" "ar" ar_inference($inputs5, $outputs5, 5, 15) -end \ No newline at end of file +end diff --git a/test/models/autoregressive/lar_tests.jl b/test/models/autoregressive/lar_tests.jl index 661e2e77b..cb1113f25 100644 --- a/test/models/autoregressive/lar_tests.jl +++ b/test/models/autoregressive/lar_tests.jl @@ -106,7 +106,6 @@ return states[(1 + 3order):end], observations[(1 + 3order):end] end - # Seed for reproducibility rng = StableRNG(123) @@ -178,4 +177,4 @@ end @test_benchmark "models" "lar" lar_inference($observations, length($real_θ), Multivariate, ARsafe(), 15, $real_τ) -end \ No newline at end of file +end diff --git a/test/models/datavars/fn_datavars_tests.jl b/test/models/datavars/fn_datavars_tests.jl index e79c1921f..46162ed38 100644 --- a/test/models/datavars/fn_datavars_tests.jl +++ b/test/models/datavars/fn_datavars_tests.jl @@ -1,47 +1,47 @@ @testitem "datavars" begin - using StableRNGs + using StableRNGs # Please use StableRNGs for random number generators - + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) - + ## Model definition @model function sum_datavars_as_gaussian_mean_1() a = datavar(Float64) b = datavar(Float64) y = datavar(Float64) - + x ~ Normal(mean = a + b, variance = 1.0) y ~ Normal(mean = x, variance = 1.0) end - + @model function sum_datavars_as_gaussian_mean_2() a = datavar(Float64) b = datavar(Float64) c = constvar(0.0) # Should not change the result y = datavar(Float64) - + x ~ Normal(mean = (a + b) + c, variance = 1.0) y ~ Normal(mean = x, variance = 1.0) end - + @model function ratio_datavars_as_gaussian_mean() a = datavar(Float64) b = datavar(Float64) y = datavar(Float64) - + x ~ Normal(mean = a / b, variance = 1.0) y ~ Normal(mean = x, variance = 1.0) end - + @model function idx_datavars_as_gaussian_mean() a = datavar(Vector{Float64}) b = datavar(Matrix{Float64}) y = datavar(Float64) - + x ~ Normal(mean = dot(a[1:2], b[1:2, 1]), variance = 1.0) y ~ Normal(mean = x, variance = 1.0) end - + # Inference function function fn_datavars_inference(modelfn, adata, bdata, ydata) return inference(model = modelfn(), data = (a = adata, b = bdata, y = ydata), free_energy = true) @@ -95,4 +95,4 @@ A_data = [1.0, 2.0, 3.0] B_data = [1.0 0.5; 0.5 1.0] @test_broken result = fn_datavars_inference(idx_datavars_as_gaussian_mean, A_data, B_data, ydata) -end \ No newline at end of file +end diff --git a/test/models/iid/mv_iid_covariance_known_mean_tests.jl b/test/models/iid/mv_iid_covariance_known_mean_tests.jl index 5c7ccda03..ca4608b04 100644 --- a/test/models/iid/mv_iid_covariance_known_mean_tests.jl +++ b/test/models/iid/mv_iid_covariance_known_mean_tests.jl @@ -20,7 +20,6 @@ return inference(model = mv_iid_inverse_wishart_known_mean(mean, n, d), data = (y = data,), iterations = 10, returnvars = KeepLast(), free_energy = Float64) end - ## Data creation rng = StableRNG(123) @@ -51,4 +50,4 @@ end @test_benchmark "models" "iid_inverse_wishart_known_mean" inference_mv_inverse_wishart_known_mean($m, $data, $n, $d) -end \ No newline at end of file +end diff --git a/test/models/iid/mv_iid_covariance_tests.jl b/test/models/iid/mv_iid_covariance_tests.jl index 266bbef14..9cfe7fc60 100644 --- a/test/models/iid/mv_iid_covariance_tests.jl +++ b/test/models/iid/mv_iid_covariance_tests.jl @@ -1,24 +1,24 @@ @testitem "Multivariate IID: Covariance parametrisation" begin using StableRNGs, Plots, BenchmarkTools - + # `include(test/utiltests.jl)` include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) - + @model function mv_iid_inverse_wishart(n, d) m ~ MvNormal(mean = zeros(d), precision = 100 * diageye(d)) C ~ InverseWishart(d + 1, diageye(d)) - + y = datavar(Vector{Float64}, n) - + for i in 1:n y[i] ~ MvNormal(mean = m, covariance = C) end end - + @constraints function constraints_mv_iid_inverse_wishart() q(m, C) = q(m)q(C) end - + function inference_mv_inverse_wishart(data, n, d) return inference( model = mv_iid_inverse_wishart(n, d), diff --git a/test/models/iid/mv_iid_precision_known_mean_tests.jl b/test/models/iid/mv_iid_precision_known_mean_tests.jl index b6c3f36c3..2145ec6c9 100644 --- a/test/models/iid/mv_iid_precision_known_mean_tests.jl +++ b/test/models/iid/mv_iid_precision_known_mean_tests.jl @@ -1,5 +1,5 @@ -@testitem "Multivariate IID: Precision parametrisation with known mean" begin - using StableRNGs, BenchmarkTools, Plots +@testitem "Multivariate IID: Precision parametrisation with known mean" begin + using StableRNGs, BenchmarkTools, Plots # Please use StableRNGs for random number generators # `include(test/utiltests.jl)` @@ -53,4 +53,4 @@ end @test_benchmark "models" "iid_wishart_known_mean" inference_mv_wishart_known_mean($m, $data, $n, $d) -end \ No newline at end of file +end diff --git a/test/models/iid/mv_iid_precision_tests.jl b/test/models/iid/mv_iid_precision_tests.jl index edead3842..2f0f83191 100644 --- a/test/models/iid/mv_iid_precision_tests.jl +++ b/test/models/iid/mv_iid_precision_tests.jl @@ -67,4 +67,4 @@ end @test_benchmark "models" "iid_mv_wishart" inference_mv_wishart($data, $n, $d) -end \ No newline at end of file +end diff --git a/test/models/mixtures/gmm_multivariate_tests.jl b/test/models/mixtures/gmm_multivariate_tests.jl index 36fcf548e..46a99450b 100644 --- a/test/models/mixtures/gmm_multivariate_tests.jl +++ b/test/models/mixtures/gmm_multivariate_tests.jl @@ -160,4 +160,4 @@ end @test_benchmark "models" "gmm_multivariate" inference_multivariate(StableRNG(123), $L, $nmixtures, $y, 25, MeanField()) -end \ No newline at end of file +end diff --git a/test/models/mixtures/gmm_univariate_tests.jl b/test/models/mixtures/gmm_univariate_tests.jl index ee17887e3..bd8b7a87e 100644 --- a/test/models/mixtures/gmm_univariate_tests.jl +++ b/test/models/mixtures/gmm_univariate_tests.jl @@ -1,4 +1,4 @@ -@testitem "Univariate Gaussian Mixture model " begin +@testitem "Univariate Gaussian Mixture model " begin using BenchmarkTools, Plots, LinearAlgebra, StableRNGs # `include(test/utiltests.jl)` @@ -129,4 +129,4 @@ end @test_benchmark "models" "gmm_univariate" inference_univariate($y, 10, MeanField()) -end \ No newline at end of file +end diff --git a/test/models/mixtures/mixture_tests.jl b/test/models/mixtures/mixture_tests.jl index d40d9e340..223cdceba 100644 --- a/test/models/mixtures/mixture_tests.jl +++ b/test/models/mixtures/mixture_tests.jl @@ -104,4 +104,4 @@ return p end end -end \ No newline at end of file +end diff --git a/test/models/nonlinear/cvi_tests.jl b/test/models/nonlinear/cvi_tests.jl index 546b5ec4b..26db4b50d 100644 --- a/test/models/nonlinear/cvi_tests.jl +++ b/test/models/nonlinear/cvi_tests.jl @@ -119,4 +119,4 @@ @test_benchmark "models" "cvi" inference_cvi($transformed, $rng, 110) end -end \ No newline at end of file +end diff --git a/test/models/nonlinear/generic_applicability_tests.jl b/test/models/nonlinear/generic_applicability_tests.jl index 0a70d7114..1252b43b3 100644 --- a/test/models/nonlinear/generic_applicability_tests.jl +++ b/test/models/nonlinear/generic_applicability_tests.jl @@ -105,7 +105,9 @@ ) return map(metas) do meta - return inference(model = delta_1input(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs())) + return inference( + model = delta_1input(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs()) + ) end end @@ -120,7 +122,9 @@ ) return map(metas) do meta - return inference(model = delta_2inputs(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs())) + return inference( + model = delta_2inputs(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs()) + ) end end @@ -128,7 +132,9 @@ metas = (DeltaMeta(method = Linearization()), DeltaMeta(method = Unscented()), Linearization(), Unscented()) return map(metas) do meta - return inference(model = delta_3inputs(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs())) + return inference( + model = delta_3inputs(meta), data = (y2 = data,), free_energy = true, free_energy_diagnostics = (BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs()) + ) end end @@ -155,4 +161,4 @@ ## All models have been created. The inference finished without errors ## @test true end -end \ No newline at end of file +end diff --git a/test/models/regression/linreg_tests.jl b/test/models/regression/linreg_tests.jl index cf3c0f0ef..7518d09b6 100644 --- a/test/models/regression/linreg_tests.jl +++ b/test/models/regression/linreg_tests.jl @@ -75,4 +75,4 @@ @test_benchmark "models" "linreg" linreg_inference(linear_regression, 25, $xdata, $ydata) @test_benchmark "models" "linreg_broadcasted" linreg_inference(linear_regression_broadcasted, 25, $xdata, $ydata) -end \ No newline at end of file +end diff --git a/test/models/statespace/hgf_tests.jl b/test/models/statespace/hgf_tests.jl index 8a760d0f9..03feda172 100644 --- a/test/models/statespace/hgf_tests.jl +++ b/test/models/statespace/hgf_tests.jl @@ -141,4 +141,4 @@ end @test_benchmark "models" "hgf" hgf_online_inference($y, $vmp_iters, $real_k, $real_w, $z_variance, $y_variance) -end \ No newline at end of file +end diff --git a/test/models/statespace/hmm_tests.jl b/test/models/statespace/hmm_tests.jl index 70f64e573..3437a0a79 100644 --- a/test/models/statespace/hmm_tests.jl +++ b/test/models/statespace/hmm_tests.jl @@ -102,4 +102,4 @@ end @test_benchmark "models" "mlgssm" hidden_markov_model_inference($x_data, 20) -end \ No newline at end of file +end diff --git a/test/models/statespace/mlgssm_test.jl b/test/models/statespace/mlgssm_test.jl index 0c336fa92..14f763dbd 100644 --- a/test/models/statespace/mlgssm_test.jl +++ b/test/models/statespace/mlgssm_test.jl @@ -106,4 +106,4 @@ end @test_benchmark "models" "mlgssm" multivariate_lgssm_inference($y, $x0, $A, $B, $Q, $P) -end \ No newline at end of file +end diff --git a/test/models/statespace/probit_tests.jl b/test/models/statespace/probit_tests.jl index d73a947d5..fa5c2042f 100644 --- a/test/models/statespace/probit_tests.jl +++ b/test/models/statespace/probit_tests.jl @@ -82,4 +82,4 @@ end @test_benchmark "models" "probit" probit_inference($data_y) -end \ No newline at end of file +end diff --git a/test/models/statespace/ulgssm_tests.jl b/test/models/statespace/ulgssm_tests.jl index be1b124f3..19df19486 100644 --- a/test/models/statespace/ulgssm_tests.jl +++ b/test/models/statespace/ulgssm_tests.jl @@ -58,4 +58,4 @@ end @test_benchmark "models" "ulgssm" univariate_lgssm_inference($data, $x0_prior, 1.0, $P) -end \ No newline at end of file +end diff --git a/test/node_tests.jl b/test/node_tests.jl index 0435e468b..2cacaa006 100644 --- a/test/node_tests.jl +++ b/test/node_tests.jl @@ -709,4 +709,4 @@ end end end -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index d284e0f9d..eb4af05b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,13 +1,11 @@ -using Aqua, CpuId,ReTestItems,RxInfer +using Aqua, CpuId, ReTestItems, RxInfer # runtests( # "./"; # ) -Aqua.test_all(RxInfer; ambiguities=false, piracies=false, deps_compat = (; check_extras = false, check_weakdeps = true)) +Aqua.test_all(RxInfer; ambiguities = false, piracies = false, deps_compat = (; check_extras = false, check_weakdeps = true)) nthreads = max(cputhreads(), 1) ncores = max(cpucores(), 1) -runtests( - RxInfer; nworkers=ncores, nworker_threads=Int(nthreads / ncores), memory_threshold=1.0 -) \ No newline at end of file +runtests(RxInfer; nworkers = ncores, nworker_threads = Int(nthreads / ncores), memory_threshold = 1.0) diff --git a/test/score/actor_tests.jl b/test/score/actor_tests.jl index 070a5c01c..38504342a 100644 --- a/test/score/actor_tests.jl +++ b/test/score/actor_tests.jl @@ -171,4 +171,4 @@ @test length(aggregated) === 10 @test aggregated == 21:30 end -end \ No newline at end of file +end diff --git a/test/score/bfe_tests.jl b/test/score/bfe_tests.jl index f7ed29acc..9a37871c8 100644 --- a/test/score/bfe_tests.jl +++ b/test/score/bfe_tests.jl @@ -1,5 +1,5 @@ @testitem "BetheFreeEnergy score tests" begin - using Random + using Random import RxInfer: get_skip_strategy, get_scheduler, apply_diagnostic_check import ReactiveMP: CountingReal, FactorNodeCreationOptions, make_node, activate! @testset "Diagnostic check tests" begin @@ -90,4 +90,4 @@ @test length(events) === 4 end end -end \ No newline at end of file +end From 078fccb41fa77b7a7e63254356115320464aac1e Mon Sep 17 00:00:00 2001 From: MarcoH Date: Mon, 27 Nov 2023 17:39:52 +0100 Subject: [PATCH 7/9] adding ReTestItems to targets --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9cb66a192..0c2ffb246 100644 --- a/Project.toml +++ b/Project.toml @@ -61,4 +61,4 @@ TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" [targets] -test = ["Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "Dates", "Distributed", "Documenter", "Plots", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "Aqua", "StableRNGs", "StatsFuns", "Optimisers"] +test = ["Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "Dates", "Distributed", "Documenter", "Plots", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "Aqua", "StableRNGs", "StatsFuns", "Optimisers", "ReTestItems"] From 44026a1ced3c0400786b02716606a843413ea410 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Tue, 28 Nov 2023 09:15:13 +0100 Subject: [PATCH 8/9] add CpuId to test deps --- Project.toml | 5 +++-- test/runtests.jl | 3 --- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 0c2ffb246..27d44bc5e 100644 --- a/Project.toml +++ b/Project.toml @@ -45,6 +45,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkCI = "20533458-34a3-403d-a444-e18f38190b5b" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037" +CpuId = "adafc99b-e345-5852-983c-f28acb93d879" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" @@ -54,11 +55,11 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" -ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" [targets] -test = ["Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "Dates", "Distributed", "Documenter", "Plots", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "Aqua", "StableRNGs", "StatsFuns", "Optimisers", "ReTestItems"] +test = ["Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "CpuId", "Dates", "Distributed", "Documenter", "Plots", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "Aqua", "StableRNGs", "StatsFuns", "Optimisers", "ReTestItems"] diff --git a/test/runtests.jl b/test/runtests.jl index eb4af05b8..94b191171 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,5 @@ using Aqua, CpuId, ReTestItems, RxInfer -# runtests( -# "./"; -# ) Aqua.test_all(RxInfer; ambiguities = false, piracies = false, deps_compat = (; check_extras = false, check_weakdeps = true)) nthreads = max(cputhreads(), 1) From d6d8e23c65e163654d1962da44507e55e275fe77 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Tue, 28 Nov 2023 09:38:27 +0100 Subject: [PATCH 9/9] adjust Makefile --- Makefile | 6 +++--- test/runtests.jl | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 56fc788d2..aede6746d 100644 --- a/Makefile +++ b/Makefile @@ -79,11 +79,11 @@ devdocs: dev_doc_init ## Same as `make docs` but uses `dev-ed` versions of core .PHONY: test -test: ## Run tests, use test_args="folder1:test1 folder2:test2" argument to run reduced testset, use dev=true to use `dev-ed` version of core packages - julia -e 'ENV["USE_DEV"]="$(dev)"; import Pkg; Pkg.activate("."); Pkg.test(test_args = split("$(test_args)") .|> string)' +test: ## Run tests, use dev=true to use `dev-ed` version of core packages + julia -e 'ENV["USE_DEV"]="$(dev)"; import Pkg; Pkg.activate("."); Pkg.test()' devtest: ## Alias for the `make test dev=true ...` - julia -e 'ENV["USE_DEV"]="true"; import Pkg; Pkg.activate("."); Pkg.test(test_args = split("$(test_args)") .|> string)' + julia -e 'ENV["USE_DEV"]="true"; import Pkg; Pkg.activate("."); Pkg.test()' clean: ## Clean documentation build, precompiled examples, benchmark output from tests $(foreach file, $(ALL_TMP_FILES), $(RM) $(file)) diff --git a/test/runtests.jl b/test/runtests.jl index 94b191171..bb6616bc5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,31 @@ using Aqua, CpuId, ReTestItems, RxInfer +const IS_USE_DEV = get(ENV, "USE_DEV", "false") == "true" +const IS_BENCHMARK = get(ENV, "BENCHMARK", "false") == "true" + +import Pkg + +if IS_USE_DEV + Pkg.rm("ReactiveMP") + Pkg.rm("GraphPPL") + Pkg.rm("Rocket") + Pkg.develop(Pkg.PackageSpec(path = joinpath(Pkg.devdir(), "ReactiveMP.jl"))) + Pkg.develop(Pkg.PackageSpec(path = joinpath(Pkg.devdir(), "GraphPPL.jl"))) + Pkg.develop(Pkg.PackageSpec(path = joinpath(Pkg.devdir(), "Rocket.jl"))) + Pkg.resolve() + Pkg.update() +end + Aqua.test_all(RxInfer; ambiguities = false, piracies = false, deps_compat = (; check_extras = false, check_weakdeps = true)) nthreads = max(cputhreads(), 1) ncores = max(cpucores(), 1) +# We use only `1` runner in case if benchmarks are enabled to improve the +# quality of the benchmarking procedure +if IS_BENCHMARK + nthreads = 1 + ncores = 1 +end + runtests(RxInfer; nworkers = ncores, nworker_threads = Int(nthreads / ncores), memory_threshold = 1.0)