Skip to content

Commit

Permalink
Fixed HIP import error, made backend import error messages point to d…
Browse files Browse the repository at this point in the history
…ocumentation
  • Loading branch information
fjwillemsen committed Oct 3, 2023
1 parent 33a2382 commit 8e20ccc
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 119 deletions.
39 changes: 17 additions & 22 deletions kernel_tuner/backends/cupy.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
"""This module contains all Cupy specific kernel_tuner functions"""
"""This module contains all Cupy specific kernel_tuner functions."""
from __future__ import print_function


import logging
import time
import numpy as np

from kernel_tuner.backends.backend import GPUBackend
from kernel_tuner.observers.cupy import CupyRuntimeObserver


# embedded in try block to be able to generate documentation
# and run tests without cupy installed
try:
Expand All @@ -19,10 +15,10 @@


class CupyFunctions(GPUBackend):
"""Class that groups the Cupy functions on maintains state about the device"""
"""Class that groups the Cupy functions on maintains state about the device."""

def __init__(self, device=0, iterations=7, compiler_options=None, observers=None):
"""instantiate CupyFunctions object used for interacting with the CUDA device
"""Instantiate CupyFunctions object used for interacting with the CUDA device.
Instantiating this object will inspect and store certain device properties at
runtime, which are used during compilation and/or execution of kernels by the
Expand All @@ -39,8 +35,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
self.texrefs = []
if not cp:
raise ImportError(
"Error: cupy not installed, please install e.g. "
+ "using 'pip install cupy', please check https://github.com/cupy/cupy."
"cupy not installed, install using 'pip install cupy', or check https://kerneltuner.github.io/kernel_tuner/stable/install.html#cuda-and-pycuda."
)

# select device
Expand Down Expand Up @@ -88,7 +83,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
self.name = env["device_name"]

def ready_argument_list(self, arguments):
"""ready argument list to be passed to the kernel, allocates gpu mem
"""Ready argument list to be passed to the kernel, allocates gpu mem.
:param arguments: List of arguments to be passed to the kernel.
The order should match the argument list on the CUDA kernel.
Expand All @@ -111,7 +106,7 @@ def ready_argument_list(self, arguments):
return gpu_args

def compile(self, kernel_instance):
"""call the CUDA compiler to compile the kernel, return the device function
"""Call the CUDA compiler to compile the kernel, return the device function.
:param kernel_name: The name of the kernel to be compiled, used to lookup the
function after compilation.
Expand Down Expand Up @@ -140,23 +135,23 @@ def compile(self, kernel_instance):
return self.func

def start_event(self):
"""Records the event that marks the start of a measurement"""
"""Records the event that marks the start of a measurement."""
self.start.record(stream=self.stream)

def stop_event(self):
"""Records the event that marks the end of a measurement"""
"""Records the event that marks the end of a measurement."""
self.end.record(stream=self.stream)

def kernel_finished(self):
"""Returns True if the kernel has finished, False otherwise"""
"""Returns True if the kernel has finished, False otherwise."""
return self.end.done

def synchronize(self):
"""Halts execution until device has finished its tasks"""
"""Halts execution until device has finished its tasks."""
self.dev.synchronize()

def copy_constant_memory_args(self, cmem_args):
"""adds constant memory arguments to the most recently compiled module
"""Adds constant memory arguments to the most recently compiled module.
:param cmem_args: A dictionary containing the data to be passed to the
device constant memory. The format to be used is as follows: A
Expand All @@ -171,11 +166,11 @@ def copy_constant_memory_args(self, cmem_args):
constant_mem[:] = cp.asarray(v)

def copy_shared_memory_args(self, smem_args):
"""add shared memory arguments to the kernel"""
"""Add shared memory arguments to the kernel."""
self.smem_size = smem_args["size"]

def copy_texture_memory_args(self, texmem_args):
"""adds texture memory arguments to the most recently compiled module
"""Adds texture memory arguments to the most recently compiled module.
:param texmem_args: A dictionary containing the data to be passed to the
device texture memory. See tune_kernel().
Expand All @@ -184,7 +179,7 @@ def copy_texture_memory_args(self, texmem_args):
raise NotImplementedError("CuPy backend does not support texture memory")

def run_kernel(self, func, gpu_args, threads, grid, stream=None):
"""runs the CUDA kernel passed as 'func'
"""Runs the CUDA kernel passed as 'func'.
:param func: A cupy kernel compiled for this specific kernel configuration
:type func: cupy.RawKernel
Expand All @@ -205,7 +200,7 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
func(grid, threads, gpu_args, stream=stream, shared_mem=self.smem_size)

def memset(self, allocation, value, size):
"""set the memory in allocation to the value in value
"""Set the memory in allocation to the value in value.
:param allocation: A GPU memory allocation unit
:type allocation: cupy.ndarray
Expand All @@ -220,7 +215,7 @@ def memset(self, allocation, value, size):
allocation[:] = value

def memcpy_dtoh(self, dest, src):
"""perform a device to host memory copy
"""Perform a device to host memory copy.
:param dest: A numpy array in host memory to store the data
:type dest: numpy.ndarray
Expand All @@ -237,7 +232,7 @@ def memcpy_dtoh(self, dest, src):
raise ValueError("dest type not supported")

def memcpy_htod(self, dest, src):
"""perform a host to device memory copy
"""Perform a host to device memory copy.
:param dest: A GPU memory allocation unit
:type dest: cupy.ndarray
Expand Down
79 changes: 41 additions & 38 deletions kernel_tuner/backends/hip.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
"""This module contains all HIP specific kernel_tuner functions"""
"""This module contains all HIP specific kernel_tuner functions."""

import numpy as np
import ctypes
import ctypes.util
import sys
import logging

import numpy as np

from kernel_tuner.backends.backend import GPUBackend
from kernel_tuner.observers.hip import HipRuntimeObserver

# embedded in try block to be able to generate documentation
# and run tests without pyhip installed
try:
from pyhip import hip, hiprtc
except ImportError:
print("Not able to import pyhip, check if PYTHONPATH includes PyHIP")
hip = None
hiprtc = None

Expand All @@ -35,10 +32,10 @@
hipSuccess = 0

class HipFunctions(GPUBackend):
"""Class that groups the HIP functions on maintains state about the device"""
"""Class that groups the HIP functions on maintains state about the device."""

def __init__(self, device=0, iterations=7, compiler_options=None, observers=None):
"""instantiate HipFunctions object used for interacting with the HIP device
"""Instantiate HipFunctions object used for interacting with the HIP device.
Instantiating this object will inspect and store certain device properties at
runtime, which are used during compilation and/or execution of kernels by the
Expand All @@ -51,8 +48,13 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
:param iterations: Number of iterations used while benchmarking a kernel, 7 by default.
:type iterations: int
"""
if not hip or not hiprtc:
raise ImportError("Unable to import PyHIP, make sure PYTHONPATH includes PyHIP, or check https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-pyhip.")

# embedded in try block to be able to generate documentation
# and run tests without pyhip installed
logging.debug("HipFunction instantiated")

self.hipProps = hip.hipGetDeviceProperties(device)

self.name = self.hipProps._name.decode('utf-8')
Expand Down Expand Up @@ -85,13 +87,13 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None


def ready_argument_list(self, arguments):
"""ready argument list to be passed to the HIP function
"""Ready argument list to be passed to the HIP function.
:param arguments: List of arguments to be passed to the HIP function.
The order should match the argument list on the HIP function.
Allowed values are np.ndarray, and/or np.int32, np.float32, and so on.
:type arguments: list(numpy objects)
:returns: Ctypes structure of arguments to be passed to the HIP function.
:rtype: ctypes structure
"""
Expand All @@ -109,11 +111,11 @@ def ready_argument_list(self, arguments):
hip.hipMemcpy_htod(device_ptr, data_ctypes, arg.nbytes)
ctype_args.append(device_ptr)
else:
raise TypeError("unknown dtype for ndarray")
# Convert valid non-array arguments to ctypes
raise TypeError("unknown dtype for ndarray")
# Convert valid non-array arguments to ctypes
elif isinstance(arg, np.generic):
data_ctypes = dtype_map[dtype_str](arg)
ctype_args.append(data_ctypes)
ctype_args.append(data_ctypes)

# Determine the types of the fields in the structure
field_types = [type(x) for x in ctype_args]
Expand All @@ -122,17 +124,17 @@ class ArgListStructure(ctypes.Structure):
_fields_ = [(f'field{i}', t) for i, t in enumerate(field_types)]
def __getitem__(self, key):
return getattr(self, self._fields_[key][0])

return ArgListStructure(*ctype_args)


def compile(self, kernel_instance):
"""call the HIP compiler to compile the kernel, return the function
"""Call the HIP compiler to compile the kernel, return the function.
:param kernel_instance: An object representing the specific instance of the tunable kernel
in the parameter space.
:type kernel_instance: kernel_tuner.core.KernelInstance
:returns: An ctypes function that can be called directly.
:rtype: ctypes._FuncPtr
"""
Expand All @@ -144,7 +146,7 @@ def compile(self, kernel_instance):
if 'extern "C"' not in kernel_string:
kernel_string = 'extern "C" {\n' + kernel_string + "\n}"
kernel_ptr = hiprtc.hiprtcCreateProgram(kernel_string, kernel_name, [], [])

try:
#Compile based on device (Not yet tested for non-AMD devices)
plat = hip.hipGetPlatformName()
Expand All @@ -156,7 +158,7 @@ def compile(self, kernel_instance):
options_list = []
options_list.extend(self.compiler_options)
hiprtc.hiprtcCompileProgram(kernel_ptr, options_list)

#Get module and kernel from compiled kernel string
code = hiprtc.hiprtcGetCode(kernel_ptr)
module = hip.hipModuleLoadData(code)
Expand All @@ -167,36 +169,36 @@ def compile(self, kernel_instance):
log = hiprtc.hiprtcGetProgramLog(kernel_ptr)
print(log)
raise e

return kernel

def start_event(self):
"""Records the event that marks the start of a measurement"""
"""Records the event that marks the start of a measurement."""
logging.debug("HipFunction start_event called")

hip.hipEventRecord(self.start, self.stream)

def stop_event(self):
"""Records the event that marks the end of a measurement"""
"""Records the event that marks the end of a measurement."""
logging.debug("HipFunction stop_event called")

hip.hipEventRecord(self.end, self.stream)

def kernel_finished(self):
"""Returns True if the kernel has finished, False otherwise"""
"""Returns True if the kernel has finished, False otherwise."""
logging.debug("HipFunction kernel_finished called")

# Query the status of the event
return hip.hipEventQuery(self.end)

def synchronize(self):
"""Halts execution until device has finished its tasks"""
"""Halts execution until device has finished its tasks."""
logging.debug("HipFunction synchronize called")

hip.hipDeviceSynchronize()

def run_kernel(self, func, gpu_args, threads, grid, stream=None):
"""runs the HIP kernel passed as 'func'
"""Runs the HIP kernel passed as 'func'.
:param func: A HIP kernel compiled for this specific kernel configuration
:type func: ctypes pionter
Expand All @@ -219,15 +221,15 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
if stream is None:
stream = self.stream

hip.hipModuleLaunchKernel(func,
grid[0], grid[1], grid[2],
hip.hipModuleLaunchKernel(func,
grid[0], grid[1], grid[2],
threads[0], threads[1], threads[2],
self.smem_size,
stream,
gpu_args)

def memset(self, allocation, value, size):
"""set the memory in allocation to the value in value
"""Set the memory in allocation to the value in value.
:param allocation: A GPU memory allocation unit
:type allocation: ctypes ptr
Expand All @@ -240,11 +242,11 @@ def memset(self, allocation, value, size):
"""
logging.debug("HipFunction memset called")

hip.hipMemset(allocation, value, size)

def memcpy_dtoh(self, dest, src):
"""perform a device to host memory copy
"""Perform a device to host memory copy.
:param dest: A numpy array in host memory to store the data
:type dest: numpy.ndarray
Expand All @@ -260,7 +262,7 @@ def memcpy_dtoh(self, dest, src):
hip.hipMemcpy_dtoh(dest_c, src, dest.nbytes)

def memcpy_htod(self, dest, src):
"""perform a host to device memory copy
"""Perform a host to device memory copy.
:param dest: A GPU memory allocation unit
:type dest: ctypes ptr
Expand All @@ -276,7 +278,7 @@ def memcpy_htod(self, dest, src):
hip.hipMemcpy_htod(dest, src_c, src.nbytes)

def copy_constant_memory_args(self, cmem_args):
"""adds constant memory arguments to the most recently compiled module
"""Adds constant memory arguments to the most recently compiled module.
:param cmem_args: A dictionary containing the data to be passed to the
device constant memory. The format to be used is as follows: A
Expand All @@ -298,12 +300,13 @@ def copy_constant_memory_args(self, cmem_args):
hip.hipMemcpy_htod(symbol_ptr, v_c, v.nbytes)

def copy_shared_memory_args(self, smem_args):
"""add shared memory arguments to the kernel"""
"""Add shared memory arguments to the kernel."""
logging.debug("HipFunction copy_shared_memory_args called")

self.smem_size = smem_args["size"]

def copy_texture_memory_args(self, texmem_args):
"""Copy texture memory arguments. Not yet implemented."""
logging.debug("HipFunction copy_texture_memory_args called")

raise NotImplementedError("HIP backend does not support texture memory")
Expand Down
Loading

0 comments on commit 8e20ccc

Please sign in to comment.