Skip to content

Commit

Permalink
Support asynchronous T_lim and T_exp
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 10, 2023
1 parent 6839f2e commit c1a1157
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 56 deletions.
7 changes: 4 additions & 3 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ export ClimaODEFunction, ForwardEulerODEFunction

abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end

Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction
T_lim!::TL = nothing # nothing or (uₜ, u, p, t) -> ...
T_exp!::TE = nothing # nothing or (uₜ, u, p, t) -> ...
Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI, CC} <: AbstractClimaODEFunction
T_lim!::TL = (uₜ, u, p, t) -> nothing
T_exp!::TE = (uₜ, u, p, t) -> nothing
T_imp!::TI = nothing # nothing or (uₜ, u, p, t) -> ...
lim!::L = (u, p, t, u_ref) -> nothing
dss!::D = (u, p, t) -> nothing
post_explicit!::PE = (u, p, t) -> nothing
post_implicit!::PI = (u, p, t) -> nothing
comms_context::CC = nothing
end

# Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work).
Expand Down
75 changes: 48 additions & 27 deletions src/solvers/imex_ark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ function step_u!(integrator, cache::IMEXARKCache)
(; u, p, t, dt, alg) = integrator
(; f) = integrator.sol.prob
(; post_explicit!, post_implicit!) = f
(; comms_context) = f
(; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
(; tableau, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
Expand All @@ -74,19 +75,17 @@ function step_u!(integrator, cache::IMEXARKCache)

@. U = u

if !isnothing(T_lim!) # Update based on limited tendencies from previous stages
for j in 1:(i - 1)
iszero(a_exp[i, j]) && continue
@. U += dt * a_exp[i, j] * T_lim[j]
end
lim!(U, p, t_exp, u)
# Update based on limited tendencies from previous stages
for j in 1:(i - 1)
iszero(a_exp[i, j]) && continue
@. U += dt * a_exp[i, j] * T_lim[j]
end
lim!(U, p, t_exp, u)

if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages
for j in 1:(i - 1)
iszero(a_exp[i, j]) && continue
@. U += dt * a_exp[i, j] * T_exp[j]
end
# Update based on explicit tendencies from previous stages
for j in 1:(i - 1)
iszero(a_exp[i, j]) && continue
@. U += dt * a_exp[i, j] * T_exp[j]
end

if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages
Expand Down Expand Up @@ -147,32 +146,54 @@ function step_u!(integrator, cache::IMEXARKCache)
end

if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i])
if !isnothing(T_lim!)
if isnothing(comms_context)
T_lim!(T_lim[i], U, p, t_exp)
end
if !isnothing(T_exp!)
T_exp!(T_exp[i], U, p, t_exp)
else # do asynchronously

# https://github.com/JuliaLang/julia/issues/40626
if ClimaComms.device(comms_context) isa CUDA.CUDADevice
CUDA.@sync begin
@async begin
T_lim!(T_lim[i], U, p, t_exp)
nothing
end
@async begin
T_exp!(T_exp[i], U, p, t_exp)
nothing
end
end
else
@sync begin
@async begin
T_lim!(T_lim[i], U, p, t_exp)
nothing
end
@async begin
T_exp!(T_exp[i], U, p, t_exp)
nothing
end
end
end
end
end
end

t_final = t + dt

if !isnothing(T_lim!) # Update based on limited tendencies from previous stages
@. temp = u
for j in 1:s
iszero(b_exp[j]) && continue
@. temp += dt * b_exp[j] * T_lim[j]
end
lim!(temp, p, t_final, u)
@. u = temp
# Update based on limited tendencies from previous stages
@. temp = u
for j in 1:s
iszero(b_exp[j]) && continue
@. temp += dt * b_exp[j] * T_lim[j]
end
lim!(temp, p, t_final, u)
@. u = temp

if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages
for j in 1:s
iszero(b_exp[j]) && continue
@. u += dt * b_exp[j] * T_exp[j]
end
# Update based on explicit tendencies from previous stages
for j in 1:s
iszero(b_exp[j]) && continue
@. u += dt * b_exp[j] * T_exp[j]
end

if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages
Expand Down
70 changes: 44 additions & 26 deletions src/solvers/imex_ssprk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{SSP}
s = length(b_exp)
inds = ntuple(i -> i, s)
inds_T_imp = filter(i -> !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]), inds)
U = similar(u0)
U_exp = similar(u0)
T_lim = similar(u0)
T_exp = similar(u0)
U_lim = similar(u0)
T_imp = SparseContainer(map(i -> similar(u0), collect(1:length(inds_T_imp))), inds_T_imp)
temp = similar(u0)
U = zero(u0)
U_exp = zero(u0)
T_lim = zero(u0)
T_exp = zero(u0)
U_lim = zero(u0)
T_imp = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_imp))), inds_T_imp)
temp = zero(u0)
â_exp = vcat(a_exp, b_exp')
β = diag(â_exp, -1)
for i in 1:length(β)
Expand Down Expand Up @@ -56,6 +56,7 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
(; u, p, t, dt, alg) = integrator
(; f) = integrator.sol.prob
(; post_explicit!, post_implicit!) = f
(; comms_context) = f
(; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
(; tableau, newtons_method) = alg
(; a_imp, b_imp, c_exp, c_imp) = tableau
Expand Down Expand Up @@ -83,14 +84,10 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
if i == 1
@. U_exp = u
elseif !iszero(β[i - 1])
if !isnothing(T_lim!)
@. U_lim = U_exp + dt * T_lim
lim!(U_lim, p, t_exp, U_exp)
@. U_exp = U_lim
end
if !isnothing(T_exp!)
@. U_exp += dt * T_exp
end
@. U_lim = U_exp + dt * T_lim
lim!(U_lim, p, t_exp, U_exp)
@. U_exp = U_lim
@. U_exp += dt * T_exp
@. U_exp = (1 - β[i - 1]) * u + β[i - 1] * U_exp
end

Expand Down Expand Up @@ -153,26 +150,47 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
end

if !iszero(β[i])
if !isnothing(T_lim!)
if isnothing(comms_context)
T_lim!(T_lim, U, p, t_exp)
end
if !isnothing(T_exp!)
T_exp!(T_exp, U, p, t_exp)
else

# https://github.com/JuliaLang/julia/issues/40626
if ClimaComms.device(comms_context) isa CUDA.CUDADevice
CUDA.@sync begin
@async begin
T_lim!(T_lim, U, p, t_exp)
nothing
end
@async begin
T_exp!(T_exp, U, p, t_exp)
nothing
end
end
else
@sync begin
@async begin
T_lim!(T_lim, U, p, t_exp)
nothing
end
@async begin
T_exp!(T_exp, U, p, t_exp)
nothing
end
end
end

end
end
end

t_final = t + dt

if !iszero(β[s])
if !isnothing(T_lim!)
@. U_lim = U_exp + dt * T_lim
lim!(U_lim, p, t_final, U_exp)
@. U_exp = U_lim
end
if !isnothing(T_exp!)
@. U_exp += dt * T_exp
end
@. U_lim = U_exp + dt * T_lim
lim!(U_lim, p, t_final, U_exp)
@. U_exp = U_lim
@. U_exp += dt * T_exp
@. u = (1 - β[s]) * u + β[s] * U_exp
end

Expand Down

0 comments on commit c1a1157

Please sign in to comment.