Skip to content

Commit

Permalink
Add support for async T_exp and T_lim
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Feb 29, 2024
1 parent 31f0bc9 commit 9b0e8a2
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
CudaExt = "CUDA"

[compat]
ClimaComms = "0.4, 0.5"
Colors = "0.12"
Expand Down
60 changes: 60 additions & 0 deletions ext/CudaExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module CudaExt

import CUDA
import ClimaComms: SingletonCommsContext, CUDADevice
import ClimaTimeSteppers: compute_T_lim_T_exp!

@inline function compute_T_lim_T_exp!(T_lim, T_exp, U, p, t, T_lim!, T_exp!, ::SingletonCommsContext{CUDADevice})
# TODO: we should benchmark these two options to
# see if one is preferrable over the other
if Base.Threads.nthreads() > 1
compute_T_lim_T_exp_spawn!(T_lim, T_exp, U, p, t, T_lim!, T_exp!)
else
compute_T_lim_T_exp_streams!(T_lim, T_exp, U, p, t, T_lim!, T_exp!)
end
end

@inline function compute_T_lim_T_exp_streams!(T_lim, T_exp, U, p, t, T_lim!, T_exp!)
event = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
CUDA.record(event, CUDA.stream()) # record event on main stream

stream1 = CUDA.CuStream() # make a stream
local event1
CUDA.stream!(stream1) do # work to be done by stream1
CUDA.wait(event, stream1) # make stream1 wait on event (host continues)
T_lim!(T_lim, U, p, t)
event1 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
end
CUDA.record(event1, stream1) # record event1 on stream1

stream2 = CUDA.CuStream() # make a stream
local event2
CUDA.stream!(stream2) do # work to be done by stream2
CUDA.wait(event, stream2) # make stream2 wait on event (host continues)
T_exp!(T_exp, U, p, t)
event2 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
end
CUDA.record(event2, stream2) # record event2 on stream2

CUDA.wait(event1, CUDA.stream()) # make main stream wait on event1
CUDA.wait(event2, CUDA.stream()) # make main stream wait on event2
end

@inline function compute_T_lim_T_exp_spawn!(T_lim, T_exp, U, p, t, T_lim!, T_exp!)

CUDA.synchronize()
CUDA.@sync begin
Base.Threads.@spawn begin
T_lim!(T_lim, U, p, t)
CUDA.synchronize()
nothing
end
Base.Threads.@spawn begin
T_exp!(T_exp, U, p, t)
CUDA.synchronize()
nothing
end
end
end

end
22 changes: 22 additions & 0 deletions src/solvers/compute_T_exp_T_lim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,25 @@
T_lim!(T_lim, U, p, t)
T_exp!(T_exp, U, p, t)
end

@inline function compute_T_lim_T_exp!(
T_lim,
T_exp,
U,
p,
t,
T_lim!,
T_exp!,
::ClimaComms.SingletonCommsContext{ClimaComms.CPUMultiThreaded},
)
Base.@sync begin
Base.Threads.@spawn begin
T_lim!(T_lim, U, p, t)
nothing
end
Base.Threads.@spawn begin
T_exp!(T_exp, U, p, t)
nothing
end
end
end

0 comments on commit 9b0e8a2

Please sign in to comment.