Skip to content

Commit

Permalink
Merge pull request #12 from firedrakeproject/firedrake-fix-float128an…
Browse files Browse the repository at this point in the history
…dcomplex256

Firedrake fix float128andcomplex256
  • Loading branch information
kaushikcfd authored Mar 17, 2021
2 parents cbbdc7e + e55d6f7 commit 7665f16
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
9 changes: 6 additions & 3 deletions loopy/target/c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,9 @@ class CMathCallable(ScalarCallable):
_real_map = {
np.dtype(np.complex64): np.dtype(np.float32),
np.dtype(np.complex128): np.dtype(np.float64),
np.dtype(np.complex256): np.dtype(np.float128)
}
if hasattr(np, "complex256"):
_real_map[np.dtype(np.complex256)] = np.dtype(np.float128)

def generate_preambles(self, target):
if self.name_in_target.startswith("loopy_" + self.name):
Expand Down Expand Up @@ -495,7 +496,8 @@ def with_types(self, arg_id_to_dtype, caller_kernel, callables_table):
pass # fabs
elif dtype in [np.float32, np.complex64]:
name = name + "f" # fminf
elif dtype in [np.float128, np.complex256]:
elif ((hasattr(np, "float128") and dtype == np.float128) or
(hasattr(np, "complex256") and dtype == np.complex256)): # pylint:disable=no-member
name = name + "l" # fminl
else:
raise LoopyTypeError("%s does not support type %s" % (name,
Expand Down Expand Up @@ -548,7 +550,8 @@ def with_types(self, arg_id_to_dtype, caller_kernel, callables_table):
pass # fabs
elif dtype in [np.float32, np.complex64]:
name = name + "f" # fminf
elif dtype in [np.float128, np.complex256]:
elif ((hasattr(np, "float128") and dtype == np.float128) or
(hasattr(np, "complex256") and dtype == np.complex256)): # pylint:disable=no-member
name = name + "l" # fminl
else:
raise LoopyTypeError("%s does not support type %s"
Expand Down
35 changes: 31 additions & 4 deletions loopy/target/c/c_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,31 @@ def __init__(self, toolchain=None,
library_dirs=library_dirs, defines=defines, source_suffix=source_suffix)


class IDIToCDLL(object):
# {{{ placeholder till ctypes fixes: bugs.python.org/issue16899

class Complex64(ctypes.Structure):
_fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]


class Complex128(ctypes.Structure):
_fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]


class Complex256(ctypes.Structure):
_fields_ = [("real", ctypes.c_longdouble), ("imag", ctypes.c_longdouble)]


_NUMPY_COMPLEX_TYPE_TO_CTYPE = {
np.complex64: Complex64,
np.complex128: Complex128,
}
if hasattr(np, "complex256"):
_NUMPY_COMPLEX_TYPE_TO_CTYPE[np.complex256] = Complex256

# }}}


class IDIToCDLL:
"""
A utility class that extracts arguement and return type info from a
:class:`ImplementedDataInfo` in order to create a :class:`ctype.CDLL`
Expand All @@ -322,9 +346,12 @@ def __call__(self, knl, idi):

def _dtype_to_ctype(self, dtype, pointer=False):
"""Map NumPy dtype to equivalent ctypes type."""
typename = self.registry.dtype_to_ctype(dtype)
typename = {'unsigned': 'uint'}.get(typename, typename)
basetype = getattr(ctypes, 'c_' + typename)
if dtype.is_complex():
# complex ctypes aren't exposed
np_dtype = dtype.numpy_dtype.type
basetype = _NUMPY_COMPLEX_TYPE_TO_CTYPE[np_dtype]
else:
basetype = np.ctypeslib.as_ctypes_type(dtype)
if pointer:
return ctypes.POINTER(basetype)
return basetype
Expand Down

0 comments on commit 7665f16

Please sign in to comment.