Skip to content

Commit

Permalink
Merge pull request #26 from firedrakeproject/dham/check_upstream
Browse files Browse the repository at this point in the history
merge upstream
  • Loading branch information
dham authored Jul 17, 2024
2 parents 87c1cd8 + eb5acd1 commit d9876d8
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 14 deletions.
2 changes: 0 additions & 2 deletions loopy/kernel/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,6 @@ def __init__(self, id, depends_on, depends_on_is_final,
# The Taggable constructor call does extra validation.
tags=tags)

Taggable.__init__(self, tags)

# {{{ abstract interface

def read_dependency_names(self):
Expand Down
2 changes: 1 addition & 1 deletion loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def __init__(self, name, tags):
assert isinstance(tags, frozenset)
assert tags

Taggable.__init__(self, tags)
self.tags = tags

def __getinitargs__(self):
return self.name, self.tags
Expand Down
7 changes: 5 additions & 2 deletions loopy/target/c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,8 @@ def with_types(self, arg_id_to_dtype, callables_table):
else:
result_dtype = dtype

if dtype.kind == "c" or name in ["real", "imag", "abs"]:
if name != "conj":
if dtype.kind == "c" or name in ["real", "realf", "imag", "imagf"]:
if name not in ["conj", "conjf"]:
name = "c" + name

return (
Expand Down Expand Up @@ -684,6 +684,9 @@ def generate_preambles(self, target):
return 0;
}}""")

if isinstance(target, CTarget):
yield ("50_cmath", "#include <math.h>")


class GNULibcCallable(ScalarCallable):
def with_types(self, arg_id_to_dtype, callables_table):
Expand Down
14 changes: 9 additions & 5 deletions loopy/target/c/c_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,16 +318,19 @@ def _tempname(self, name):
return os.path.join(self.tempdir, name)

def build(self, name, code, debug=False, wait_on_error=None,
debug_recompile=True):
debug_recompile=True, extra_build_options: Sequence[str] = ()):
"""Compile code, build and load shared library."""
logger.debug(code)
c_fname = self._tempname("code." + self.source_suffix)

# build object
_, mod_name, ext_file, recompiled = \
compile_from_string(self.toolchain, name, code, c_fname,
self.tempdir, debug, wait_on_error,
debug_recompile, False)
compile_from_string(
self.toolchain.copy(
cflags=self.toolchain.cflags+list(extra_build_options)),
name, code, c_fname,
self.tempdir, debug, wait_on_error,
debug_recompile, False)

if recompiled:
logger.debug(f"Kernel {name} compiled from source")
Expand Down Expand Up @@ -426,7 +429,8 @@ def __init__(self, kernel: LoopKernel, devprog: GeneratedProgram,
# get code and build
self.code = dev_code
self.comp = comp if comp is not None else CCompiler()
self.dll = self.comp.build(devprog.name, self.code)
self.dll = self.comp.build(devprog.name, self.code,
extra_build_options=kernel.options.build_options)

# get the function declaration for interface with ctypes
self._fn = getattr(self.dll, devprog.name)
Expand Down
10 changes: 6 additions & 4 deletions loopy/target/pyopencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ def with_types(self, arg_id_to_dtype, callables_table):
np.dtype(dtype.numpy_dtype.type(0).real))}),
callables_table)

if name in ["real", "imag"]:
if name in ["real", "imag", "conj"]:
if not dtype.is_complex():
tpname = dtype.numpy_dtype.type.__name__
return (
self.copy(
name_in_target=f"lpy_{name}_{tpname}",
name_in_target=f"_lpy_{name}_{tpname}",
arg_id_to_dtype={0: dtype, -1: dtype}),
callables_table)

Expand Down Expand Up @@ -154,8 +154,10 @@ def with_types(self, arg_id_to_dtype, callables_table):

def generate_preambles(self, target):
name = self.name_in_target
if name.startswith("lpy_real") or name.startswith("lpy_imag"):
if name.startswith("lpy_real"):
if (name.startswith("_lpy_real")
or name.startswith("_lpy_conj")
or name.startswith("_lpy_imag")):
if name.startswith("_lpy_real") or name.startswith("_lpy_conj"):
ret = "x"
else:
ret = "0"
Expand Down
75 changes: 75 additions & 0 deletions test/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,81 @@ def test_np_bool_handling(ctx_factory):
assert out.get().item() is True


@pytest.mark.parametrize("target", [lp.PyOpenCLTarget, lp.ExecutableCTarget])
def test_complex_functions_with_real_args(ctx_factory, target):
# Reported by David Ham. See <https://github.com/inducer/loopy/issues/851>
t_unit = lp.make_kernel(
"{[i]: 0<=i<10}",
"""
y1[i] = abs(c64[i])
y2[i] = real(c64[i])
y3[i] = imag(c64[i])
y4[i] = conj(c64[i])
y5[i] = abs(c128[i])
y6[i] = real(c128[i])
y7[i] = imag(c128[i])
y8[i] = conj(c128[i])
y9[i] = abs(f32[i])
y10[i] = real(f32[i])
y11[i] = imag(f32[i])
y12[i] = conj(f32[i])
y13[i] = abs(f64[i])
y14[i] = real(f64[i])
y15[i] = imag(f64[i])
y16[i] = conj(f64[i])
""",
target=target())

t_unit = lp.add_dtypes(t_unit,
{"y9,y10,y11,y12": np.complex64,
"y13,y14,y15,y16": np.complex128,
"c64": np.complex64,
"c128": np.complex128,
"f64": np.float64,
"f32": np.float32})
t_unit = lp.set_options(t_unit, return_dict=True)

from numpy.random import default_rng
rng = default_rng(0)
c64 = (rng.random(10, dtype=np.float32)
+ np.csingle(1j)*rng.random(10, dtype=np.float32))
c128 = (rng.random(10, dtype=np.float64)
+ np.cdouble(1j)*rng.random(10, dtype=np.float64))
f32 = rng.random(10, dtype=np.float32)
f64 = rng.random(10, dtype=np.float64)

if target == lp.PyOpenCLTarget:
cl_ctx = ctx_factory()
with cl.CommandQueue(cl_ctx) as queue:
evt, out = t_unit(queue, c64=c64, c128=c128, f32=f32, f64=f64)
elif target == lp.ExecutableCTarget:
t_unit = lp.set_options(t_unit, build_options=["-Werror"])
evt, out = t_unit(c64=c64, c128=c128, f32=f32, f64=f64)
else:
raise NotImplementedError("unsupported target")

np.testing.assert_allclose(out["y1"], np.abs(c64), rtol=1e-6)
np.testing.assert_allclose(out["y2"], np.real(c64), rtol=1e-6)
np.testing.assert_allclose(out["y3"], np.imag(c64), rtol=1e-6)
np.testing.assert_allclose(out["y4"], np.conj(c64), rtol=1e-6)
np.testing.assert_allclose(out["y5"], np.abs(c128), rtol=1e-6)
np.testing.assert_allclose(out["y6"], np.real(c128), rtol=1e-6)
np.testing.assert_allclose(out["y7"], np.imag(c128), rtol=1e-6)
np.testing.assert_allclose(out["y8"], np.conj(c128), rtol=1e-6)
np.testing.assert_allclose(out["y9"], np.abs(f32), rtol=1e-6)
np.testing.assert_allclose(out["y10"], np.real(f32), rtol=1e-6)
np.testing.assert_allclose(out["y11"], np.imag(f32), rtol=1e-6)
np.testing.assert_allclose(out["y12"], np.conj(f32), rtol=1e-6)
np.testing.assert_allclose(out["y13"], np.abs(f64), rtol=1e-6)
np.testing.assert_allclose(out["y14"], np.real(f64), rtol=1e-6)
np.testing.assert_allclose(out["y15"], np.imag(f64), rtol=1e-6)
np.testing.assert_allclose(out["y16"], np.conj(f64), rtol=1e-6)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit d9876d8

Please sign in to comment.