Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check cl_khr_fp16 support #225

Closed
wants to merge 9 commits into from
5 changes: 5 additions & 0 deletions clpy/backend/compiler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ import functools
import operator
cimport function

import clpy.backend.opencl.api
cimport clpy.backend.opencl.api
cimport clpy.backend.opencl.env
cimport clpy.backend.opencl.utility

cpdef function.Module compile_with_cache(
str source, tuple options=(), arch=None, cache_dir=None,
extra_source=None):
options += (' -cl-fp32-correctly-rounded-divide-sqrt', )
if clpy.backend.opencl.env.supports_cl_khr_fp16() == \
clpy.backend.opencl.api.TRUE:
options += (' -D__CLPY_ENABLE_CL_KHR_FP16', )
optionStr = functools.reduce(operator.add, options)

device = clpy.backend.opencl.env.get_device()
Expand Down
92 changes: 92 additions & 0 deletions clpy/backend/opencl/env.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
import atexit
import locale
import logging
import numpy
import re

from clpy.backend.opencl import api
from clpy.backend.opencl cimport api
from clpy.backend.opencl import exceptions
from cython.view cimport array as cython_array

from libc.stdlib cimport malloc
from libc.stdlib cimport free
from libc.string cimport memset
from libc.string cimport memcpy

cdef interpret_versionstr(versionstr):
version_detector = re.compile('''OpenCL (\d+)\.(\d+)''')
Expand Down Expand Up @@ -137,6 +143,10 @@ for id in range(__num_devices):
CL_QUEUE_PROFILING_ENABLE)
logging.info("SUCCESS")

cdef cl_char* __supports_cl_khr_fp16 = \
<cl_char*>malloc(sizeof(cl_char)*__num_devices)
memset(__supports_cl_khr_fp16, -1, sizeof(cl_char)*__num_devices)

##########################################
# Functions
##########################################
Expand Down Expand Up @@ -164,6 +174,88 @@ cdef cl_device_id get_device():
global __current_device_id
return __devices[__current_device_id]

cdef cl_char check_cl_khr_fp16():
code = b'''
#pragma OPENCL EXTENSION cl_khr_fp16: enable
__kernel void check_cl_khr_fp16(__global half* v){
ushort x = 0x5140;
*v = *(const half*)&x;
}
'''
cdef size_t length
cdef char* src
cdef char* options
cdef cl_mem buf=NULL
cdef size_t global_work_size[3]
cdef size_t ptr
try:
device = __devices[__current_device_id]
context = __contexts[__current_device_id]
queue = __command_queues[__current_device_id]

length = len(code)
src = <char*>malloc(sizeof(char)*length)
memcpy(src, <char*>code, length)

program = api.CreateProgramWithSource(
context=context,
count=<size_t>1,
strings=&src,
lengths=&length)
options = b'-cl-fp32-correctly-rounded-divide-sqrt'
api.BuildProgram(
program,
1,
&device,
options,
<void*>NULL,
<void*>NULL)
kernel = api.CreateKernel(
program, b'check_cl_khr_fp16')
global_work_size[0] = 1
global_work_size[1] = 0
global_work_size[2] = 0
v = numpy.array([0], dtype=numpy.float16)
ptr = <size_t>v.ctypes.get_as_parameter().value
buf = api.CreateBuffer(
context,
CL_MEM_WRITE_ONLY,
<size_t>v.nbytes,
<void*>NULL)
api.SetKernelArg(
kernel,
0,
sizeof(cl_mem),
&buf)
api.EnqueueNDRangeKernel(
command_queue=queue,
kernel=kernel,
work_dim=1,
global_work_offset=<size_t*>NULL,
global_work_size=&global_work_size[0],
local_work_size=<size_t*>NULL)
api.EnqueueReadBuffer(
command_queue=queue,
buffer=buf,
blocking_read=api.BLOCKING,
offset=0,
cb=<size_t>v.nbytes,
host_ptr=<void*>ptr)
result = 1 if v[0] == numpy.float16(42) else 0
except exceptions.OpenCLRuntimeError:
result = 0
if buf != NULL:
api.ReleaseMemObject(buf)
free(src)
return result

cpdef cl_bool supports_cl_khr_fp16():
global __current_device_id
if __supports_cl_khr_fp16[__current_device_id] == -1:
__supports_cl_khr_fp16[__current_device_id] = check_cl_khr_fp16()
return CL_TRUE if __supports_cl_khr_fp16[__current_device_id] == 1 \
else CL_FALSE


def release():
"""Release command_queue and context automatically."""
Expand Down
6 changes: 6 additions & 0 deletions clpy/core/include/clpy/carray.clh
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once
#ifdef __CLPY_ENABLE_CL_KHR_FP16
#pragma OPENCL EXTENSION cl_khr_fp16: enable
#endif

// TODO: Implement common functions in OpenCL C
#if 0
Expand Down Expand Up @@ -532,13 +534,16 @@ static void __clpy_end_print_out() __attribute__((annotate("clpy_end_print_out")
#ifdef __ULTIMA
__attribute__((annotate("clpy_no_mangle"))) static half convert_float_to_half(float x);
#else
#ifdef __CLPY_ENABLE_CL_KHR_FP16
#include "fp16.clh"
typedef half __clpy__half;
#endif
#endif

#ifdef __ULTIMA
__attribute__((annotate("clpy_no_mangle"))) static half clpy_nextafter_fp16(half x1, half x2);
#else
#ifdef __CLPY_ENABLE_CL_KHR_FP16
static int isnan_fp16(half x){
unsigned short const* x_raw = (unsigned short const*)&x;
return (*x_raw & 0x7c00u) == 0x7c00u && (*x_raw & 0x03ffu) != 0x0000u;
Expand Down Expand Up @@ -581,3 +586,4 @@ static half clpy_nextafter_fp16(half x1, half x2){
return *(half*)&ret_raw_;
}
#endif
#endif
2 changes: 2 additions & 0 deletions clpy/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from clpy.testing.array import assert_array_max_ulp # NOQA
from clpy.testing.attr import gpu # NOQA
from clpy.testing.attr import multi_gpu # NOQA
from clpy.testing.attr import skip_when_disabled_cl_khr_fp16 # NOQA
from clpy.testing.attr import slow # NOQA
from clpy.testing.bufio import readbuf # NOQA
from clpy.testing.bufio import writebuf # NOQA
from clpy.testing.helper import assert_warns # NOQA
from clpy.testing.helper import for_8bit_integer_dtypes # NOQA
from clpy.testing.helper import for_all_dtypes # NOQA
from clpy.testing.helper import for_all_dtypes_combination # NOQA
from clpy.testing.helper import for_CF_orders # NOQA
Expand Down
8 changes: 8 additions & 0 deletions clpy/testing/attr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import clpy
import os
import unittest

Expand Down Expand Up @@ -65,3 +66,10 @@ def gpu(f):

check_available()
return multi_gpu(1)(pytest.mark.gpu(f))


def skip_when_disabled_cl_khr_fp16(f):
check_available()
return unittest.skipUnless(
clpy.backend.opencl.env.supports_cl_khr_fp16() ==
clpy.backend.opencl.api.TRUE, "")(f)
93 changes: 64 additions & 29 deletions clpy/testing/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,24 +509,31 @@ def test_func(self, *args, **kw):
_complex_dtypes = (numpy.complex64, numpy.complex128)
_regular_float_dtypes = (numpy.float64, numpy.float32)
_float_dtypes = _regular_float_dtypes + (numpy.float16,)
_signed_dtypes = tuple(numpy.dtype(i).type for i in 'bhilq')
_unsigned_dtypes = tuple(numpy.dtype(i).type for i in 'BHILQ')
_signed_dtypes_without_int8 = tuple(numpy.dtype(i).type for i in 'hilq')
_signed_dtypes = (numpy.int8,) + _signed_dtypes_without_int8
_unsigned_dtypes_without_uint8 = tuple(numpy.dtype(i).type for i in 'HILQ')
_unsigned_dtypes = (numpy.uint8,) + _unsigned_dtypes_without_uint8
_int_dtypes_without_8bit = _signed_dtypes_without_int8 + \
_unsigned_dtypes_without_uint8
_int_dtypes = _signed_dtypes + _unsigned_dtypes
_int_bool_dtypes = _int_dtypes + (numpy.bool_,)
_regular_dtypes = _regular_float_dtypes + _int_bool_dtypes
_dtypes = _float_dtypes + _int_bool_dtypes
_8bit_int_dtypes = tuple(numpy.dtype(i).type for i in 'bB')
_8bit_int_bool_dtypes = tuple(numpy.dtype(i).type for i in '?bB')


def _make_all_dtypes(no_float16, no_bool, no_complex):
def _make_all_dtypes(no_float16, no_bool, no_complex, no_8bit_integer):
if no_float16:
dtypes = _regular_float_dtypes
else:
dtypes = _float_dtypes

if no_bool:
dtypes += _int_dtypes
else:
dtypes += _int_bool_dtypes
dtypes += _int_dtypes_without_8bit

if not no_8bit_integer:
dtypes += _8bit_int_dtypes
if not no_bool:
dtypes += (numpy.bool_,)

# if not no_complex:
# # TODO(LWisteria): Support complex in OpenCL
Expand All @@ -536,7 +543,7 @@ def _make_all_dtypes(no_float16, no_bool, no_complex):


def for_all_dtypes(name='dtype', no_float16=True, no_bool=False,
no_complex=False):
no_complex=False, no_8bit_integer=False):
"""Decorator that checks the fixture with all dtypes.

Args:
Expand Down Expand Up @@ -591,8 +598,10 @@ def for_all_dtypes(name='dtype', no_float16=True, no_bool=False,

.. seealso:: :func:`clpy.testing.for_dtypes`
"""
return for_dtypes(_make_all_dtypes(no_float16, no_bool, no_complex),
name=name)
return for_dtypes(_make_all_dtypes(
no_float16, no_bool, no_complex,
no_8bit_integer),
name=name)


def for_float_dtypes(name='dtype', no_float16=True):
Expand All @@ -615,7 +624,7 @@ def for_float_dtypes(name='dtype', no_float16=True):
return for_dtypes(_float_dtypes, name=name)


def for_signed_dtypes(name='dtype'):
def for_signed_dtypes(name='dtype', no_int8=False):
"""Decorator that checks the fixture with signed dtypes.

Args:
Expand All @@ -627,10 +636,13 @@ def for_signed_dtypes(name='dtype'):
.. seealso:: :func:`clpy.testing.for_dtypes`,
:func:`clpy.testing.for_all_dtypes`
"""
return for_dtypes(_signed_dtypes, name=name)
if no_int8:
return for_dtypes(_signed_dtypes_without_int8, name=name)
else:
return for_dtypes(_signed_dtypes, name=name)


def for_unsigned_dtypes(name='dtype'):
def for_unsigned_dtypes(name='dtype', no_uint8=False):
"""Decorator that checks the fixture with all dtypes.

Args:
Expand All @@ -643,10 +655,13 @@ def for_unsigned_dtypes(name='dtype'):
.. seealso:: :func:`clpy.testing.for_dtypes`,
:func:`clpy.testing.for_all_dtypes`
"""
return for_dtypes(_unsigned_dtypes, name=name)
if no_uint8:
return for_dtypes(_unsigned_dtypes_without_uint8, name=name)
else:
return for_dtypes(_unsigned_dtypes, name=name)


def for_int_dtypes(name='dtype', no_bool=False):
def for_int_dtypes(name='dtype', no_bool=False, no_8bit_integer=False):
"""Decorator that checks the fixture with integer and optionally bool dtypes.

Args:
Expand All @@ -662,10 +677,17 @@ def for_int_dtypes(name='dtype', no_bool=False):
.. seealso:: :func:`clpy.testing.for_dtypes`,
:func:`clpy.testing.for_all_dtypes`
"""
if no_bool:
return for_dtypes(_int_dtypes, name=name)
if no_8bit_integer:
return for_dtypes(_int_dtypes_without_8bit, name=name)
else:
return for_dtypes(_int_bool_dtypes, name=name)
if no_bool:
return for_dtypes(_int_dtypes, name=name)
else:
return for_dtypes(_int_bool_dtypes, name=name)


def for_8bit_integer_dtypes(name='dtype'):
return for_dtypes(_8bit_int_bool_dtypes, name=name)


def for_dtypes_combination(types, names=('dtype',), full=None):
Expand Down Expand Up @@ -743,7 +765,7 @@ def test_func(self, *args, **kw):

def for_all_dtypes_combination(names=('dtyes',),
no_float16=True, no_bool=False, full=None,
no_complex=False):
no_complex=False, no_8bit_integer=False):
"""Decorator that checks the fixture with a product set of all dtypes.

Args:
Expand All @@ -761,11 +783,11 @@ def for_all_dtypes_combination(names=('dtyes',),

.. seealso:: :func:`clpy.testing.for_dtypes_combination`
"""
types = _make_all_dtypes(no_float16, no_bool, no_complex)
types = _make_all_dtypes(no_float16, no_bool, no_complex, no_8bit_integer)
return for_dtypes_combination(types, names, full)


def for_signed_dtypes_combination(names=('dtype',), full=None):
def for_signed_dtypes_combination(names=('dtype',), no_int8=False, full=None):
"""Decorator for parameterized test w.r.t. the product set of signed dtypes.

Args:
Expand All @@ -777,10 +799,15 @@ def for_signed_dtypes_combination(names=('dtype',), full=None):

.. seealso:: :func:`clpy.testing.for_dtypes_combination`
"""
return for_dtypes_combination(_signed_dtypes, names=names, full=full)
if no_int8:
types = _signed_dtypes_without_int8
else:
types = _signed_dtypes
return for_dtypes_combination(types, names=names, full=full)


def for_unsigned_dtypes_combination(names=('dtype',), full=None):
def for_unsigned_dtypes_combination(names=('dtype',), no_uint8=False,
full=None):
"""Decorator for parameterized test w.r.t. the product set of unsigned dtypes.

Args:
Expand All @@ -792,10 +819,15 @@ def for_unsigned_dtypes_combination(names=('dtype',), full=None):

.. seealso:: :func:`clpy.testing.for_dtypes_combination`
"""
return for_dtypes_combination(_unsigned_dtypes, names=names, full=full)
if no_uint8:
types = _unsigned_dtypes_without_uint8
else:
types = _unsigned_dtypes
return for_dtypes_combination(types, names=names, full=full)


def for_int_dtypes_combination(names=('dtype',), no_bool=False, full=None):
def for_int_dtypes_combination(names=('dtype',), no_bool=False,
no_8bit_integer=False, full=None):
"""Decorator for parameterized test w.r.t. the product set of int and boolean.

Args:
Expand All @@ -809,10 +841,13 @@ def for_int_dtypes_combination(names=('dtype',), no_bool=False, full=None):

.. seealso:: :func:`clpy.testing.for_dtypes_combination`
"""
if no_bool:
types = _int_dtypes
if no_8bit_integer:
types = _int_dtypes_without_8bit
else:
types = _int_bool_dtypes
if no_bool:
types = _int_dtypes
else:
types = _int_bool_dtypes
return for_dtypes_combination(types, names, full)


Expand Down
Loading