Skip to content

Commit

Permalink
Merge pull request #221 from stijnh/fix-snap-nearest-on-strings
Browse files Browse the repository at this point in the history
Fix `snap_to_nearest_config` on non-numeric parameters
  • Loading branch information
stijnh authored Oct 12, 2023
2 parents a4a284b + 1233407 commit 97ed8ca
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
10 changes: 8 additions & 2 deletions kernel_tuner/strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,14 @@ def snap_to_nearest_config(x, tune_params):
"""Helper func that for each param selects the closest actual value."""
params = []
for i, k in enumerate(tune_params.keys()):
values = np.array(tune_params[k])
idx = np.abs(values - x[i]).argmin()
values = tune_params[k]

# if `x[i]` is in `values`, use that value, otherwise find the closest match
if x[i] in values:
idx = values.index(x[i])
else:
idx = np.argmin([abs(v - x[i]) for v in values])

params.append(values[idx])
return params

Expand Down
5 changes: 3 additions & 2 deletions test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def test_snap_to_nearest_config():
tune_params['x'] = [0, 1, 2, 3, 4, 5]
tune_params['y'] = [0, 1, 2, 3, 4, 5]
tune_params['z'] = [0, 1, 2, 3, 4, 5]
tune_params['w'] = ['a', 'b', 'c']

x = [-5.7, 3.14, 1e6]
expected = [0, 3, 5]
x = [-5.7, 3.14, 1e6, 'b']
expected = [0, 3, 5, 'b']

answer = common.snap_to_nearest_config(x, tune_params)
assert answer == expected
Expand Down

0 comments on commit 97ed8ca

Please sign in to comment.