-
-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adding pooled callbacks to VectorContinousCallback #523
base: master
Are you sure you want to change the base?
Changes from 6 commits
b5df36d
3c85f2e
24add26
09bb173
706f61f
b9ae0ca
4a331b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -122,18 +122,17 @@ function ContinuousCallback(condition,affect!; | |||||||||||||
rootfind,interp_points, | ||||||||||||||
collect(save_positions), | ||||||||||||||
dtrelax,abstol,reltol) | ||||||||||||||
|
||||||||||||||
end | ||||||||||||||
|
||||||||||||||
""" | ||||||||||||||
```julia | ||||||||||||||
VectorContinuousCallback(condition,affect!,affect_neg!,len; | ||||||||||||||
initialize = INITIALIZE_DEFAULT, | ||||||||||||||
idxs = nothing, | ||||||||||||||
rootfind=true, | ||||||||||||||
save_positions=(true,true), | ||||||||||||||
interp_points=10, | ||||||||||||||
abstol=10eps(),reltol=0) | ||||||||||||||
initialize = INITIALIZE_DEFAULT, | ||||||||||||||
idxs = nothing, | ||||||||||||||
rootfind=true, | ||||||||||||||
save_positions=(true,true), | ||||||||||||||
interp_points=10, | ||||||||||||||
abstol=10eps(),reltol=0,pooltol=nothing,pool_events=false) | ||||||||||||||
``` | ||||||||||||||
|
||||||||||||||
```julia | ||||||||||||||
|
@@ -144,7 +143,7 @@ VectorContinuousCallback(condition,affect!,len; | |||||||||||||
save_positions=(true,true), | ||||||||||||||
affect_neg! = affect!, | ||||||||||||||
interp_points=10, | ||||||||||||||
abstol=10eps(),reltol=0) | ||||||||||||||
abstol=10eps(),reltol=0,pooltol=nothing,pool_events=false) | ||||||||||||||
``` | ||||||||||||||
|
||||||||||||||
This is also a subtype of `AbstractContinuousCallback`. `CallbackSet` is not feasible when you have a large number of callbacks, | ||||||||||||||
|
@@ -159,10 +158,13 @@ multiple events. | |||||||||||||
- `affect!`: This is a function `affect!(integrator, event_index)` which lets you modify `integrator` and it tells you about | ||||||||||||||
which event occured using `event_idx` i.e. gives you index `i` for which `out[i]` came out to be zero. | ||||||||||||||
- `len`: Number of callbacks chained. This is compulsory to be specified. | ||||||||||||||
- `pool_events`: Whether multiple concurrent events should be passed as one array of indexs instead of the indexes on a time. | ||||||||||||||
- `pooltol`: Custom limit which values get grouped. Callback accepted if it's absolute value is smaller than pooltol at callback time. | ||||||||||||||
The default value is `eps(integrator.t) + eps(callback_return_type)`. | ||||||||||||||
|
||||||||||||||
Rest of the arguments have the same meaning as in [`ContinuousCallback`](@ref). | ||||||||||||||
""" | ||||||||||||||
struct VectorContinuousCallback{F1,F2,F3,F4,T,T2,I,R} <: AbstractContinuousCallback | ||||||||||||||
struct VectorContinuousCallback{F1,F2,F3,F4,T,T2,T3,I,R} <: AbstractContinuousCallback | ||||||||||||||
condition::F1 | ||||||||||||||
affect!::F2 | ||||||||||||||
affect_neg!::F3 | ||||||||||||||
|
@@ -175,15 +177,17 @@ struct VectorContinuousCallback{F1,F2,F3,F4,T,T2,I,R} <: AbstractContinuousCallb | |||||||||||||
dtrelax::R | ||||||||||||||
abstol::T | ||||||||||||||
reltol::T2 | ||||||||||||||
pooltol::T3 | ||||||||||||||
pool_events::Bool | ||||||||||||||
VectorContinuousCallback(condition::F1,affect!::F2,affect_neg!::F3,len::Int, | ||||||||||||||
initialize::F4,idxs::I,rootfind, | ||||||||||||||
interp_points,save_positions,dtrelax::R, | ||||||||||||||
abstol::T,reltol::T2) where {F1,F2,F3,F4,T,T2,I,R} = | ||||||||||||||
new{F1,F2,F3,F4,T,T2,I,R}(condition, | ||||||||||||||
abstol::T,reltol::T2, pooltol::T3, pool_events) where {F1,F2,F3,F4,T,T2,T3,I,R} = | ||||||||||||||
new{F1,F2,F3,F4,T,T2,T3,I,R}(condition, | ||||||||||||||
affect!,affect_neg!,len, | ||||||||||||||
initialize,idxs,rootfind,interp_points, | ||||||||||||||
BitArray(collect(save_positions)), | ||||||||||||||
dtrelax,abstol,reltol) | ||||||||||||||
dtrelax,abstol,reltol,pooltol,pool_events) | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
VectorContinuousCallback(condition,affect!,affect_neg!,len; | ||||||||||||||
|
@@ -193,13 +197,13 @@ VectorContinuousCallback(condition,affect!,affect_neg!,len; | |||||||||||||
save_positions=(true,true), | ||||||||||||||
interp_points=10, | ||||||||||||||
dtrelax=1, | ||||||||||||||
abstol=10eps(),reltol=0) = VectorContinuousCallback( | ||||||||||||||
abstol=10eps(),reltol=0, pooltol=missing, pool_events=false) = VectorContinuousCallback( | ||||||||||||||
condition,affect!,affect_neg!,len, | ||||||||||||||
initialize, | ||||||||||||||
idxs, | ||||||||||||||
rootfind,interp_points, | ||||||||||||||
save_positions,dtrelax, | ||||||||||||||
abstol,reltol) | ||||||||||||||
abstol,reltol, pooltol, pool_events) | ||||||||||||||
|
||||||||||||||
function VectorContinuousCallback(condition,affect!,len; | ||||||||||||||
initialize = INITIALIZE_DEFAULT, | ||||||||||||||
|
@@ -209,14 +213,13 @@ function VectorContinuousCallback(condition,affect!,len; | |||||||||||||
affect_neg! = affect!, | ||||||||||||||
interp_points=10, | ||||||||||||||
dtrelax=1, | ||||||||||||||
abstol=10eps(),reltol=0) | ||||||||||||||
abstol=10eps(),reltol=0, pooltol=missing, pool_events=false) | ||||||||||||||
|
||||||||||||||
VectorContinuousCallback( | ||||||||||||||
condition,affect!,affect_neg!,len,initialize,idxs, | ||||||||||||||
rootfind,interp_points, | ||||||||||||||
collect(save_positions), | ||||||||||||||
dtrelax,abstol,reltol) | ||||||||||||||
|
||||||||||||||
dtrelax,abstol,reltol,pooltol,pool_events) | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
""" | ||||||||||||||
|
@@ -754,6 +757,16 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte | |||||||||||||
new_t = integrator.dt | ||||||||||||||
min_event_idx = event_idx[1] | ||||||||||||||
end | ||||||||||||||
if callback.pool_events | ||||||||||||||
tmp = get_condition(integrator, callback, integrator.dt + new_t) | ||||||||||||||
if callback.pooltol isa Missing | ||||||||||||||
# This is still dubious | ||||||||||||||
pool_tol = eps(integrator.t + new_t) + eps(typeof(tmp[end])) | ||||||||||||||
else | ||||||||||||||
pool_tol = callback.pooltol | ||||||||||||||
end | ||||||||||||||
min_event_idx = findall(x-> abs(x) < pool_tol, tmp) | ||||||||||||||
end | ||||||||||||||
Comment on lines
+760
to
+769
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit surprised that one has to redo all calculations here. Is it not possible to change lines such as DiffEqBase.jl/src/callbacks.jl Lines 735 to 738 in a13b48b
min_t and min_event_idx are not replaced but that one gets a list of min_t for every index, which then allows to select all indices within some tolerance later? And to change DiffEqBase.jl/src/callbacks.jl Line 751 in a13b48b
DiffEqBase.jl/src/callbacks.jl Line 755 in a13b48b
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes we discussed that, it is the initial PR, so lets get that perfect There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code is not optimized since it does not work yet. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is necessarily the right direction. You're not looking for if the other values are near zero at the same time, you're looking for if there is another zero crossing nearby. One decent approximation would be to find the earliest event, and the move forward in time by some tolerance, and check the condition. You'd then take all of the events that were triggered in that time, and you'd know whether they are up or down crossings. Multiple crossings are extremely unlikely if this is on the size of a*eps(t). This would be insensitive to the relative size of the values. |
||||||||||||||
end | ||||||||||||||
else | ||||||||||||||
new_t = zero(typeof(integrator.t)) | ||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO usually in DiffEq default values are set to
nothing
(as indicated by the docstring) or a sensible default, maybe one could use the same asabstol
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very prototypie but thanks for your input, the doc strings aren't really updated ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The convention for the pooltol is also still in discussion, so i won't merge that proposal but keep in mind that doc strings would need to be added