Skip to content

Commit

Permalink
Merge pull request #136 from SciML/fm/states_pred
Browse files Browse the repository at this point in the history
Adding `save_states`
  • Loading branch information
MartinuzziFrancesco authored Sep 17, 2022
2 parents 4814d55 + acf5776 commit c916c45
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReservoirComputing"
uuid = "7c2d2b1e-3dd4-11ea-355a-8f6a8116e294"
authors = ["Francesco Martinuzzi"]
version = "0.9.0"
version = "0.9.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
4 changes: 4 additions & 0 deletions docs/src/esn_tutorials/lorenz_basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ Once the ```OutputLayer``` has been obtained the prediction can be done followin
output = esn(Generative(predict_len), output_layer)
```
both the training method and the output layer are needed in this call. The number of steps for the prediction must be specified to the ```Generative``` method. The output results are given in a matrix.

!!! info "Saving the states during prediction"
While the states are saved in the `ESN` struct for the training, for the prediction they are not saved by default. To inspect the states it is necessary to pass the boolean keyword argument `save_states` to the prediction call, in this example using `esn(... ; save_states=true)`. This returns a tuple `(output, states)` where `size(states) = res_size, prediction_len`

To inspect the results they can easily be plotted using an external library. In this case ```Plots``` is adopted:
```julia
using Plots, Plots.PlotMeasures
Expand Down
8 changes: 4 additions & 4 deletions src/esn/echostatenetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ end

function (esn::ESN)(prediction::AbstractPrediction,
output_layer::AbstractOutputLayer;
initial_conditions = output_layer.last_value,
last_state = esn.states[:, [end]])
last_state = esn.states[:, [end]],
kwargs...)
variation = esn.variation
pred_len = prediction.prediction_len

Expand All @@ -178,10 +178,10 @@ function (esn::ESN)(prediction::AbstractPrediction,
model_pred_data = model(u0, tspan_new, predict_tsteps)[:, 2:end]
return obtain_esn_prediction(esn, prediction, last_state, output_layer,
model_pred_data;
initial_conditions = initial_conditions)
kwargs...)
else
return obtain_esn_prediction(esn, prediction, last_state, output_layer;
initial_conditions = initial_conditions)
kwargs...)
end
end

Expand Down
14 changes: 10 additions & 4 deletions src/esn/esn_predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ function obtain_esn_prediction(esn,
x,
output_layer,
args...;
initial_conditions = output_layer.last_value)
initial_conditions = output_layer.last_value,
save_states = false)
out_size = output_layer.out_size
training_method = output_layer.training_method
prediction_len = prediction.prediction_len

output = output_storing(training_method, out_size, prediction_len, typeof(esn.states))
out = initial_conditions
states = similar(esn.states, size(esn.states, 1), prediction_len)

out_pad = allocate_outpad(esn.variation, esn.states_type, out)
tmp_array = allocate_tmp(esn.reservoir_driver, typeof(esn.states), esn.res_size)
Expand All @@ -20,23 +22,26 @@ function obtain_esn_prediction(esn,
args...)
out_tmp = get_prediction(output_layer.training_method, output_layer, x_new)
out = store_results!(output_layer.training_method, out_tmp, output, i)
states[:, i] = x
end

return output
save_states ? (output, states) : output
end

function obtain_esn_prediction(esn,
prediction::Predictive,
x,
output_layer,
args...;
initial_conditions = output_layer.last_value)
initial_conditions = output_layer.last_value,
save_states = false)
out_size = output_layer.out_size
training_method = output_layer.training_method
prediction_len = prediction.prediction_len

output = output_storing(training_method, out_size, prediction_len, typeof(esn.states))
out = initial_conditions
states = similar(esn.states, size(esn.states, 1), prediction_len)

out_pad = allocate_outpad(esn.variation, esn.states_type, out)
tmp_array = allocate_tmp(esn.reservoir_driver, typeof(esn.states), esn.res_size)
Expand All @@ -47,9 +52,10 @@ function obtain_esn_prediction(esn,
out_pad, i, tmp_array, args...)
out_tmp = get_prediction(training_method, output_layer, x_new)
out = store_results!(training_method, out_tmp, output, i)
states[:, i] = x
end

return output
save_states ? (output, states) : output
end

#prediction dispatch on esn
Expand Down
6 changes: 6 additions & 0 deletions test/esn/test_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,9 @@ for t in training_methods
output = esn(Predictive(input_data), output_layer)
@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.21
end

for t in training_methods
output_layer = train(esn, target_data, t)
output, states = esn(Predictive(input_data), output_layer, save_states = true)
@test size(states) == (res_size, size(input_data, 2))
end

2 comments on commit c916c45

@MartinuzziFrancesco
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/68467

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.1 -m "<description of version>" c916c45aa9d46f7448bc309896d8f2e581fe2d0a
git push origin v0.9.1

Please sign in to comment.