Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: draft of multi-threaded modified_equation #139

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 89 additions & 5 deletions src/BSeries.jl
Original file line number Diff line number Diff line change
Expand Up @@ -989,10 +989,11 @@ function _evaluate(f, u, dt, series, ::EagerEvaluation, reduce_order_by)
end

"""
modified_equation(series_integrator)
modified_equation(series_integrator, thread::Bool = Threads.nthreads() > 1)

Compute the B-series of the modified equation of the time integration method
with B-series `series_integrator`.
with B-series `series_integrator` using multiple threads if Julia is started
with multiple threads and `thread` is set to `true`.

Given an ordinary differential equation (ODE) ``u'(t) = f(u(t))`` and a
Runge-Kutta method, the idea is to interpret the numerical solution with
Expand All @@ -1014,11 +1015,41 @@ Section 3.2 of
Foundations of Computational Mathematics
[DOI: 10.1007/s10208-010-9065-1](https://doi.org/10.1007/s10208-010-9065-1)
"""
function modified_equation(series_integrator)
_modified_equation(series_integrator, evaluation_type(series_integrator))
function modified_equation(series_integrator, thread::Bool = Threads.nthreads() > 1)
if thread
_modified_equation_thread(series_integrator,
evaluation_type(series_integrator))
else
_modified_equation_serial(series_integrator,
evaluation_type(series_integrator))
end
end

function _modified_equation_serial(series_integrator, ::EagerEvaluation)
# Setup shared between the serial and threaded versions
series, series_keys, series_ex, iter = _modified_equation_shared(series_integrator)

# Recursively solve
# substitute(series, series_ex, t) == series_integrator[t]
# This works because
# substitute(series, series_ex, t) = series[t] + lower order terms

# Since the `keys` are ordered, we don't need to use nested loops of the form
# for o in 2:order
# for _t in RootedTreeIterator(o)
# t = copy(_t)
# which are slightly less efficient due to additional computations and
# allocations.
while iter !== nothing
t, t_state = iter
series[t] += series_integrator[t] - substitute(series, series_ex, t)
iter = iterate(series_keys, t_state)
end

return series
end

function _modified_equation(series_integrator, ::EagerEvaluation)
@inline function _modified_equation_shared(series_integrator)
V = valtype(series_integrator)

# B-series of the exact solution
Expand Down Expand Up @@ -1050,10 +1081,24 @@ function _modified_equation(series_integrator, ::EagerEvaluation)
iter = iterate(series_keys, t_state)
end

return series, series_keys, series_ex, iter
end

function _modified_equation_thread(series_integrator, ::EagerEvaluation)
# Setup shared between the serial and threaded versions
series, series_keys, series_ex, iter = _modified_equation_shared(series_integrator)

# Recursively solve
# substitute(series, series_ex, t) == series_integrator[t]
# This works because
# substitute(series, series_ex, t) = series[t] + lower order terms

# Here, we use the serial version up to a specified `cutoff_order`, i.e.,
# for low-order trees, since it avoids the parallel overhead. We only use
# the parallel (threaded) version for trees of an order of at least
# `cutoff_order`.
cutoff_order = 5

# Since the `keys` are ordered, we don't need to use nested loops of the form
# for o in 2:order
# for _t in RootedTreeIterator(o)
Expand All @@ -1062,10 +1107,49 @@ function _modified_equation(series_integrator, ::EagerEvaluation)
# allocations.
while iter !== nothing
t, t_state = iter
order(t) >= cutoff_order && break
series[t] += series_integrator[t] - substitute(series, series_ex, t)
iter = iterate(series_keys, t_state)
end

# The algorithm has a data dependency: It is assumed that the coefficients
# of the new `series` are already computed for all trees with a lower order
# than the current tree. Thus, we can use threaded parallelism only over a
# set of trees of the same order.
# for o in cutoff_order:order(series_integrator)
# # We need to collect the trees we will iterate over in a vector for
# # threaded parallelism.
# # TODO: This should be the iterator type specified by the keys
# # of the series_integrator
# trees = map(copy, RootedTreeIterator(o))
# Threads.@threads for t in trees
# series[t] += series_integrator[t] - substitute(series, series_ex, t)
# end
# end

idx_stop = findfirst(==(cutoff_order) ∘ order, series_integrator.coef.keys)
if idx_stop === nothing
return series
else
idx_stop = idx_stop - 1
end
for o in cutoff_order:order(series_integrator)
# TODO: This uses internal implementation details...
idx_start = findnext(==(o) ∘ order, series_integrator.coef.keys, idx_stop)
idx_stop = findnext(==(o + 1) ∘ order, series_integrator.coef.keys, idx_start)
if idx_stop === nothing
idx_stop = lastindex(series_integrator.coef.keys)
end
# We iterate over the indices instead of the trees in the threaded
# loop since that is slightly more efficient at the time of writing
# due to less allocations.
indices = idx_start:idx_stop
Threads.@threads for i in indices
t = @inbounds series_integrator.coef.keys[i]
series[t] += series_integrator[t] - substitute(series, series_ex, t)
end
end

return series
end

Expand Down