Skip to content

Commit

Permalink
Fixing what is (probably) a long standing bug, observers were ignored…
Browse files Browse the repository at this point in the history
… for the Compiler backend.
  • Loading branch information
isazi committed Sep 17, 2024
1 parent 939f0c3 commit 06c6074
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 108 deletions.
1 change: 0 additions & 1 deletion kernel_tuner/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def __init__(self, iterations=7, compiler_options=None, compiler=None, observers
self.lib = None
self.using_openmp = False
self.using_openacc = False
self.observers = [CompilerRuntimeObserver(self)]
self.last_result = None

if self.compiler == "g++":
Expand Down
150 changes: 43 additions & 107 deletions kernel_tuner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ def __init__(self, kernel_name, kernel_sources, lang, defines=None):
self.defines = defines
if lang is None:
if callable(self.kernel_sources[0]):
raise TypeError(
"Please specify language when using a code generator function"
)
raise TypeError("Please specify language when using a code generator function")
kernel_string = self.get_kernel_string(0)
lang = util.detect_language(kernel_string)

Expand All @@ -109,9 +107,7 @@ def get_kernel_string(self, index=0, params=None):
kernel_source = self.kernel_sources[index]
return util.get_kernel_string(kernel_source, params)

def prepare_list_of_files(
self, kernel_name, params, grid, threads, block_size_names
):
def prepare_list_of_files(self, kernel_name, params, grid, threads, block_size_names):
"""prepare the kernel string along with any additional files
The first file in the list is allowed to include or read in the others
Expand Down Expand Up @@ -147,9 +143,7 @@ def prepare_list_of_files(

for i, f in enumerate(self.kernel_sources):
if i > 0 and not util.looks_like_a_filename(f):
raise ValueError(
"When passing multiple kernel sources, the secondary entries must be filenames"
)
raise ValueError("When passing multiple kernel sources, the secondary entries must be filenames")

ks = self.get_kernel_string(i, params)
# add preprocessor statements
Expand Down Expand Up @@ -183,9 +177,7 @@ def prepare_list_of_files(

def get_user_suffix(self, index=0):
"""Get the suffix of the kernel filename, if the user specified one. Return None otherwise."""
if util.looks_like_a_filename(self.kernel_sources[index]) and (
"." in self.kernel_sources[index]
):
if util.looks_like_a_filename(self.kernel_sources[index]) and ("." in self.kernel_sources[index]):
return "." + self.kernel_sources[index].split(".")[-1]
return None

Expand Down Expand Up @@ -214,13 +206,9 @@ def check_argument_lists(self, kernel_name, arguments):
"""
for i, f in enumerate(self.kernel_sources):
if not callable(f):
util.check_argument_list(
kernel_name, self.get_kernel_string(i), arguments
)
util.check_argument_list(kernel_name, self.get_kernel_string(i), arguments)
else:
logging.debug(
"Checking of arguments list not supported yet for code generators."
)
logging.debug("Checking of arguments list not supported yet for code generators.")


class DeviceInterface(object):
Expand Down Expand Up @@ -304,6 +292,7 @@ def __init__(
compiler=compiler,
compiler_options=compiler_options,
iterations=iterations,
observers=observers,
)
elif lang.upper() == "HIP":
dev = HipFunctions(
Expand All @@ -313,7 +302,9 @@ def __init__(
observers=observers,
)
else:
raise ValueError("Sorry, support for languages other than CUDA, OpenCL, HIP, C, and Fortran is not implemented yet")
raise ValueError(
"Sorry, support for languages other than CUDA, OpenCL, HIP, C, and Fortran is not implemented yet"
)
self.dev = dev

# look for NVMLObserver in observers, if present, enable special tunable parameters through nvml
Expand Down Expand Up @@ -443,9 +434,7 @@ def benchmark(self, func, gpu_args, instance, verbose, objective, skip_nvml_sett
obs.results = result
duration = max(duration, obs.continuous_duration)

self.benchmark_continuous(
func, gpu_args, instance.threads, instance.grid, result, duration
)
self.benchmark_continuous(func, gpu_args, instance.threads, instance.grid, result, duration)

except Exception as e:
# some launches may fail because too many registers are required
Expand All @@ -458,9 +447,7 @@ def benchmark(self, func, gpu_args, instance, verbose, objective, skip_nvml_sett
"INVALID_WORK_GROUP_SIZE",
]
if any([skip_str in str(e) for skip_str in skippable_exceptions]):
logging.debug(
"benchmark fails due to runtime failure too many resources required"
)
logging.debug("benchmark fails due to runtime failure too many resources required")
if verbose:
print(
f"skipping config {util.get_instance_string(instance.params)} reason: too many resources requested for launch"
Expand All @@ -472,13 +459,11 @@ def benchmark(self, func, gpu_args, instance, verbose, objective, skip_nvml_sett
raise e
return result

def check_kernel_output(
self, func, gpu_args, instance, answer, atol, verify, verbose
):
def check_kernel_output(self, func, gpu_args, instance, answer, atol, verify, verbose):
"""runs the kernel once and checks the result against answer"""
logging.debug("check_kernel_output")

#if not using custom verify function, check if the length is the same
# if not using custom verify function, check if the length is the same
if answer:
if len(instance.arguments) != len(answer):
raise TypeError("The length of argument list and provided results do not match.")
Expand Down Expand Up @@ -507,7 +492,7 @@ def check_kernel_output(
self.dev.memcpy_dtoh(result_host[-1], gpu_args[i])
elif isinstance(arg, torch.Tensor) and isinstance(answer[i], torch.Tensor):
if not answer[i].is_cuda:
#if the answer is on the host, copy gpu output to host as well
# if the answer is on the host, copy gpu output to host as well
result_host.append(torch.zeros_like(answer[i]))
self.dev.memcpy_dtoh(result_host[-1], gpu_args[i].tensor)
else:
Expand Down Expand Up @@ -535,10 +520,7 @@ def check_kernel_output(
correct = True

if not correct:
raise RuntimeError(
"Kernel result verification failed for: "
+ util.get_config_string(instance.params)
)
raise RuntimeError("Kernel result verification failed for: " + util.get_config_string(instance.params))

def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options, to):
# reset previous timers
Expand All @@ -552,7 +534,7 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options,
# Compile and benchmark a kernel instance based on kernel strings and parameters
instance_string = util.get_instance_string(params)

logging.debug('compile_and_benchmark ' + instance_string)
logging.debug("compile_and_benchmark " + instance_string)

instance = self.create_kernel_instance(kernel_source, kernel_options, params, verbose)
if isinstance(instance, util.ErrorConfig):
Expand All @@ -570,9 +552,7 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options,
else:
# add shared memory arguments to compiled module
if kernel_options.smem_args is not None:
self.dev.copy_shared_memory_args(
util.get_smem_args(kernel_options.smem_args, params)
)
self.dev.copy_shared_memory_args(util.get_smem_args(kernel_options.smem_args, params))
# add constant memory arguments to compiled module
if kernel_options.cmem_args is not None:
self.dev.copy_constant_memory_args(kernel_options.cmem_args)
Expand All @@ -586,12 +566,8 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options,
# test kernel for correctness
if func and (to.answer or to.verify or self.output_observers):
start_verification = time.perf_counter()
self.check_kernel_output(
func, gpu_args, instance, to.answer, to.atol, to.verify, verbose
)
last_verification_time = 1000 * (
time.perf_counter() - start_verification
)
self.check_kernel_output(func, gpu_args, instance, to.answer, to.atol, to.verify, verbose)
last_verification_time = 1000 * (time.perf_counter() - start_verification)

# benchmark
if func:
Expand All @@ -607,10 +583,7 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options,
except Exception as e:
# dump kernel sources to temp file
temp_filenames = instance.prepare_temp_files_for_error_msg()
print(
"Error while compiling or benchmarking, see source files: "
+ " ".join(temp_filenames)
)
print("Error while compiling or benchmarking, see source files: " + " ".join(temp_filenames))
raise e

# clean up any temporary files, if no error occured
Expand Down Expand Up @@ -639,9 +612,7 @@ def compile_kernel(self, instance, verbose):
"local memory limit exceeded",
]
if any(msg in str(e) for msg in shared_mem_error_messages):
logging.debug(
"compile_kernel failed due to kernel using too much shared memory"
)
logging.debug("compile_kernel failed due to kernel using too much shared memory")
if verbose:
print(
f"skipping config {util.get_instance_string(instance.params)} reason: too much shared memory used"
Expand All @@ -654,7 +625,7 @@ def compile_kernel(self, instance, verbose):

@staticmethod
def preprocess_gpu_arguments(old_arguments, params):
""" Get a flat list of arguments based on the configuration given by `params` """
"""Get a flat list of arguments based on the configuration given by `params`"""
return _preprocess_gpu_arguments(old_arguments, params)

def copy_shared_memory_args(self, smem_args):
Expand Down Expand Up @@ -690,9 +661,7 @@ def create_kernel_instance(self, kernel_source, kernel_options, params, verbose)
)
if np.prod(threads) > self.dev.max_threads:
if verbose:
print(
f"skipping config {util.get_instance_string(params)} reason: too many threads per block"
)
print(f"skipping config {util.get_instance_string(params)} reason: too many threads per block")
return util.InvalidConfig()

# obtain the kernel_string and prepare additional files, if any
Expand All @@ -711,7 +680,7 @@ def create_kernel_instance(self, kernel_source, kernel_options, params, verbose)
# Preprocess GPU arguments. Require for handling `Tunable` arguments
arguments = _preprocess_gpu_arguments(kernel_options.arguments, params)

#collect everything we know about this instance and return it
# collect everything we know about this instance and return it
return KernelInstance(name, kernel_source, kernel_string, temp_files, threads, grid, params, arguments)

def get_environment(self):
Expand Down Expand Up @@ -758,12 +727,8 @@ def run_kernel(self, func, gpu_args, instance):
try:
self.dev.run_kernel(func, gpu_args, instance.threads, instance.grid)
except Exception as e:
if "too many resources requested for launch" in str(
e
) or "OUT_OF_RESOURCES" in str(e):
logging.debug(
"ignoring runtime failure due to too many resources required"
)
if "too many resources requested for launch" in str(e) or "OUT_OF_RESOURCES" in str(e):
logging.debug("ignoring runtime failure due to too many resources required")
return False
else:
logging.debug("encountered unexpected runtime failure: " + str(e))
Expand All @@ -772,7 +737,7 @@ def run_kernel(self, func, gpu_args, instance):


def _preprocess_gpu_arguments(old_arguments, params):
""" Get a flat list of arguments based on the configuration given by `params` """
"""Get a flat list of arguments based on the configuration given by `params`"""
new_arguments = []

for argument in old_arguments:
Expand All @@ -789,15 +754,11 @@ def _default_verify_function(instance, answer, result_host, atol, verbose):

# first check if the length is the same
if len(instance.arguments) != len(answer):
raise TypeError(
"The length of argument list and provided results do not match."
)
raise TypeError("The length of argument list and provided results do not match.")
# for each element in the argument list, check if the types match
for i, arg in enumerate(instance.arguments):
if answer[i] is not None: # skip None elements in the answer list
if isinstance(answer[i], (np.ndarray, cp.ndarray)) and isinstance(
arg, (np.ndarray, cp.ndarray)
):
if isinstance(answer[i], (np.ndarray, cp.ndarray)) and isinstance(arg, (np.ndarray, cp.ndarray)):
if answer[i].dtype != arg.dtype:
raise TypeError(
f"Element {i} of the expected results list is not of the same dtype as the kernel output: "
Expand Down Expand Up @@ -845,16 +806,14 @@ def _default_verify_function(instance, answer, result_host, atol, verbose):
)
else:
# either answer[i] and argument have different types or answer[i] is not a numpy type
if not isinstance(
answer[i], (np.ndarray, cp.ndarray, torch.Tensor)
) or not isinstance(answer[i], np.number):
if not isinstance(answer[i], (np.ndarray, cp.ndarray, torch.Tensor)) or not isinstance(
answer[i], np.number
):
raise TypeError(
f"Element {i} of expected results list is not a numpy/cupy ndarray, torch Tensor or numpy scalar."
)
else:
raise TypeError(
f"Element {i} of expected results list and kernel arguments have different types."
)
raise TypeError(f"Element {i} of expected results list and kernel arguments have different types.")

def _ravel(a):
if hasattr(a, "ravel") and len(a.shape) > 1:
Expand All @@ -874,26 +833,15 @@ def _flatten(a):
expected = _flatten(expected)
if any([isinstance(array, cp.ndarray) for array in [expected, result]]):
output_test = cp.allclose(expected, result, atol=atol)
elif isinstance(expected, torch.Tensor) and isinstance(
result, torch.Tensor
):
elif isinstance(expected, torch.Tensor) and isinstance(result, torch.Tensor):
output_test = torch.allclose(expected, result, atol=atol)
else:
output_test = np.allclose(expected, result, atol=atol)

if not output_test and verbose:
print(
"Error: "
+ util.get_config_string(instance.params)
+ " detected during correctness check"
)
print(
"this error occured when checking value of the %oth kernel argument"
% (i,)
)
print(
"Printing kernel output and expected result, set verbose=False to suppress this debug print"
)
print("Error: " + util.get_config_string(instance.params) + " detected during correctness check")
print("this error occured when checking value of the %oth kernel argument" % (i,))
print("Printing kernel output and expected result, set verbose=False to suppress this debug print")
np.set_printoptions(edgeitems=50)
print("Kernel output:")
print(result)
Expand Down Expand Up @@ -928,11 +876,7 @@ def apply_template_typenames(type_list, templated_typenames):
def replace_typename_token(matchobj):
"""function for a whitespace preserving token regex replace"""
# replace only the match, leaving the whitespace around it as is
return (
matchobj.group(1)
+ templated_typenames[matchobj.group(2)]
+ matchobj.group(3)
)
return matchobj.group(1) + templated_typenames[matchobj.group(2)] + matchobj.group(3)

for i, arg_type in enumerate(type_list):
for k, v in templated_typenames.items():
Expand Down Expand Up @@ -963,25 +907,19 @@ def wrap_templated_kernel(kernel_string, kernel_name):
# relatively strict regex that does not allow nested template parameters like vector<TF>
# within the template parameter list
regex = (
r"template\s*<([^>]*?)>\s*__global__\s+void\s+(__launch_bounds__\([^\)]+?\)\s+)?"
+ name
+ r"\s*\((.*?)\)\s*\{"
r"template\s*<([^>]*?)>\s*__global__\s+void\s+(__launch_bounds__\([^\)]+?\)\s+)?" + name + r"\s*\((.*?)\)\s*\{"
)
match = re.search(regex, kernel_string, re.S)
if not match:
raise ValueError("could not find templated kernel definition")

template_parameters = match.group(1).split(",")
argument_list = match.group(3).split(",")
argument_list = [
s.strip() for s in argument_list
] # remove extra whitespace around 'type name' strings
argument_list = [s.strip() for s in argument_list] # remove extra whitespace around 'type name' strings

type_list, name_list = split_argument_list(argument_list)

templated_typenames = get_templated_typenames(
template_parameters, template_arguments
)
templated_typenames = get_templated_typenames(template_parameters, template_arguments)
apply_template_typenames(type_list, templated_typenames)

# replace __global__ with __device__ in the templated kernel definition
Expand All @@ -995,9 +933,7 @@ def wrap_templated_kernel(kernel_string, kernel_name):
launch_bounds = match.group(2)

# generate code for the compile-time template instantiation
template_instantiation = (
f"template __device__ void {kernel_name}(" + ", ".join(type_list) + ");\n"
)
template_instantiation = f"template __device__ void {kernel_name}(" + ", ".join(type_list) + ");\n"

# generate code for the wrapper kernel
new_arg_list = ", ".join([" ".join((a, b)) for a, b in zip(type_list, name_list)])
Expand Down

0 comments on commit 06c6074

Please sign in to comment.