From ff0142d40f5a4e46de0c1e8adfbea773669b306f Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Mon, 9 Sep 2024 00:50:37 +0330 Subject: [PATCH] revert `function_annotation` --- README.md | 2 +- benchmark/benchmarks.jl | 8 ++------ src/base_icnf.jl | 4 +--- test/call_tests.jl | 16 ++++------------ test/fit_tests.jl | 16 ++++------------ test/instability_tests.jl | 4 +--- test/regression_tests.jl | 4 +--- 7 files changed, 14 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 6cd6e3f2..b7b6b0a5 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ icnf = construct( nn, nvars, # number of variables naugs; # number of augmented dimensions - compute_mode = DIJacVecMatrixMode(AutoEnzyme(; function_annotation = Enzyme.Const)), # process data in batches + compute_mode = DIJacVecMatrixMode(AutoEnzyme()), # process data in batches tspan = (0.0f0, 13.0f0), # have bigger time span steer_rate = 1.0f-1, # add random noise to end of the time span # resource = CUDALibs(), # process data by GPU diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 4143546a..ef8f85ba 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -36,9 +36,7 @@ icnf = ContinuousNormalizingFlows.construct( nn, nvars, naugs; - compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), + compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoEnzyme()), tspan = (0.0f0, 13.0f0), steer_rate = 1.0f-1, λ₃ = 1.0f-2, @@ -84,9 +82,7 @@ icnf2 = ContinuousNormalizingFlows.construct( nvars, naugs; inplace = true, - compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), + compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoEnzyme()), tspan = (0.0f0, 13.0f0), steer_rate = 1.0f-1, λ₃ = 1.0f-2, diff --git a/src/base_icnf.jl b/src/base_icnf.jl index a875f269..3b1b08c1 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -4,9 +4,7 @@ function construct( nvars::Int, naugmented::Int = 0; data_type::Type{<:AbstractFloat} = Float32, - compute_mode::ComputeMode = DIJacVecMatrixMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), + compute_mode::ComputeMode = DIJacVecMatrixMode(ADTypes.AutoEnzyme()), inplace::Bool = false, cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar}, resource::ComputationalResources.AbstractResource = ComputationalResources.CPU1(), diff --git a/test/call_tests.jl b/test/call_tests.jl index 19d08501..70b47cf5 100644 --- a/test/call_tests.jl +++ b/test/call_tests.jl @@ -35,18 +35,10 @@ Test.@testset "Call Tests" begin ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()), ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIVecJacVectorMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), - ContinuousNormalizingFlows.DIJacVecVectorMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), - ContinuousNormalizingFlows.DIVecJacMatrixMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), - ContinuousNormalizingFlows.DIJacVecMatrixMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), + ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoEnzyme()), + ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoEnzyme()), + ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoEnzyme()), + ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoEnzyme()), ] data_types = Type{<:AbstractFloat}[Float32] resources = ComputationalResources.AbstractResource[ComputationalResources.CPU1()] diff --git a/test/fit_tests.jl b/test/fit_tests.jl index 60f0e833..77320cdd 100644 --- a/test/fit_tests.jl +++ b/test/fit_tests.jl @@ -32,18 +32,10 @@ Test.@testset "Fit Tests" begin ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()), ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIVecJacVectorMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), - ContinuousNormalizingFlows.DIJacVecVectorMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), - ContinuousNormalizingFlows.DIVecJacMatrixMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), - ContinuousNormalizingFlows.DIJacVecMatrixMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), + ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoEnzyme()), + ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoEnzyme()), + ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoEnzyme()), + ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoEnzyme()), ] data_types = Type{<:AbstractFloat}[Float32] resources = ComputationalResources.AbstractResource[ComputationalResources.CPU1()] diff --git a/test/instability_tests.jl b/test/instability_tests.jl index e229f72d..5bec9149 100644 --- a/test/instability_tests.jl +++ b/test/instability_tests.jl @@ -16,9 +16,7 @@ Test.@testset "Instability" begin nn, nvars, naugs; - compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), + compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoEnzyme()), tspan = (0.0f0, 13.0f0), steer_rate = 1.0f-1, λ₃ = 1.0f-2, diff --git a/test/regression_tests.jl b/test/regression_tests.jl index ccdc433b..4f47b945 100644 --- a/test/regression_tests.jl +++ b/test/regression_tests.jl @@ -11,9 +11,7 @@ Test.@testset "Regression Tests" begin nn, nvars, naugs; - compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode( - ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const), - ), + compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoEnzyme()), tspan = (0.0f0, 13.0f0), steer_rate = 1.0f-1, λ₃ = 1.0f-2,