Skip to content

Commit

Permalink
Add SummaryCallback (#48)
Browse files Browse the repository at this point in the history
* add SummaryCallback

* format

* bump compat of TimerOutputs to v0.5.23

* fix test printing SummaryCallback

* put timer inside function
  • Loading branch information
JoshuaLampert authored May 11, 2024
1 parent cc91146 commit 9c5db6a
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 17 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
TypedPolynomials = "afbbf031-7a57-5f58-a1b9-b774a0fad08d"
WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192"

Expand All @@ -28,6 +29,7 @@ SciMLBase = "2.26"
SimpleUnPack = "1.1"
SpecialFunctions = "2"
StaticArrays = "1"
TimerOutputs = "0.5.23"
TypedPolynomials = "0.4.1"
WriteVTK = "1.18"
julia = "1.10"
6 changes: 4 additions & 2 deletions examples/PDEs/advection_diffusion_2d_basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ sd = Semidiscretization(pde, nodeset_inner, g, nodeset_boundary, u, kernel)
tspan = (0.0, 1.0)
ode = semidiscretize(sd, tspan)

callback = SaveSolutionCallback(dt = 0.01)
sol = solve(ode, Rosenbrock23(), saveat = 0.01, callback = callback)
save_solution = SaveSolutionCallback(dt = 0.01)
summary = SummaryCallback()
callbacks = CallbackSet(save_solution, summary)
sol = solve(ode, Rosenbrock23(), saveat = 0.01, callback = callbacks)
titp = TemporalInterpolation(sol)

many_nodes = homogeneous_hypercube(20; dim = 2)
Expand Down
3 changes: 2 additions & 1 deletion src/KernelInterpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using SciMLBase: ODEFunction, ODEProblem, ODESolution, DiscreteCallback, u_modif
using SimpleUnPack: @unpack
using SpecialFunctions: besselk, loggamma
using StaticArrays: StaticArrays, MVector
using TimerOutputs: TimerOutputs, TimerOutput, @timeit, print_timer, reset_timer!
using TypedPolynomials: Variable, monomials, degree
using WriteVTK: WriteVTK, vtk_grid, paraview_collection, MeshCell, VTKCellTypes,
CollectionFile
Expand Down Expand Up @@ -43,7 +44,7 @@ export interpolation_kernel, nodeset, coefficients, kernel_coefficients,
polynomial_coefficients, polynomial_basis, polyvars, system_matrix,
interpolate, solve_stationary, kernel_inner_product, kernel_norm,
TemporalInterpolation
export SaveSolutionCallback
export SaveSolutionCallback, SummaryCallback
export vtk_save, vtk_read
export examples_dir, get_examples, default_example, include_example

Expand Down
1 change: 1 addition & 0 deletions src/callbacks_step/callbacks_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
end

include("save_solution.jl")
include("summary.jl")
23 changes: 13 additions & 10 deletions src/callbacks_step/save_solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,19 @@ end

# this method is called when the callback is activated
function (solution_callback::SaveSolutionCallback)(integrator)
@unpack pvd, output_directory, extra_functions, keys = solution_callback
semi = integrator.p
@unpack nodeset_inner, nodeset_boundary = semi.spatial_discretization
nodeset = merge(nodeset_inner, nodeset_boundary)
A = semi.cache.kernel_matrix
u = A * integrator.u
t = integrator.t
iter = integrator.stats.naccept
filename = joinpath(solution_callback.output_directory, @sprintf("solution_%06d", iter))
add_to_pvd(filename, pvd, t, nodeset, u, extra_functions...; keys = keys)
@timeit timer() "save solution" begin
@unpack pvd, output_directory, extra_functions, keys = solution_callback
semi = integrator.p
@unpack nodeset_inner, nodeset_boundary = semi.spatial_discretization
nodeset = merge(nodeset_inner, nodeset_boundary)
A = semi.cache.kernel_matrix
u = A * integrator.u
t = integrator.t
iter = integrator.stats.naccept
filename = joinpath(solution_callback.output_directory,
@sprintf("solution_%06d", iter))
add_to_pvd(filename, pvd, t, nodeset, u, extra_functions...; keys = keys)
end

# avoid re-evaluating possible FSAL stages
u_modified!(integrator, false)
Expand Down
43 changes: 43 additions & 0 deletions src/callbacks_step/summary.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
SummaryCallback(io::IO = stdout)
Create and return a callback that resets the timer at the beginning of
a simulation and prints the timer values at the end of the simulation.
"""
struct SummaryCallback
io::IO

function SummaryCallback(io::IO = stdout)
summary_callback = new(io)
# SummaryCallback is never called during the simulation
condition = (u, t, integrator) -> false
DiscreteCallback(condition, summary_callback,
save_positions = (false, false),
initialize = initialize_summary_callback,
finalize = finalize_summary_callback)
end
end

function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:SummaryCallback})
@nospecialize cb # reduce precompilation time

print(io, "SummaryCallback")
end

function initialize_summary_callback(cb::DiscreteCallback, u, t, integrator)
reset_timer!(timer())
return nothing
end

# the summary callback does nothing when called accidentally
(cb::SummaryCallback)(integrator) = u_modified!(integrator, false)

# At the end of the simulation, the timer is printed
function finalize_summary_callback(cb::DiscreteCallback, u, t, integrator)
io = cb.affect!.io
TimerOutputs.complement!(timer())
print_timer(io, timer(), title = "KernelInterpolation",
allocations = true, linechars = :unicode, compact = false)
println(io)
return nothing
end
12 changes: 8 additions & 4 deletions src/discretization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,14 @@ Base.eltype(semi::Semidiscretization) = eltype(semi.spatial_discretization)
function rhs!(dc, c, semi, t)
@unpack pde_boundary_matrix = semi.cache
@unpack equations, nodeset_inner, boundary_condition, nodeset_boundary = semi.spatial_discretization
rhs_vector = [rhs(t, nodeset_inner, equations);
boundary_condition.(Ref(t), nodeset_boundary)]
# dc = -pde_boundary_matrix * c + rhs_vector
dc[:] = muladd(pde_boundary_matrix, -c, rhs_vector)
@timeit timer() "rhs!" begin
@timeit timer() "rhs vector" begin
rhs_vector = [rhs(t, nodeset_inner, equations);
boundary_condition.(Ref(t), nodeset_boundary)]
end
# dc = -pde_boundary_matrix * c + rhs_vector
@timeit timer() "muladd" dc[:]=muladd(pde_boundary_matrix, -c, rhs_vector)
end
return nothing
end

Expand Down
6 changes: 6 additions & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,9 @@ end
# https://github.com/JuliaAlgebra/TypedPolynomials.jl/issues/51, instead use the
# workaround from there
polyvars(d) = ntuple(i -> Variable{Symbol("x[", i, "]")}(), d)

# Store main timer for global timing of functions
const main_timer = TimerOutput()

# Always call timer() to hide implementation details
timer() = main_timer
3 changes: 3 additions & 0 deletions test/test_unit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,9 @@ using Plots
@test_nowarn println(save_solution_callback)
@test_nowarn display(save_solution_callback)
@test_throws ArgumentError SaveSolutionCallback(interval = 10, dt = 0.1)
summary_callback = SummaryCallback()
@test_nowarn println(summary_callback)
@test_nowarn display(summary_callback)
end

@testset "Visualization" begin
Expand Down

0 comments on commit 9c5db6a

Please sign in to comment.