-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for async T_exp and T_lim
- Loading branch information
1 parent
31f0bc9
commit 9b0e8a2
Showing
3 changed files
with
88 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
module CudaExt | ||
|
||
import CUDA | ||
import ClimaComms: SingletonCommsContext, CUDADevice | ||
import ClimaTimeSteppers: compute_T_lim_T_exp! | ||
|
||
@inline function compute_T_lim_T_exp!(T_lim, T_exp, U, p, t, T_lim!, T_exp!, ::SingletonCommsContext{CUDADevice}) | ||
# TODO: we should benchmark these two options to | ||
# see if one is preferrable over the other | ||
if Base.Threads.nthreads() > 1 | ||
compute_T_lim_T_exp_spawn!(T_lim, T_exp, U, p, t, T_lim!, T_exp!) | ||
else | ||
compute_T_lim_T_exp_streams!(T_lim, T_exp, U, p, t, T_lim!, T_exp!) | ||
end | ||
end | ||
|
||
@inline function compute_T_lim_T_exp_streams!(T_lim, T_exp, U, p, t, T_lim!, T_exp!) | ||
event = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING) | ||
CUDA.record(event, CUDA.stream()) # record event on main stream | ||
|
||
stream1 = CUDA.CuStream() # make a stream | ||
local event1 | ||
CUDA.stream!(stream1) do # work to be done by stream1 | ||
CUDA.wait(event, stream1) # make stream1 wait on event (host continues) | ||
T_lim!(T_lim, U, p, t) | ||
event1 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING) | ||
end | ||
CUDA.record(event1, stream1) # record event1 on stream1 | ||
|
||
stream2 = CUDA.CuStream() # make a stream | ||
local event2 | ||
CUDA.stream!(stream2) do # work to be done by stream2 | ||
CUDA.wait(event, stream2) # make stream2 wait on event (host continues) | ||
T_exp!(T_exp, U, p, t) | ||
event2 = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING) | ||
end | ||
CUDA.record(event2, stream2) # record event2 on stream2 | ||
|
||
CUDA.wait(event1, CUDA.stream()) # make main stream wait on event1 | ||
CUDA.wait(event2, CUDA.stream()) # make main stream wait on event2 | ||
end | ||
|
||
@inline function compute_T_lim_T_exp_spawn!(T_lim, T_exp, U, p, t, T_lim!, T_exp!) | ||
|
||
CUDA.synchronize() | ||
CUDA.@sync begin | ||
Base.Threads.@spawn begin | ||
T_lim!(T_lim, U, p, t) | ||
CUDA.synchronize() | ||
nothing | ||
end | ||
Base.Threads.@spawn begin | ||
T_exp!(T_exp, U, p, t) | ||
CUDA.synchronize() | ||
nothing | ||
end | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters