diff --git a/ext/ClimaTimeSteppersBenchmarkToolsExt.jl b/ext/ClimaTimeSteppersBenchmarkToolsExt.jl index 9a08994d..16bba021 100644 --- a/ext/ClimaTimeSteppersBenchmarkToolsExt.jl +++ b/ext/ClimaTimeSteppersBenchmarkToolsExt.jl @@ -35,8 +35,8 @@ n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict( "T_exp_T_lim!" => 4, "lim!" => 4, "dss!" => 4, - "post_explicit!" => 3, - "post_implicit!" => 4, + "pre_explicit!" => 3, + "pre_implicit!" => 4, "step!" => 1, ) function n_calls_per_step(alg::CTS.RosenbrockAlgorithm) @@ -47,8 +47,8 @@ function n_calls_per_step(alg::CTS.RosenbrockAlgorithm) "T_exp_T_lim!" => CTS.n_stages(alg.tableau), "lim!" => 0, "dss!" => CTS.n_stages(alg.tableau), - "post_explicit!" => 0, - "post_implicit!" => CTS.n_stages(alg.tableau), + "pre_explicit!" => 0, + "pre_implicit!" => CTS.n_stages(alg.tableau), "step!" => 1, ) end @@ -60,7 +60,7 @@ function maybe_push!(trials₀, name, f!, args, kwargs, only) end const allowed_names = - ["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "post_explicit!", "post_implicit!", "step!"] + ["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "pre_explicit!", "pre_implicit!", "step!"] """ benchmark_step( @@ -89,8 +89,8 @@ Benchmark a DistributedODEIntegrator given: - "T_exp_T_lim!" - "lim!" - "dss!" - - "post_explicit!" - - "post_implicit!" + - "pre_explicit!" + - "pre_implicit!" - "step!" """ function CTS.benchmark_step( @@ -123,8 +123,8 @@ function CTS.benchmark_step( maybe_push!(trials₀, "T_exp_T_lim!", remaining_fun(integrator), remaining_args(integrator), kwargs, only) maybe_push!(trials₀, "lim!", f.lim!, (Xlim, p, t, u), kwargs, only) maybe_push!(trials₀, "dss!", f.dss!, (u, p, t), kwargs, only) - maybe_push!(trials₀, "post_explicit!", f.post_explicit!, (u, p, t), kwargs, only) - maybe_push!(trials₀, "post_implicit!", f.post_implicit!, (u, p, t), kwargs, only) + maybe_push!(trials₀, "pre_explicit!", f.pre_explicit!, (u, p, t), kwargs, only) + maybe_push!(trials₀, "pre_implicit!", f.pre_implicit!, (u, p, t), kwargs, only) maybe_push!(trials₀, "step!", SciMLBase.step!, (integrator, ), kwargs, only) #! format: on diff --git a/src/functions.jl b/src/functions.jl index 04c92f4d..d8a6c5bb 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -11,8 +11,8 @@ struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFuncti T_imp!::TI lim!::L dss!::D - post_explicit!::PE - post_implicit!::PI + pre_explicit!::PE + pre_implicit!::PI function ClimaODEFunction(; T_exp_T_lim! = nothing, # nothing or (uₜ_exp, uₜ_lim, u, p, t) -> ... T_lim! = nothing, # nothing or (uₜ, u, p, t) -> ... @@ -20,10 +20,10 @@ struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFuncti T_imp! = nothing, # nothing or (uₜ, u, p, t) -> ... lim! = (u, p, t, u_ref) -> nothing, dss! = (u, p, t) -> nothing, - post_explicit! = (u, p, t) -> nothing, - post_implicit! = (u, p, t) -> nothing, + pre_explicit! = (u, p, t) -> nothing, + pre_implicit! = (u, p, t) -> nothing, ) - args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!) + args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, pre_explicit!, pre_implicit!) if !isnothing(T_exp_T_lim!) @assert isnothing(T_exp!) "`T_exp_T_lim!` was passed, `T_exp!` must be `nothing`" diff --git a/src/integrators.jl b/src/integrators.jl index 09345304..24128005 100644 --- a/src/integrators.jl +++ b/src/integrators.jl @@ -147,8 +147,8 @@ function DiffEqBase.__init( tdir, ) if prob.f isa ClimaODEFunction - (; post_explicit!) = prob.f - isnothing(post_explicit!) || post_explicit!(u0, p, t0) + (; pre_explicit!) = prob.f + isnothing(pre_explicit!) || pre_explicit!(u0, p, t0) end DiffEqBase.initialize!(callback, u0, t0, integrator) return integrator diff --git a/src/nl_solvers/newtons_method.jl b/src/nl_solvers/newtons_method.jl index 479487e4..e0bcf795 100644 --- a/src/nl_solvers/newtons_method.jl +++ b/src/nl_solvers/newtons_method.jl @@ -130,7 +130,7 @@ struct ForwardDiffStepSize3 <: ForwardDiffStepSize end Computes the Jacobian-vector product `j(x[n]) * Δx[n]` for a Newton-Krylov method without directly using the Jacobian `j(x[n])`, and instead only using `x[n]`, `f(x[n])`, and other function evaluations `f(x′)`. This is done by -calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)`. +calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, pre_implicit!)`. The `jΔx` passed to a Jacobian-free JVP is modified in-place. The `cache` can be obtained with `allocate_cache(::JacobianFreeJVP, x_prototype)`, where `x_prototype` is `similar` to `x` (and also to `Δx` and `f`). @@ -151,13 +151,13 @@ end allocate_cache(::ForwardDiffJVP, x_prototype) = (; x2 = similar(x_prototype), f2 = similar(x_prototype)) -function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, post_implicit!) +function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, pre_implicit!) (; default_step, step_adjustment) = alg (; x2, f2) = cache FT = eltype(x) ε = FT(step_adjustment) * default_step(Δx, x) @. x2 = x + ε * Δx - isnothing(post_implicit!) || post_implicit!(x2) + isnothing(pre_implicit!) || pre_implicit!(x2) f!(f2, x2) @. jΔx = (f2 - f) / ε end @@ -343,7 +343,7 @@ end Finds an approximation `Δx[n] ≈ j(x[n]) \\ f(x[n])` for Newton's method such that `‖f(x[n]) - j(x[n]) * Δx[n]‖ ≤ rtol[n] * ‖f(x[n])‖`, where `rtol[n]` is the value of the forcing term on iteration `n`. This is done by calling -`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)`, +`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, pre_implicit!, j = nothing)`, where `f` is `f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an approximation of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place. The `cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`, @@ -428,14 +428,14 @@ function allocate_cache(alg::KrylovMethod, x_prototype) ) end -NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing) +NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, pre_implicit!, j = nothing) (; jacobian_free_jvp, forcing_term, solve_kwargs) = alg (; disable_preconditioner, debugger) = alg type = solver_type(alg) (; jacobian_free_jvp_cache, forcing_term_cache, solver, debugger_cache) = cache jΔx!(jΔx, Δx) = isnothing(jacobian_free_jvp) ? mul!(jΔx, j, Δx) : - jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, post_implicit!) + jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, pre_implicit!) opj = LinearOperator(eltype(x), length(x), length(x), false, false, jΔx!) M = disable_preconditioner || isnothing(j) || isnothing(jacobian_free_jvp) ? I : j print_debug!(debugger, debugger_cache, opj, M) @@ -567,32 +567,22 @@ function allocate_cache(alg::NewtonsMethod, x_prototype, j_prototype = nothing) ) end -solve_newton!( - alg::NewtonsMethod, - cache::Nothing, - x, - f!, - j! = nothing, - post_implicit! = nothing, - post_implicit_last! = nothing, -) = nothing - -NVTX.@annotate function solve_newton!( - alg::NewtonsMethod, - cache, - x, - f!, - j! = nothing, - post_implicit! = nothing, - post_implicit_last! = nothing, -) +solve_newton!(alg::NewtonsMethod, cache::Nothing, x, f!, j! = nothing, pre_implicit! = nothing) = nothing + +NVTX.@annotate function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing, pre_implicit! = nothing) (; max_iters, update_j, krylov_method, convergence_checker, verbose) = alg (; krylov_method_cache, convergence_checker_cache) = cache (; Δx, f, j) = cache - if (!isnothing(j)) && needs_update!(update_j, NewNewtonSolve()) - j!(j, x) + if !isnothing(pre_implicit!) && !isempty(1:max_iters) + pre_implicit!(x) + if (!isnothing(j)) && needs_update!(update_j, NewNewtonSolve()) + j!(j, x) + end end for n in 1:max_iters + if !isnothing(pre_implicit!) + n ≠ 1 && pre_implicit!(x) + end # Compute Δx[n]. if (!isnothing(j)) && needs_update!(update_j, NewNewtonIteration()) j!(j, x) @@ -605,7 +595,7 @@ NVTX.@annotate function solve_newton!( ldiv!(Δx, j, f) end else - solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, post_implicit!, j) + solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, pre_implicit!, j) end is_verbose(verbose) && @info "Newton iteration $n: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))" @@ -613,12 +603,7 @@ NVTX.@annotate function solve_newton!( # Update x[n] with Δx[n - 1], and exit the loop if Δx[n] is not needed. # Check for convergence if necessary. if is_converged!(convergence_checker, convergence_checker_cache, x, Δx, n) - isnothing(post_implicit_last!) || post_implicit_last!(x) break - elseif n == max_iters - isnothing(post_implicit_last!) || post_implicit_last!(x) - else - isnothing(post_implicit!) || post_implicit!(x) end if is_verbose(verbose) && n == max_iters @warn "Newton's method did not converge within $n iterations: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))" diff --git a/src/solvers/hard_coded_ars343.jl b/src/solvers/hard_coded_ars343.jl index f82ca418..d3a091e3 100644 --- a/src/solvers/hard_coded_ars343.jl +++ b/src/solvers/hard_coded_ars343.jl @@ -4,7 +4,7 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) (; u, p, t, dt, sol, alg) = integrator (; f) = sol.prob (; T_imp!, lim!, dss!) = f - (; post_explicit!, post_implicit!) = f + (; pre_explicit!, pre_implicit!) = f (; tableau, newtons_method) = alg (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau (; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache @@ -34,7 +34,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) lim!(U, p, t_exp, u) @. U += dt * a_exp[i, 1] * T_exp[1] dss!(U, p, t_exp) - post_explicit!(U, p, t_exp) @. temp = U # used in closures let i = i @@ -46,8 +45,8 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) implicit_equation_jacobian! = (jacobian, Ui) -> begin T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp) end - call_post_implicit! = Ui -> begin - post_implicit!(Ui, p, t_imp) + call_pre_implicit! = Ui -> begin + pre_implicit!(Ui, p, t_imp) end solve_newton!( newtons_method, @@ -55,12 +54,13 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) U, implicit_equation_residual!, implicit_equation_jacobian!, - call_post_implicit!, + call_pre_implicit!, ) end @. T_imp[i] = (U - temp) / (dt * a_imp[i, i]) + pre_explicit!(U, p, t_exp) T_lim!(T_lim[i], U, p, t_exp) T_exp!(T_exp[i], U, p, t_exp) @@ -70,7 +70,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) lim!(U, p, t_exp, u) @. U += dt * a_exp[i, 1] * T_exp[1] + dt * a_exp[i, 2] * T_exp[2] + dt * a_imp[i, 2] * T_imp[2] dss!(U, p, t_exp) - post_explicit!(U, p, t_exp) @. temp = U # used in closures let i = i @@ -82,8 +81,8 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) implicit_equation_jacobian! = (jacobian, Ui) -> begin T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp) end - call_post_implicit! = Ui -> begin - post_implicit!(Ui, p, t_imp) + call_pre_implicit! = Ui -> begin + pre_implicit!(Ui, p, t_imp) end solve_newton!( newtons_method, @@ -91,12 +90,13 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) U, implicit_equation_residual!, implicit_equation_jacobian!, - call_post_implicit!, + call_pre_implicit!, ) end @. T_imp[i] = (U - temp) / (dt * a_imp[i, i]) + pre_explicit!(U, p, t_exp) T_lim!(T_lim[i], U, p, t_exp) T_exp!(T_exp[i], U, p, t_exp) i = 4 @@ -110,7 +110,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) dt * a_imp[i, 2] * T_imp[2] + dt * a_imp[i, 3] * T_imp[3] dss!(U, p, t_exp) - post_explicit!(U, p, t_exp) @. temp = U # used in closures let i = i @@ -122,8 +121,8 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) implicit_equation_jacobian! = (jacobian, Ui) -> begin T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp) end - call_post_implicit! = Ui -> begin - post_implicit!(Ui, p, t_imp) + call_pre_implicit! = Ui -> begin + pre_implicit!(Ui, p, t_imp) end solve_newton!( newtons_method, @@ -131,12 +130,13 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) U, implicit_equation_residual!, implicit_equation_jacobian!, - call_post_implicit!, + call_pre_implicit!, ) end @. T_imp[i] = (U - temp) / (dt * a_imp[i, i]) + pre_explicit!(U, p, t_exp) T_lim!(T_lim[i], U, p, t_exp) T_exp!(T_exp[i], U, p, t_exp) @@ -155,6 +155,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) dt * b_imp[3] * T_imp[3] + dt * b_imp[4] * T_imp[4] dss!(u, p, t_final) - post_explicit!(u, p, t_final) + pre_explicit!(U, p, t_final) return u end diff --git a/src/solvers/imex_ark.jl b/src/solvers/imex_ark.jl index 4c2d24f9..ab494495 100644 --- a/src/solvers/imex_ark.jl +++ b/src/solvers/imex_ark.jl @@ -49,7 +49,7 @@ end function step_u!(integrator, cache::IMEXARKCache) (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob - (; post_explicit!, post_implicit!) = f + (; pre_explicit!, pre_implicit!) = f (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f (; tableau, newtons_method) = alg (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau @@ -63,6 +63,7 @@ function step_u!(integrator, cache::IMEXARKCache) if γ isa Nothing sdirk_error(name) else + pre_implicit!(u, p, t) T_imp!.Wfact(jacobian, u, p, dt * γ, t) end end @@ -83,7 +84,9 @@ function step_u!(integrator, cache::IMEXARKCache) isnothing(T_imp!) || fused_increment!(u, dt, b_imp, T_imp, Val(s)) dss!(u, p, t_final) - post_explicit!(u, p, t_final) + # this `pre_explicit!` call perpares the cache `p` for both + # the callbacks and the beginning of the next timestep + pre_explicit!(u, p, t_final) return u end @@ -98,7 +101,7 @@ end @inline function update_stage!(integrator, cache::IMEXARKCache, i::Int) (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob - (; post_explicit!, post_implicit!) = f + (; pre_explicit!, pre_implicit!) = f (; T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!) = f (; tableau, newtons_method) = alg (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau @@ -116,34 +119,22 @@ end end # Update based on tendencies from previous stages + i ≠ 1 && pre_explicit!(U, p, t_exp) # pre_explicit! was called at the end of the previous step! has_T_exp(f) && fused_increment!(U, dt, a_exp, T_exp, Val(i)) isnothing(T_imp!) || fused_increment!(U, dt, a_imp, T_imp, Val(i)) i ≠ 1 && dss!(U, p, t_exp) - if !(!isnothing(T_imp!) && !iszero(a_imp[i, i])) - i ≠ 1 && post_explicit!(U, p, t_imp) - else # Implicit solve + if (!isnothing(T_imp!) && !iszero(a_imp[i, i])) # Implicit solve @assert !isnothing(newtons_method) @. temp = U - i ≠ 1 && post_explicit!(U, p, t_imp) # TODO: can/should we remove these closures? implicit_equation_residual! = (residual, Ui) -> begin T_imp!(residual, Ui, p, t_imp) @. residual = temp + dt * a_imp[i, i] * residual - Ui end implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp) - call_post_implicit! = Ui -> begin - post_implicit!(Ui, p, t_imp) - end - call_post_implicit_last! = Ui -> begin - if (!all(iszero, a_imp[:, i]) || !iszero(b_imp[i])) && !iszero(a_imp[i, i]) - # If T_imp[i] is being treated implicitly, ensure that it - # exactly satisfies the implicit equation. - @. T_imp[i] = (Ui - temp) / (dt * a_imp[i, i]) - end - post_implicit!(Ui, p, t_imp) - end + call_pre_implicit! = Ui -> pre_implicit!(Ui, p, t_imp) solve_newton!( newtons_method, @@ -151,8 +142,7 @@ end U, implicit_equation_residual!, implicit_equation_jacobian!, - call_post_implicit!, - call_post_implicit_last!, + call_pre_implicit!, ) end @@ -160,18 +150,25 @@ end # give the same results for redundant columns (as long as the implicit # tendency only acts in the vertical direction). - if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]) - if iszero(a_imp[i, i]) && !isnothing(T_imp!) + if (!all(iszero, a_imp[:, i]) || !iszero(b_imp[i])) && !isnothing(T_imp!) + if iszero(a_imp[i, i]) # If its coefficient is 0, T_imp[i] is effectively being # treated explicitly. T_imp!(T_imp[i], U, p, t_imp) + else + # If T_imp[i] is being treated implicitly, ensure that it + # exactly satisfies the implicit equation. + @. T_imp[i] = (U - temp) / (dt * a_imp[i, i]) end end - if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i]) if !isnothing(T_exp_T_lim!) + pre_explicit!(U, p, t_exp) T_exp_T_lim!(T_exp[i], T_lim[i], U, p, t_exp) else + if !isnothing(T_lim!) || !isnothing(T_exp!) + pre_explicit!(U, p, t_exp) + end isnothing(T_lim!) || T_lim!(T_lim[i], U, p, t_exp) isnothing(T_exp!) || T_exp!(T_exp[i], U, p, t_exp) end diff --git a/src/solvers/imex_ssprk.jl b/src/solvers/imex_ssprk.jl index 0395991e..8a89a5cd 100644 --- a/src/solvers/imex_ssprk.jl +++ b/src/solvers/imex_ssprk.jl @@ -55,7 +55,7 @@ end function step_u!(integrator, cache::IMEXSSPRKCache) (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob - (; post_explicit!, post_implicit!) = f + (; pre_explicit!, pre_implicit!) = f (; T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!) = f (; tableau, newtons_method) = alg (; a_imp, b_imp, c_exp, c_imp) = tableau @@ -104,30 +104,16 @@ function step_u!(integrator, cache::IMEXSSPRKCache) end end - if !(!isnothing(T_imp!) && !iszero(a_imp[i, i])) - i ≠ 1 && post_explicit!(U, p, t_imp) - else # Implicit solve + if (!isnothing(T_imp!) && !iszero(a_imp[i, i])) # Implicit solve @assert !isnothing(newtons_method) @. temp = U - post_explicit!(U, p, t_imp) # TODO: can/should we remove these closures? implicit_equation_residual! = (residual, Ui) -> begin T_imp!(residual, Ui, p, t_imp) @. residual = temp + dt * a_imp[i, i] * residual - Ui end implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp) - call_post_implicit! = Ui -> begin - post_implicit!(Ui, p, t_imp) - end - call_post_implicit_last! = - Ui -> begin - if (!all(iszero, a_imp[:, i]) || !iszero(b_imp[i])) && !iszero(a_imp[i, i]) - # If T_imp[i] is being treated implicitly, ensure that it - # exactly satisfies the implicit equation. - @. T_imp[i] = (Ui - temp) / (dt * a_imp[i, i]) - end - post_implicit!(Ui, p, t_imp) - end + call_pre_implicit! = Ui -> pre_implicit!(Ui, p, t_imp) solve_newton!( newtons_method, @@ -135,8 +121,7 @@ function step_u!(integrator, cache::IMEXSSPRKCache) U, implicit_equation_residual!, implicit_equation_jacobian!, - call_post_implicit!, - call_post_implicit_last!, + call_pre_implicit!, ) end @@ -144,18 +129,26 @@ function step_u!(integrator, cache::IMEXSSPRKCache) # give the same results for redundant columns (as long as the implicit # tendency only acts in the vertical direction). - if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]) - if iszero(a_imp[i, i]) && !isnothing(T_imp!) + if (!all(iszero, a_imp[:, i]) || !iszero(b_imp[i])) && !isnothing(T_imp!) + if iszero(a_imp[i, i]) # If its coefficient is 0, T_imp[i] is effectively being # treated explicitly. T_imp!(T_imp[i], U, p, t_imp) + else + # If T_imp[i] is being treated implicitly, ensure that it + # exactly satisfies the implicit equation. + @. T_imp[i] = (U - temp) / (dt * a_imp[i, i]) end end if !iszero(β[i]) if !isnothing(T_exp_T_lim!) + pre_explicit!(U, p, t_exp) T_exp_T_lim!(T_lim, T_exp, U, p, t_exp) else + if !isnothing(T_lim!) || !isnothing(T_exp!) + pre_explicit!(U, p, t_exp) + end isnothing(T_lim!) || T_lim!(T_lim, U, p, t_exp) isnothing(T_exp!) || T_exp!(T_exp, U, p, t_exp) end @@ -184,7 +177,7 @@ function step_u!(integrator, cache::IMEXSSPRKCache) end dss!(u, p, t_final) - post_explicit!(u, p, t_final) + pre_explicit!(u, p, t_final) return u end diff --git a/src/solvers/rosenbrock.jl b/src/solvers/rosenbrock.jl index 78f8d9e9..34bb636d 100644 --- a/src/solvers/rosenbrock.jl +++ b/src/solvers/rosenbrock.jl @@ -123,7 +123,7 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages} T_exp_lim! = int.sol.prob.f.T_exp_T_lim! tgrad! = isnothing(T_imp!) ? nothing : T_imp!.tgrad - (; post_explicit!, post_implicit!, dss!) = int.sol.prob.f + (; pre_explicit!, pre_implicit!, dss!) = int.sol.prob.f # TODO: This is only valid when Γ[i, i] is constant, otherwise we have to # move this in the for loop @@ -150,13 +150,12 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages} U .+= A[i, j] .* k[j] end - # NOTE: post_implicit! is a misnomer - if !isnothing(post_implicit!) + if !isnothing(pre_explicit!) # We update p on every stage but the first, and at the end of each # timestep. Since the first stage is unchanged from the end of the # previous timestep, this order of operations ensures that p is # always consistent with the state, including between timesteps. - (i != 1) && post_implicit!(U, p, t + αi * dt) + (i != 1) && pre_explicit!(U, p, t + αi * dt) end if !isnothing(T_imp!) @@ -205,7 +204,7 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages} end dss!(u, p, t + dt) - post_implicit!(u, p, t + dt) + pre_explicit!(u, p, t + dt) return nothing end diff --git a/test/problems.jl b/test/problems.jl index 74155321..3660988a 100644 --- a/test/problems.jl +++ b/test/problems.jl @@ -493,8 +493,8 @@ function climacore_2Dheat_test_cts(::Type{FT}) where {FT} # we add implicit pieces here for inference analysis T_lim! = (Yₜ, u, _, t) -> nothing - post_implicit! = (u, _, t) -> nothing - post_explicit! = (u, _, t) -> nothing + pre_implicit! = (u, _, t) -> nothing + pre_explicit! = (u, _, t) -> nothing jacobian = ClimaCore.MatrixFields.FieldMatrix((@name(u), @name(u)) => FT(-1) * LinearAlgebra.I) @@ -505,7 +505,7 @@ function climacore_2Dheat_test_cts(::Type{FT}) where {FT} tgrad = (∂Y∂t, Y, p, t) -> (∂Y∂t .= 0), ) - tendency_func = ClimaODEFunction(; T_exp!, T_imp!, dss!, post_implicit!, post_explicit!) + tendency_func = ClimaODEFunction(; T_exp!, T_imp!, dss!, pre_implicit!, pre_explicit!) split_tendency_func = tendency_func make_prob(func) = ODEProblem(func, init_state, (FT(0), t_end), nothing) IntegratorTestCase(