Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding pooled callbacks to VectorContinousCallback #523

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,17 @@ function ContinuousCallback(condition,affect!;
rootfind,interp_points,
collect(save_positions),
dtrelax,abstol,reltol)

end

"""
```julia
VectorContinuousCallback(condition,affect!,affect_neg!,len;
initialize = INITIALIZE_DEFAULT,
idxs = nothing,
rootfind=true,
save_positions=(true,true),
interp_points=10,
abstol=10eps(),reltol=0)
initialize = INITIALIZE_DEFAULT,
idxs = nothing,
rootfind=true,
save_positions=(true,true),
interp_points=10,
abstol=10eps(),reltol=0,pooltol=nothing,pool_events=false)
```

```julia
Expand All @@ -144,7 +143,7 @@ VectorContinuousCallback(condition,affect!,len;
save_positions=(true,true),
affect_neg! = affect!,
interp_points=10,
abstol=10eps(),reltol=0)
abstol=10eps(),reltol=0,pooltol=nothing,pool_events=false)
```

This is also a subtype of `AbstractContinuousCallback`. `CallbackSet` is not feasible when you have a large number of callbacks,
Expand All @@ -159,10 +158,13 @@ multiple events.
- `affect!`: This is a function `affect!(integrator, event_index)` which lets you modify `integrator` and it tells you about
which event occured using `event_idx` i.e. gives you index `i` for which `out[i]` came out to be zero.
- `len`: Number of callbacks chained. This is compulsory to be specified.
- `pool_events`: Whether multiple concurrent events should be passed as one array of indexs instead of the indexes on a time.
- `pooltol`: Custom limit which values get grouped. Callback accepted if it's absolute value is smaller than pooltol at callback time.
The default value is `eps(integrator.t) + eps(callback_return_type)`.

Rest of the arguments have the same meaning as in [`ContinuousCallback`](@ref).
"""
struct VectorContinuousCallback{F1,F2,F3,F4,T,T2,I,R} <: AbstractContinuousCallback
struct VectorContinuousCallback{F1,F2,F3,F4,T,T2,T3,I,R} <: AbstractContinuousCallback
condition::F1
affect!::F2
affect_neg!::F3
Expand All @@ -175,15 +177,17 @@ struct VectorContinuousCallback{F1,F2,F3,F4,T,T2,I,R} <: AbstractContinuousCallb
dtrelax::R
abstol::T
reltol::T2
pooltol::T3
pool_events::Bool
VectorContinuousCallback(condition::F1,affect!::F2,affect_neg!::F3,len::Int,
initialize::F4,idxs::I,rootfind,
interp_points,save_positions,dtrelax::R,
abstol::T,reltol::T2) where {F1,F2,F3,F4,T,T2,I,R} =
new{F1,F2,F3,F4,T,T2,I,R}(condition,
abstol::T,reltol::T2, pooltol::T3, pool_events) where {F1,F2,F3,F4,T,T2,T3,I,R} =
new{F1,F2,F3,F4,T,T2,T3,I,R}(condition,
affect!,affect_neg!,len,
initialize,idxs,rootfind,interp_points,
BitArray(collect(save_positions)),
dtrelax,abstol,reltol)
dtrelax,abstol,reltol,pooltol,pool_events)
end

VectorContinuousCallback(condition,affect!,affect_neg!,len;
Expand All @@ -193,13 +197,13 @@ VectorContinuousCallback(condition,affect!,affect_neg!,len;
save_positions=(true,true),
interp_points=10,
dtrelax=1,
abstol=10eps(),reltol=0) = VectorContinuousCallback(
abstol=10eps(),reltol=0, pooltol=missing, pool_events=false) = VectorContinuousCallback(
condition,affect!,affect_neg!,len,
initialize,
idxs,
rootfind,interp_points,
save_positions,dtrelax,
abstol,reltol)
abstol,reltol, pooltol, pool_events)

function VectorContinuousCallback(condition,affect!,len;
initialize = INITIALIZE_DEFAULT,
Expand All @@ -209,14 +213,13 @@ function VectorContinuousCallback(condition,affect!,len;
affect_neg! = affect!,
interp_points=10,
dtrelax=1,
abstol=10eps(),reltol=0)
abstol=10eps(),reltol=0, pooltol=missing, pool_events=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO usually in DiffEq default values are set to nothing (as indicated by the docstring) or a sensible default, maybe one could use the same as abstol?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very prototypie but thanks for your input, the doc strings aren't really updated ...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The convention for the pooltol is also still in discussion, so i won't merge that proposal but keep in mind that doc strings would need to be added


VectorContinuousCallback(
condition,affect!,affect_neg!,len,initialize,idxs,
rootfind,interp_points,
collect(save_positions),
dtrelax,abstol,reltol)

dtrelax,abstol,reltol,pooltol,pool_events)
end

"""
Expand Down Expand Up @@ -754,6 +757,16 @@ function find_callback_time(integrator,callback::VectorContinuousCallback,counte
new_t = integrator.dt
min_event_idx = event_idx[1]
end
if callback.pool_events
tmp = get_condition(integrator, callback, integrator.dt + new_t)
if callback.pooltol isa Missing
# This is still dubious
pool_tol = eps(integrator.t + new_t) + eps(typeof(tmp[end]))
else
pool_tol = callback.pooltol
end
min_event_idx = findall(x-> abs(x) < pool_tol, tmp)
end
Comment on lines +760 to +769
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit surprised that one has to redo all calculations here. Is it not possible to change lines such as

if integrator.tdir * Θ < integrator.tdir * min_t
min_event_idx = idx
min_t = Θ
end
such that min_t and min_event_idx are not replaced but that one gets a list of min_t for every index, which then allows to select all indices within some tolerance later? And to change
min_event_idx = event_idx[1]
and
min_event_idx = event_idx[1]
to not just select the first index?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we discussed that, it is the initial PR, so lets get that perfect

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is not optimized since it does not work yet.
I have decided to keep the pooled callback logic out of branches for different types of time resolving, so there is just a single place to change it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessarily the right direction. You're not looking for if the other values are near zero at the same time, you're looking for if there is another zero crossing nearby. One decent approximation would be to find the earliest event, and the move forward in time by some tolerance, and check the condition. You'd then take all of the events that were triggered in that time, and you'd know whether they are up or down crossings. Multiple crossings are extremely unlikely if this is on the size of a*eps(t). This would be insensitive to the relative size of the values.

end
else
new_t = zero(typeof(integrator.t))
Expand Down