Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas committed Jan 12, 2024
1 parent 21805de commit ceccc7d
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 94 deletions.
2 changes: 1 addition & 1 deletion mpi4py_fft/distarrayCuPy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __new__(
subcomm = cls.get_subcomm(subcomm, global_shape, rank, alignment)
p0, subshape = cls.setup_pencil(subcomm, rank, global_shape, alignment)

obj = cls.xp.ndarray.__new__(cls, subshape, dtype=dtype, memptr=memptr)
obj = cls.xp.ndarray.__new__(cls, subshape, dtype=dtype, memptr=memptr, strides=strides)
if memptr is None and isinstance(val, Number):
obj.fill(val)
obj._p0 = p0
Expand Down
55 changes: 13 additions & 42 deletions mpi4py_fft/libfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,30 +80,18 @@ def _Xfftn_plan_fftw(shape, axes, dtype, transforms, options):

def _Xfftn_plan_cupy(shape, axes, dtype, transforms, options):
import cupy as cp
cp.fft.config.enable_nd_planning = True

transforms = {} if transforms is None else transforms
if tuple(axes) in transforms:
plan_fwd, plan_bck = transforms[tuple(axes)]
else:
if cp.issubdtype(dtype, cp.floating):
_plan_fwd = cp.fft.rfftn
_plan_bck = cp.fft.irfftn
plan_fwd = cp.fft.rfftn
plan_bck = cp.fft.irfftn
else:
_plan_fwd = cp.fft.fftn
_plan_bck = cp.fft.ifftn

stream = cp.cuda.stream.Stream()
def execute_in_stream(function, *args, **kwargs):
with stream:
result = function(*args, **kwargs)
stream.synchronize()
return result

def plan_fwd(*args, **kwargs):
return execute_in_stream(_plan_fwd, *args, **kwargs)

def plan_bck(*args, **kwargs):
return execute_in_stream(_plan_bck, *args, **kwargs)
plan_fwd = cp.fft.fftn
plan_bck = cp.fft.ifftn

s = tuple(np.take(shape, axes))
U = cp.array(fftw.aligned(shape, dtype=dtype)) # TODO: avoid going via CPU
Expand Down Expand Up @@ -166,39 +154,22 @@ def _Xfftn_plan_mkl(shape, axes, dtype, transforms, options): #pragma: no cover

def _Xfftn_plan_cupyx_scipy(shape, axes, dtype, transforms, options):
import cupy as cp
import cupyx.scipy.fft as fft_lib
import cupyx.scipy.fftpack as cufft

transforms = {} if transforms is None else transforms
if tuple(axes) in transforms:
_plan_fwd, _plan_bck = transforms[tuple(axes)]
plan_fwd, plan_bck = transforms[tuple(axes)]
else:
if cp.issubdtype(dtype, cp.floating):
_plan_fwd = fft_lib.rfftn
_plan_bck = fft_lib.irfftn
else:
_plan_fwd = fft_lib.fftn
_plan_bck = fft_lib.ifftn

def swap_shape_for_s(kwargs):
_kwargs = {
's': kwargs.pop('shape', None),
**kwargs,
}
return _kwargs

def plan_fwd(*args, **kwargs):
return _plan_fwd(*args, **swap_shape_for_s(kwargs))

def plan_bck(*args, **kwargs):
return _plan_bck(*args, **swap_shape_for_s(kwargs))
plan_fwd = cufft.fftn
plan_bck = cufft.ifftn

s = tuple(np.take(shape, axes))
U = cp.array(fftw.aligned(shape, dtype=dtype)) # TODO: Skip CPU detour
V = plan_fwd(U, s=s, axes=axes)
V = plan_fwd(U, shape=s, axes=axes)
V = cp.array(fftw.aligned_like(V.get())) # TODO: skip CPU detour
M = np.prod(s)
return (_Yfftn_wrap(plan_fwd, U, V, 1, {'shape': s, 'axes': axes}),
_Yfftn_wrap(plan_bck, V, U, M, {'shape': s, 'axes': axes}))
return (_Yfftn_wrap(plan_fwd, U, V, 1, {'shape': s, 'axes': axes, 'overwrite_x': True}),
_Yfftn_wrap(plan_bck, V, U, M, {'shape': s, 'axes': axes, 'overwrite_x': True}))

def _Xfftn_plan_scipy(shape, axes, dtype, transforms, options):

Expand Down Expand Up @@ -468,7 +439,7 @@ def __init__(self, shape, axes=None, dtype=float, padding=False,
self.M = 1./np.prod(np.take(self.shape, self.axes))
else:
self.M = self.fwd.get_normalization()
if backend == 'scipy':
if backend in ['scipy', 'cupyx-scipy']:
self.real_transform = False # No rfftn/irfftn methods

self.padding_factor = 1.0
Expand Down
6 changes: 3 additions & 3 deletions mpi4py_fft/mpifft.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ class PFFT(object):
:mod:`.fftw.xfftn` module. See Examples.
comm_backend : str, optional
Choose backend for communication. When using GPU based serial backends,
the "NCCL" backend can be be used in `Alltoallw` operations to speedup
GPU to GPU transfer. Keep in mind that this is used alongside MPI and
assumes one GPU per MPI rankMPI is used.
the "NCCL" backend or a "customMPI" backend can be be used in `Alltoallw`
operations to speedup GPU to GPU transfer. Keep in mind that this is used
alongside MPI and assumes one GPU per MPI rank is used.
Other Parameters
----------------
Expand Down
78 changes: 30 additions & 48 deletions mpi4py_fft/pencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def synchronize_stream():

rank, size, comm = self.comm.rank, self.comm.size, self.comm

for i in range(size):
for i in range(1, size + 1):
send_to = (rank + i) % size
recv_from = (rank -i + size) % size

Expand All @@ -242,22 +242,13 @@ def synchronize_stream():
if send_to == rank:
arrayB[sliceB][:] = arrayA[sliceA][:]
else:
# send asynchronously
sendbuff = xp.empty(shapeA, dtype=self.dtype)
sendbuff[:] = arrayA[sliceA][:]
synchronize_stream()
req = comm.Isend(sendbuff, dest=send_to)
recvbuf = xp.empty(shapeB, dtype=self.dtype)
sendbuf = xp.empty(shapeA, dtype=self.dtype)
sendbuf[:] = arrayA[sliceA][:]

# receive
recvbuff = xp.empty(shapeB, dtype=self.dtype)
comm.Recv(recvbuff, source=recv_from)
synchronize_stream()
arrayB[sliceB][:] = recvbuff[:]

# finish send and clean up
req.wait()
del sendbuff
del recvbuff
comm.Sendrecv(sendbuf, send_to, recvbuf=recvbuf, source=recv_from)
arrayB[sliceB][:] = recvbuf[:]

@staticmethod
def get_slice_and_shape(subtype):
Expand Down Expand Up @@ -302,45 +293,36 @@ def Alltoallw(self, arrayA, subtypesA, arrayB, subtypesB):
iscomplex = cp.iscomplexobj(arrayA)
NCCL_dtype, real_dtype = self.get_nccl_and_real_dtypes(arrayA)

def send(array, subtype, send_to, iscomplex, stream):
local_slice, shape = self.get_slice_and_shape(subtype)
buff = self.get_buffer(shape, iscomplex, real_dtype, stream)
self.fill_buffer(buff, array, local_slice, iscomplex)
comm.send(buff.data.ptr, buff.size, NCCL_dtype, send_to, stream.ptr)

events = []
streams = [cp.cuda.Stream(null=False) for _ in range(size)]
for i, stream in zip(range(size), streams):
with stream:

send_to = (rank + i) % size
recv_from = (rank -i + size) % size

if send_to > rank:
send(arrayA, subtypesA[send_to], send_to, iscomplex, stream)
stream = cp.cuda.Stream(null=True)
stream.use()

local_slice, shape = self.get_slice_and_shape(subtypesB[recv_from])
buff = self.get_buffer(shape, iscomplex, real_dtype, stream)
for i in range(size):
send_to = (rank + i) % size
recv_from = (rank -i + size) % size

if recv_from == rank:
send_slice, _ = self.get_slice_and_shape(subtypesA[send_to])
self.fill_buffer(buff, arrayA, send_slice, iscomplex)
else:
comm.recv(buff.data.ptr, buff.size, NCCL_dtype, recv_from, stream.ptr)
# prepare receive buffer
local_slice, shape = self.get_slice_and_shape(subtypesB[recv_from])
recv_buff = self.get_buffer(shape, iscomplex, real_dtype)

self.unpack_buffer(buff, arrayB, local_slice, iscomplex)
# prepare send buffer
send_slice, send_shape = self.get_slice_and_shape(subtypesA[send_to])

if send_to < rank:
send(arrayA, subtypesA[send_to], send_to, iscomplex, stream)
# send / receive
if send_to == rank:
self.fill_buffer(recv_buff, arrayA, send_slice, iscomplex)
else:
send_buff = self.get_buffer(send_shape, iscomplex, real_dtype)
self.fill_buffer(send_buff, arrayA, send_slice, iscomplex)

events += [stream.record()]
# perform all sends and receives in a single kernel to allow overlap
cp.cuda.nccl.groupStart()
comm.recv(recv_buff.data.ptr, recv_buff.size, NCCL_dtype, recv_from, stream.ptr)
comm.send(send_buff.data.ptr, send_buff.size, NCCL_dtype, send_to, stream.ptr)
cp.cuda.nccl.groupEnd()

null_stream = cp.cuda.Stream(null=True)
null_stream.use()
for event in events:
null_stream.wait_event(event)
self.unpack_buffer(recv_buff, arrayB, local_slice, iscomplex)

cp.cuda.Device(0).synchronize()
cp.cuda.Stream(null=True).use()

@staticmethod
def get_slice_and_shape(subtype):
Expand Down Expand Up @@ -376,7 +358,7 @@ def get_nccl_and_real_dtypes(array):
return nccl_dtypes[array.dtype], real_dtypes[array.dtype]

@staticmethod
def get_buffer(shape, iscomplex, real_dtype, stream):
def get_buffer(shape, iscomplex, real_dtype):
"""
Get a buffer for communication. If complex numbers are used, we send
two real values instead.
Expand Down

0 comments on commit ceccc7d

Please sign in to comment.