From ab1120963b6d3ceea71af2bebe37ce4f399507f0 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 23 Jan 2025 19:42:15 +0330 Subject: [PATCH] no short func --- src/exts/mlj_ext/core_cond_icnf.jl | 4 +++- src/exts/mlj_ext/core_icnf.jl | 4 +++- test/call_tests.jl | 16 ++++++++++++---- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 42dfdc2f..fe795c2d 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -108,7 +108,9 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew) tst = @timed if model.m.compute_mode isa VectorMode logp̂x = broadcast( - (x, y) -> first(inference(model.m, TestMode(), x, y, ps, st)), + function (x, y) + first(inference(model.m, TestMode(), x, y, ps, st)) + end, eachcol(xnew), eachcol(ynew), ) diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 3c2c68df..5b33ec09 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -102,7 +102,9 @@ function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew) (ps, st) = fitresult tst = @timed if model.m.compute_mode isa VectorMode - logp̂x = broadcast(x -> first(inference(model.m, TestMode(), x, ps, st)), eachcol(xnew)) + logp̂x = broadcast(function (x) + first(inference(model.m, TestMode(), x, ps, st)) + end, eachcol(xnew)) elseif model.m.compute_mode isa MatrixMode logp̂x = first(inference(model.m, TestMode(), xnew, ps, st)) else diff --git a/test/call_tests.jl b/test/call_tests.jl index 2b74f1ac..c4a79918 100644 --- a/test/call_tests.jl +++ b/test/call_tests.jl @@ -190,8 +190,12 @@ Test.@testset "Call Tests" begin ) Test.@test !isnothing(icnf((r, r2), ps, st)) - diff_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, r, r2, x, st) - diff2_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, x, r2, ps, st) + diff_loss = function (x) + ContinuousNormalizingFlows.loss(icnf, omode, r, r2, x, st) + end + diff2_loss = function (x) + ContinuousNormalizingFlows.loss(icnf, omode, x, r2, ps, st) + end else Test.@test !isnothing( ContinuousNormalizingFlows.inference(icnf, omode, r, ps, st), @@ -209,8 +213,12 @@ Test.@testset "Call Tests" begin Test.@test !isnothing(ContinuousNormalizingFlows.loss(icnf, omode, r, ps, st)) Test.@test !isnothing(icnf(r, ps, st)) - diff_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, r, x, st) - diff2_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, x, ps, st) + diff_loss = function (x) + ContinuousNormalizingFlows.loss(icnf, omode, r, x, st) + end + diff2_loss = function (x) + ContinuousNormalizingFlows.loss(icnf, omode, x, ps, st) + end end if mt <: Union{