Skip to content

Commit

Permalink
Fixed a bug relating to constraints checking, and added and updated t…
Browse files Browse the repository at this point in the history
…ests to detect this in the future
  • Loading branch information
fjwillemsen committed Oct 12, 2023
1 parent 30f8568 commit a4a284b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 21 deletions.
12 changes: 5 additions & 7 deletions kernel_tuner/runners/simulation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" The simulation runner for sequentially tuning the parameter space based on cached data """
"""The simulation runner for sequentially tuning the parameter space based on cached data."""
import logging
from collections import namedtuple
from time import perf_counter
Expand All @@ -10,7 +10,7 @@


class SimulationDevice(_SimulationDevice):
""" Simulated device used by simulation runner """
"""Simulated device used by simulation runner."""

@property
def name(self):
Expand All @@ -27,10 +27,10 @@ def get_environment(self):


class SimulationRunner(Runner):
""" SimulationRunner is used for tuning with a single process/thread """
"""SimulationRunner is used for tuning with a single process/thread."""

def __init__(self, kernel_source, kernel_options, device_options, iterations, observers):
""" Instantiate the SimulationRunner
"""Instantiate the SimulationRunner.
:param kernel_source: The kernel source
:type kernel_source: kernel_tuner.core.KernelSource
Expand All @@ -46,7 +46,6 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob
each kernel instance.
:type iterations: int
"""

self.quiet = device_options.quiet
self.dev = SimulationDevice(1024, dict(device_name="Simulation"), self.quiet)

Expand All @@ -66,7 +65,7 @@ def get_environment(self, tuning_options):
return env

def run(self, parameter_space, tuning_options):
""" Iterate through the entire parameter space using a single Python process
"""Iterate through the entire parameter space using a single Python process.
:param parameter_space: The parameter space as an iterable.
:type parameter_space: iterable
Expand All @@ -78,7 +77,6 @@ def run(self, parameter_space, tuning_options):
:returns: A list of dictionaries for executed kernel configurations and their
execution times.
:rtype: dict()
"""
logging.debug('simulation runner started for ' + self.kernel_options.kernel_name)

Expand Down
20 changes: 13 additions & 7 deletions kernel_tuner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def check_block_size_params_names_list(block_size_names, tune_params):
)


def check_restrictions(restrictions, params: dict, verbose: bool):
def check_restrictions(restrictions, params: dict, verbose: bool) -> bool:
"""Check whether a specific instance meets the search space restrictions."""
valid = True
if callable(restrictions):
Expand All @@ -263,14 +263,19 @@ def check_restrictions(restrictions, params: dict, verbose: bool):
if not restrict(params.values()):
valid = False
break
continue
# if it's a string, fill in the parameters and evaluate
elif isinstance(restrict, str) and not eval(replace_param_occurrences(restrict, params)):
valid = False
break
elif isinstance(restrict, str):
if not eval(replace_param_occurrences(restrict, params)):
valid = False
break
continue
# if it's a function, call it
elif callable(restrict) and not restrict(params):
valid = False
break
elif callable(restrict):
if not restrict(**params):
valid = False
break
continue
# if it's a tuple, use only the parameters in the second argument to call the restriction
elif (isinstance(restrict, tuple) and len(restrict) == 2
and callable(restrict[0]) and isinstance(restrict[1], (list, tuple))):
Expand All @@ -282,6 +287,7 @@ def check_restrictions(restrictions, params: dict, verbose: bool):
if not restrict(**selected_params):
valid = False
break
continue
# otherwise, raise an error
else:
raise ValueError(f"Unkown restriction type {type(restrict)} ({restrict})")
Expand Down
19 changes: 15 additions & 4 deletions test/test_searchspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from unittest.mock import patch

import numpy as np
from constraint import ExactSumConstraint, FunctionConstraint
from constraint import ExactSumConstraint

from kernel_tuner.interface import Options
from kernel_tuner.searchspace import Searchspace
Expand Down Expand Up @@ -37,13 +37,13 @@

# each GPU must have at least one layer and the sum of all layers must not exceed the total number of layers


def _min_func(gpu1, gpu2, gpu3, gpu4):
return min([gpu1, gpu2, gpu3, gpu4]) >= 1


# test three different types of restrictions: python-constraint, a function and a string
restrict = [ExactSumConstraint(num_layers), FunctionConstraint(_min_func)]
# test two different types of restrictions: a constraint and a callable
assert callable(_min_func)
restrict = [ExactSumConstraint(num_layers), _min_func]

# create the searchspace object
searchspace = Searchspace(tune_params, restrict, max_threads)
Expand Down Expand Up @@ -79,6 +79,17 @@ def test_internal_representation():
for index, dict_config in enumerate(searchspace.get_list_dict().keys()):
assert dict_config == searchspace.list[index]

def test_check_restrictions():
"""Test whether the outcome of restrictions is as expected when using check_restrictions."""
from kernel_tuner.util import check_restrictions

param_config_false = {'x': 1, 'y': 4, 'z': "string_1" }
param_config_true = {'x': 3, 'y': 4, 'z': "string_1" }

assert check_restrictions(simple_searchspace.restrictions, param_config_false, verbose=False) is False
assert check_restrictions(simple_searchspace.restrictions, param_config_true, verbose=False) is True


def test_against_bruteforce():
"""Tests the default Searchspace framework against bruteforcing the searchspace."""
compare_two_searchspace_objects(simple_searchspace, simple_searchspace_bruteforce)
Expand Down
3 changes: 0 additions & 3 deletions test/test_util_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,6 @@ def test_replace_param_occurrences():

def test_check_restrictions():
params = {"a": 7, "b": 4, "c": 3}
print(params.values())
print(params.keys())
restrictions = [
["a==b+c"],
["a==b+c", "b==b", "a-b==c"],
Expand All @@ -238,7 +236,6 @@ def test_check_restrictions():
# test the call returns expected
for r, e in zip(restrictions, expected):
answer = check_restrictions(r, dict(zip(params.keys(), params.values())), False)
print(answer)
assert answer == e


Expand Down

0 comments on commit a4a284b

Please sign in to comment.