Skip to content

Commit

Permalink
Merge pull request #27 from CosmologicalEmulators/develop
Browse files Browse the repository at this point in the history
Fixing a problem with loading of preprocessing
  • Loading branch information
marcobonici authored Apr 14, 2024
2 parents c593718 + d84d50a commit 08a955d
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- '1.7'
- '1.8'
- '1.9'
- '~1.10.0-0'
- '1.10'
os:
- ubuntu-latest
arch:
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Capse"
uuid = "994f66c3-d2c7-4ba6-88fb-4a10f50800ba"
authors = ["marcobonici <[email protected]>"]
version = "0.3.0"
version = "0.3.2"

[deps]
AbstractCosmologicalEmulators = "c83c1981-e5c4-4837-9eb8-c9b1572acfc6"
Expand All @@ -10,7 +10,7 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"

[compat]
AbstractCosmologicalEmulators = "0.4"
Adapt = "3"
AbstractCosmologicalEmulators = "0.5"
Adapt = "3, 4"
JSON = "0.21"
NPZ = "0.4"
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"

[compat]
BenchmarkTools = "1.3.2"
2 changes: 1 addition & 1 deletion docs/src/assets/capse_benchmark.json

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Cℓ_emu = Capse.CℓEmulator(TrainedEmulator = emu, ℓgrid=ℓgrid, InMinMax =
Postprocessing = postprocessing)
```

`Capse.jl` is a Julia package designed to emulate the computation of the CMB Angular Power Spectrum, with a speedup of several orders of magnitude compared to standard codes.
`Capse.jl` is a Julia package designed to emulate the computation of the CMB Angular Power Spectrum, with a speedup of several orders of magnitude compared to standard codes such as CAMB or CLASS. The core functionalities of `Capse.jl` are inherithed by the upstream library [`AbstractCosmologicalEmulators.jl`](https://github.com/CosmologicalEmulators/AbstractCosmologicalEmulators.jl).

## Installation

Expand Down Expand Up @@ -70,10 +70,10 @@ It is possible to pass an additional argument to the previous function, which is
- [SimpleChains](https://github.com/PumasAI/SimpleChains.jl), which is taylored for small NN running on a CPU
- [Lux](https://github.com/LuxDL/Lux.jl), which can run both on CPUs and GPUs

`SimpleChains.jl` is faster expecially for small NN on the CPU. If you wanna use something running on a GPU, you should use `Lux.jl`, which can be done adding an additional argument to the `load_emulator` function, `Capse.LuxEmulator`
`SimpleChains.jl` is faster expecially for small NN on the CPU. If you wanna use something running on a GPU, you should use `Lux.jl`, which can be loaded adding an additional argument to the `load_emulator` function, `Capse.LuxEmulator`

```julia
Cℓ_emu = Capse.load_emulator(weights_folder, Capse.LuxEmulator);
Cℓ_emu = Capse.load_emulator(weights_folder, emu = Capse.LuxEmulator);
```

Each trained emulator should be shipped with a description within the JSON file. In order to print the description, just runs:
Expand Down Expand Up @@ -108,9 +108,9 @@ Using `Lux.jl`, with the same architecture and weights, we obtain
benchmark[1]["Capse"]["Lux"] # hide
```

`SimpleChains.jl` is about 2 times faster than `Lux.jl` and they give the same result up to floating point precision.
`Lux.jl` is around 20% slower than `SimpleChains.jl` and they give the same result up to floating point precision.

These benchmarks have been performed locally, with a 12th Gen Intel® Core™ i7-1260P.
These benchmarks have been performed locally, with a 13th Gen Intel® Core™ i7-13700H.

Considering that a high-precision settings calculation performed with [`CAMB`](https://github.com/cmbant/CAMB) on the same machine requires around 60 seconds, `Capse.jl` is 5-6 order of magnitudes faster.

Expand Down
15 changes: 6 additions & 9 deletions src/Capse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ Computes and returns the ``C_\\ell``'s on the ``\\ell``-grid the emulator has be
"""
function get_Cℓ(input_params, Cℓemu::AbstractCℓEmulators)
input = deepcopy(input_params)
maximin_input!(input, Cℓemu.InMinMax)
output = Array(run_emulator(input, Cℓemu.TrainedEmulator))
inv_maximin_output!(output, Cℓemu.OutMinMax)
return Cℓemu.Postprocessing(input_params, output, Cℓemu)
norm_input = maximin_input(input_params, Cℓemu.InMinMax)
output = Array(run_emulator(norm_input, Cℓemu.TrainedEmulator))
norm_output = inv_maximin_output(output, Cℓemu.OutMinMax)
return Cℓemu.Postprocessing(input_params, norm_output, Cℓemu)
end

"""
Expand Down Expand Up @@ -80,21 +79,19 @@ The following keyword arguments are used to specify the name of the files used t
If the corresponding file in the folder you are trying to load have different names,
change the default values accordingly.
"""
function load_emulator(path::String, emu = SimpleChainsEmulator,
function load_emulator(path::String; emu = SimpleChainsEmulator,
ℓ_file = "l.npy", weights_file = "weights.npy", inminmax_file = "inminmax.npy",
outminmax_file = "outminmax.npy", nn_setup_file = "nn_setup.json",
postprocessing_file = "postprocessing.jl")
NN_dict = parsefile(path*nn_setup_file)
= npzread(path*ℓ_file)
include(path*postprocessing_file)
#we assume there is a postprocessing() function in the postprocessing_file

weights = npzread(path*weights_file)
trained_emu = Capse.init_emulator(NN_dict, weights, emu)
Cℓ_emu = Capse.CℓEmulator(TrainedEmulator = trained_emu, ℓgrid = ℓ,
InMinMax = npzread(path*inminmax_file),
OutMinMax = npzread(path*outminmax_file),
Postprocessing = postprocessing)
Postprocessing = include(path*postprocessing_file))
return Cℓ_emu
end

Expand Down

0 comments on commit 08a955d

Please sign in to comment.