Skip to content

Commit

Permalink
Fixed a logic bug in parsing multidimensional sizes, and added a test…
Browse files Browse the repository at this point in the history
… to catch the bug.
  • Loading branch information
isazi committed Jun 6, 2024
1 parent 7fcabad commit 8336cd0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
16 changes: 10 additions & 6 deletions kernel_tuner/utils/directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,16 @@ def parse_size(size: Any, preprocessor: list = None, dimensions: dict = None) ->
except ValueError:
# If size cannot be natively converted to an int, we try to derive it from the preprocessor
if preprocessor is not None:
if "," in size:
ret_size = 1
for dimension in size.split(","):
ret_size *= find_size_in_preprocessor(dimension, preprocessor)
else:
ret_size = find_size_in_preprocessor(size, preprocessor)
try:
if "," in size:
ret_size = 1
for dimension in size.split(","):
ret_size *= find_size_in_preprocessor(dimension, preprocessor)
else:
ret_size = find_size_in_preprocessor(size, preprocessor)
except TypeError:
# preprocessor is available but does not contain the dimensions
pass
# If size cannot be natively converted, nor retrieved from the preprocessor, we check user provided values
if dimensions is not None:
if size in dimensions.keys():
Expand Down
1 change: 1 addition & 0 deletions test/utils/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_parse_size():
assert parse_size("m", ["#define size 512\n"], {"n": 32}) is None
assert parse_size("rows,cols", dimensions={"rows": 16, "cols": 8}) == 128
assert parse_size("n_rows,n_cols", ["#define n_cols 16\n", "#define n_rows 32\n"]) == 512
assert parse_size("rows,cols", [], dimensions={"rows": 16, "cols": 8}) == 128


def test_wrap_timing():
Expand Down

0 comments on commit 8336cd0

Please sign in to comment.