Skip to content

Commit

Permalink
no short func
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Jan 23, 2025
1 parent a5dc936 commit ab11209
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/exts/mlj_ext/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew)

tst = @timed if model.m.compute_mode isa VectorMode
logp̂x = broadcast(
(x, y) -> first(inference(model.m, TestMode(), x, y, ps, st)),
function (x, y)
first(inference(model.m, TestMode(), x, y, ps, st))
end,
eachcol(xnew),
eachcol(ynew),
)
Expand Down
4 changes: 3 additions & 1 deletion src/exts/mlj_ext/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew)
(ps, st) = fitresult

tst = @timed if model.m.compute_mode isa VectorMode
logp̂x = broadcast(x -> first(inference(model.m, TestMode(), x, ps, st)), eachcol(xnew))
logp̂x = broadcast(function (x)
first(inference(model.m, TestMode(), x, ps, st))
end, eachcol(xnew))
elseif model.m.compute_mode isa MatrixMode
logp̂x = first(inference(model.m, TestMode(), xnew, ps, st))
else
Expand Down
16 changes: 12 additions & 4 deletions test/call_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,12 @@ Test.@testset "Call Tests" begin
)
Test.@test !isnothing(icnf((r, r2), ps, st))

diff_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, r, r2, x, st)
diff2_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, x, r2, ps, st)
diff_loss = function (x)
ContinuousNormalizingFlows.loss(icnf, omode, r, r2, x, st)
end
diff2_loss = function (x)
ContinuousNormalizingFlows.loss(icnf, omode, x, r2, ps, st)
end
else
Test.@test !isnothing(
ContinuousNormalizingFlows.inference(icnf, omode, r, ps, st),
Expand All @@ -209,8 +213,12 @@ Test.@testset "Call Tests" begin
Test.@test !isnothing(ContinuousNormalizingFlows.loss(icnf, omode, r, ps, st))
Test.@test !isnothing(icnf(r, ps, st))

diff_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, r, x, st)
diff2_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, x, ps, st)
diff_loss = function (x)
ContinuousNormalizingFlows.loss(icnf, omode, r, x, st)
end
diff2_loss = function (x)
ContinuousNormalizingFlows.loss(icnf, omode, x, ps, st)
end
end

if mt <: Union{
Expand Down

0 comments on commit ab11209

Please sign in to comment.