Skip to content

Commit

Permalink
test enzyme with reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Jan 12, 2025
1 parent 5b1c25e commit b9bb0ef
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ icnf = construct(
nn,
nvars, # number of variables
naugs; # number of augmented dimensions
# compute_mode = DIJacVecMatrixMode(AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const)), # process data in batches and use Enzyme
# compute_mode = DIVecJacMatrixMode(AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation = Enzyme.Const)), # process data in batches and use Enzyme
# inplace = true, # use the inplace version of functions
# resource = CUDALibs(), # process data by GPU
tspan = (0.0f0, 13.0f0), # have bigger time span
Expand Down
8 changes: 4 additions & 4 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ icnf = ContinuousNormalizingFlows.construct(
nn,
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
Expand Down Expand Up @@ -92,9 +92,9 @@ icnf2 = ContinuousNormalizingFlows.construct(
nvars,
naugs;
inplace = true,
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
Expand Down
4 changes: 2 additions & 2 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ function construct(
nvars::Int,
naugmented::Int = 0;
data_type::Type{<:AbstractFloat} = Float32,
compute_mode::ComputeMode = DIJacVecMatrixMode(
compute_mode::ComputeMode = DIVecJacMatrixMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
Expand Down
4 changes: 2 additions & 2 deletions test/regression_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ Test.@testset "Regression Tests" begin
nn,
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
Expand Down

0 comments on commit b9bb0ef

Please sign in to comment.