diff --git a/Project.toml b/Project.toml index 550d9aaab..e3314c75a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ClimaTimeSteppers" uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79" authors = ["Climate Modeling Alliance"] -version = "0.7.14" +version = "0.7.15" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/functions.jl b/src/functions.jl index 4e80446c2..8815c5c21 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -4,14 +4,34 @@ export ClimaODEFunction, ForwardEulerODEFunction abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end -Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction - T_lim!::TL = nothing # nothing or (uₜ, u, p, t) -> ... - T_exp!::TE = nothing # nothing or (uₜ, u, p, t) -> ... - T_imp!::TI = nothing # nothing or (uₜ, u, p, t) -> ... - lim!::L = (u, p, t, u_ref) -> nothing - dss!::D = (u, p, t) -> nothing - post_explicit!::PE = (u, p, t) -> nothing - post_implicit!::PI = (u, p, t) -> nothing +struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI, CC} <: AbstractClimaODEFunction + T_lim!::TL + T_exp!::TE + T_imp!::TI + lim!::L + dss!::D + post_explicit!::PE + post_implicit!::PI + comms_context::CC +end +function ClimaODEFunction(; + T_lim! = nothing, + T_exp! = nothing, + T_imp! = nothing, + lim! = nothing, + dss! = nothing, + post_explicit! = nothing, + post_implicit! = nothing, + comms_context = nothing, +) + isnothing(T_lim!) && (T_lim! = (uₜ, u, p, t) -> nothing) + isnothing(T_exp!) && (T_exp! = (uₜ, u, p, t) -> nothing) + T_imp! = nothing + isnothing(lim!) && (lim! = (u, p, t, u_ref) -> nothing) + isnothing(dss!) && (dss! = (u, p, t) -> nothing) + isnothing(post_explicit!) && (post_explicit! = (u, p, t) -> nothing) + isnothing(post_implicit!) && (post_implicit! = (u, p, t) -> nothing) + return ClimaODEFunction(T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!, comms_context) end # Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work). diff --git a/src/solvers/imex_ark.jl b/src/solvers/imex_ark.jl index 32253ecf4..c23db3dbc 100644 --- a/src/solvers/imex_ark.jl +++ b/src/solvers/imex_ark.jl @@ -50,6 +50,7 @@ function step_u!(integrator, cache::IMEXARKCache) (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob (; post_explicit!, post_implicit!) = f + (; comms_context) = f (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f (; tableau, newtons_method) = alg (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau @@ -74,19 +75,17 @@ function step_u!(integrator, cache::IMEXARKCache) @. U = u - if !isnothing(T_lim!) # Update based on limited tendencies from previous stages - for j in 1:(i - 1) - iszero(a_exp[i, j]) && continue - @. U += dt * a_exp[i, j] * T_lim[j] - end - lim!(U, p, t_exp, u) + # Update based on limited tendencies from previous stages + for j in 1:(i - 1) + iszero(a_exp[i, j]) && continue + @. U += dt * a_exp[i, j] * T_lim[j] end + lim!(U, p, t_exp, u) - if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages - for j in 1:(i - 1) - iszero(a_exp[i, j]) && continue - @. U += dt * a_exp[i, j] * T_exp[j] - end + # Update based on explicit tendencies from previous stages + for j in 1:(i - 1) + iszero(a_exp[i, j]) && continue + @. U += dt * a_exp[i, j] * T_exp[j] end if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages @@ -147,32 +146,54 @@ function step_u!(integrator, cache::IMEXARKCache) end if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i]) - if !isnothing(T_lim!) + if isnothing(comms_context) T_lim!(T_lim[i], U, p, t_exp) - end - if !isnothing(T_exp!) T_exp!(T_exp[i], U, p, t_exp) + else # do asynchronously + + # https://github.com/JuliaLang/julia/issues/40626 + if ClimaComms.device(comms_context) isa CUDA.CUDADevice + CUDA.@sync begin + @async begin + T_lim!(T_lim[i], U, p, t_exp) + nothing + end + @async begin + T_exp!(T_exp[i], U, p, t_exp) + nothing + end + end + else + @sync begin + @async begin + T_lim!(T_lim[i], U, p, t_exp) + nothing + end + @async begin + T_exp!(T_exp[i], U, p, t_exp) + nothing + end + end + end end end end t_final = t + dt - if !isnothing(T_lim!) # Update based on limited tendencies from previous stages - @. temp = u - for j in 1:s - iszero(b_exp[j]) && continue - @. temp += dt * b_exp[j] * T_lim[j] - end - lim!(temp, p, t_final, u) - @. u = temp + # Update based on limited tendencies from previous stages + @. temp = u + for j in 1:s + iszero(b_exp[j]) && continue + @. temp += dt * b_exp[j] * T_lim[j] end + lim!(temp, p, t_final, u) + @. u = temp - if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages - for j in 1:s - iszero(b_exp[j]) && continue - @. u += dt * b_exp[j] * T_exp[j] - end + # Update based on explicit tendencies from previous stages + for j in 1:s + iszero(b_exp[j]) && continue + @. u += dt * b_exp[j] * T_exp[j] end if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages diff --git a/src/solvers/imex_ssprk.jl b/src/solvers/imex_ssprk.jl index 646889ba0..f5e6d1bde 100644 --- a/src/solvers/imex_ssprk.jl +++ b/src/solvers/imex_ssprk.jl @@ -19,13 +19,13 @@ function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{SSP} s = length(b_exp) inds = ntuple(i -> i, s) inds_T_imp = filter(i -> !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]), inds) - U = similar(u0) - U_exp = similar(u0) - T_lim = similar(u0) - T_exp = similar(u0) - U_lim = similar(u0) - T_imp = SparseContainer(map(i -> similar(u0), collect(1:length(inds_T_imp))), inds_T_imp) - temp = similar(u0) + U = zero(u0) + U_exp = zero(u0) + T_lim = zero(u0) + T_exp = zero(u0) + U_lim = zero(u0) + T_imp = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_imp))), inds_T_imp) + temp = zero(u0) â_exp = vcat(a_exp, b_exp') β = diag(â_exp, -1) for i in 1:length(β) @@ -56,6 +56,7 @@ function step_u!(integrator, cache::IMEXSSPRKCache) (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob (; post_explicit!, post_implicit!) = f + (; comms_context) = f (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f (; tableau, newtons_method) = alg (; a_imp, b_imp, c_exp, c_imp) = tableau @@ -83,14 +84,10 @@ function step_u!(integrator, cache::IMEXSSPRKCache) if i == 1 @. U_exp = u elseif !iszero(β[i - 1]) - if !isnothing(T_lim!) - @. U_lim = U_exp + dt * T_lim - lim!(U_lim, p, t_exp, U_exp) - @. U_exp = U_lim - end - if !isnothing(T_exp!) - @. U_exp += dt * T_exp - end + @. U_lim = U_exp + dt * T_lim + lim!(U_lim, p, t_exp, U_exp) + @. U_exp = U_lim + @. U_exp += dt * T_exp @. U_exp = (1 - β[i - 1]) * u + β[i - 1] * U_exp end @@ -153,11 +150,36 @@ function step_u!(integrator, cache::IMEXSSPRKCache) end if !iszero(β[i]) - if !isnothing(T_lim!) + if isnothing(comms_context) T_lim!(T_lim, U, p, t_exp) - end - if !isnothing(T_exp!) T_exp!(T_exp, U, p, t_exp) + else + + # https://github.com/JuliaLang/julia/issues/40626 + if ClimaComms.device(comms_context) isa CUDA.CUDADevice + CUDA.@sync begin + @async begin + T_lim!(T_lim, U, p, t_exp) + nothing + end + @async begin + T_exp!(T_exp, U, p, t_exp) + nothing + end + end + else + @sync begin + @async begin + T_lim!(T_lim, U, p, t_exp) + nothing + end + @async begin + T_exp!(T_exp, U, p, t_exp) + nothing + end + end + end + end end end @@ -165,14 +187,10 @@ function step_u!(integrator, cache::IMEXSSPRKCache) t_final = t + dt if !iszero(β[s]) - if !isnothing(T_lim!) - @. U_lim = U_exp + dt * T_lim - lim!(U_lim, p, t_final, U_exp) - @. U_exp = U_lim - end - if !isnothing(T_exp!) - @. U_exp += dt * T_exp - end + @. U_lim = U_exp + dt * T_lim + lim!(U_lim, p, t_final, U_exp) + @. U_exp = U_lim + @. U_exp += dt * T_exp @. u = (1 - β[s]) * u + β[s] * U_exp end