Skip to content

Commit

Permalink
Merge pull request #13 from HaoZeke/earlySkip
Browse files Browse the repository at this point in the history
ENH: Add a skip decorator
  • Loading branch information
mattip authored Jul 11, 2023
2 parents 4ee482b + 2128416 commit fc1f43a
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 45 deletions.
140 changes: 100 additions & 40 deletions asv_runner/benchmarks/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _get_sourceline_info(obj, basedir):
return ""


def check_num_args(root, benchmark_name, func, min_num_args, max_num_args=None):
def _check_num_args(root, benchmark_name, func, min_num_args, max_num_args=None):
"""
Verifies if the function under benchmarking accepts a correct number of arguments.
Expand Down Expand Up @@ -334,6 +334,79 @@ def _repr_no_address(obj):
return result


def _validate_params(params, param_names, name):
"""
Validates the params and param_names attributes and returns validated lists.
#### Parameters
**params** (`list`)
: List of parameters for the function to be benchmarked.
**param_names** (`list`)
: List of names for the parameters.
**name** (`str`)
: The name of the benchmark.
#### Returns
**params**, **param_names** (`list`, `list`)
: The validated parameter and parameter name lists.
"""

try:
param_names = [str(x) for x in list(param_names)]
except ValueError:
raise ValueError(f"{name}.param_names is not a list of strings")

try:
params = list(params)
except ValueError:
raise ValueError(f"{name}.params is not a list")

if params and not isinstance(params[0], (tuple, list)):
params = [params]
else:
params = [list(entry) for entry in params]

if len(param_names) != len(params):
param_names = param_names[: len(params)]
param_names += [
"param%d" % (k + 1,) for k in range(len(param_names), len(params))
]

return params, param_names


def _unique_param_ids(params):
"""
Processes the params list to handle duplicate names within parameter sets,
ensuring unique IDs.
#### Parameters
**params** (`list`)
: List of parameters. Each entry is a list representing a set of parameters.
#### Returns
**params** (`list`)
: List of parameters with duplicate names within each set handled.
If there are duplicate names, they are renamed with a numerical suffix to
ensure unique IDs.
"""

params = [[_repr_no_address(item) for item in entry] for entry in params]
for i, param in enumerate(params):
if len(param) != len(set(param)):
counter = Counter(param)
dupe_dict = {name: 0 for name, count in counter.items() if count > 1}
for j in range(len(param)):
name = param[j]
if name in dupe_dict:
param[j] = f"{name} ({dupe_dict[name]})"
dupe_dict[name] += 1
params[i] = param
return params


class Benchmark:
"""
Class representing a single benchmark. The class encapsulates
Expand Down Expand Up @@ -409,6 +482,10 @@ def __init__(self, name, func, attr_sources):
**params** (`list`)
: The list of parameters with unique representations for exporting.
**_skip_tuples** (`list`)
: List of tuples representing parameter combinations to be skipped
before calling the setup method.
#### Raises
**ValueError**
: If `param_names` or `_params` is not a list or if the number of
Expand Down Expand Up @@ -437,44 +514,15 @@ def __init__(self, name, func, attr_sources):
self.param_names = _get_first_attr(attr_sources, "param_names", [])
self._current_params = ()

# Enforce params format
try:
self.param_names = [str(x) for x in list(self.param_names)]
except ValueError:
raise ValueError(f"{name}.param_names is not a list of strings")

try:
self._params = list(self._params)
except ValueError:
raise ValueError(f"{name}.params is not a list")

if self._params and not isinstance(self._params[0], (tuple, list)):
# Accept a single list for one parameter only
self._params = [self._params]
else:
self._params = [list(entry) for entry in self._params]
self._params, self.param_names = _validate_params(
self._params, self.param_names, self.name
)

if len(self.param_names) != len(self._params):
self.param_names = self.param_names[: len(self._params)]
self.param_names += [
"param%d" % (k + 1,)
for k in range(len(self.param_names), len(self._params))
]
# Fetch skip parameters
self._skip_tuples = _get_first_attr(attr_sources, "skip_params", [])

# Exported parameter representations
self.params = [
[_repr_no_address(item) for item in entry] for entry in self._params
]
for i, param in enumerate(self.params):
if len(param) != len(set(param)):
counter = Counter(param)
dupe_dict = {name: 0 for name, count in counter.items() if count > 1}
for j in range(len(param)):
name = param[j]
if name in dupe_dict:
param[j] = f"{name} ({dupe_dict[name]})"
dupe_dict[name] += 1
self.params[i] = param
self.params = _unique_param_ids(self._params)

def __repr__(self):
return f"<{self.__class__.__name__} {self.name}>"
Expand Down Expand Up @@ -548,22 +596,22 @@ def check(self, root):
max_num_args = min_num_args

if self.setup_cache_key is not None:
ok = ok and check_num_args(
ok = ok and _check_num_args(
root, f"{self.name}: setup_cache", self._setup_cache, 0
)
max_num_args += 1

for setup in self._setups:
ok = ok and check_num_args(
ok = ok and _check_num_args(
root, f"{self.name}: setup", setup, min_num_args, max_num_args
)

ok = ok and check_num_args(
ok = ok and _check_num_args(
root, f"{self.name}: call", self.func, min_num_args, max_num_args
)

for teardown in self._teardowns:
ok = ok and check_num_args(
ok = ok and _check_num_args(
root,
f"{self.name}: teardown",
teardown,
Expand All @@ -574,6 +622,9 @@ def check(self, root):
return ok

def do_setup(self):
if tuple(self._current_params) in self._skip_tuples:
# Skip
return True
try:
for setup in self._setups:
setup(*self._current_params)
Expand All @@ -591,6 +642,9 @@ def redo_setup(self):
self.do_setup()

def do_teardown(self):
if tuple(self._current_params) in self._skip_tuples:
# Skip
return
for teardown in self._teardowns:
teardown(*self._current_params)

Expand All @@ -599,6 +653,9 @@ def do_setup_cache(self):
return self._setup_cache()

def do_run(self):
if tuple(self._current_params) in self._skip_tuples:
# Skip
return
return self.run(*self._current_params)

def do_profile(self, filename=None):
Expand All @@ -623,6 +680,9 @@ def do_profile(self, filename=None):
raised. If a `filename` is provided, the profiling results will be saved
to that file.
"""
if tuple(self._current_params) in self._skip_tuples:
# Skip
return

def method_caller():
run(*params) # noqa:F821 undefined name
Expand Down
64 changes: 64 additions & 0 deletions asv_runner/benchmarks/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import functools


def skip_for_params(skip_params_list):
"""
Decorator to set skip parameters for a benchmark function.
#### Parameters
**skip_params_dict** (`dict`):
A dictionary specifying the skip parameters for the benchmark function.
The keys represent the parameter names, and the values can be a single value
or a list of values.
#### Returns
**decorator** (`function`):
A decorator function that sets the skip parameters for the benchmark function.
#### Notes
The `skip_benchmark_for_params` decorator can be used to specify skip parameters
for a benchmark function. The skip parameters define combinations of values that
should be skipped when running the benchmark. The decorated function's `skip_params`
attribute will be set with the provided skip parameters, which will be used during
the benchmarking process.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

setattr(wrapper, "skip_params", skip_params_list)
return wrapper

return decorator


def skip_benchmark(func):
"""
Decorator to mark a function as skipped for benchmarking.
#### Parameters
**func** (function)
: The function to be marked as skipped.
#### Returns
**wrapper** (function)
: A wrapped function that is marked to be skipped for benchmarking.
#### Notes
The `skip_benchmark` decorator can be used to mark a specific function as
skipped for benchmarking. When the decorated function is encountered during
benchmarking, it will be skipped and not included in the benchmarking
process.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

setattr(wrapper, "skip_benchmark", True)
return wrapper


__all__ = [skip_for_params, skip_benchmark]
15 changes: 10 additions & 5 deletions asv_runner/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,19 @@ def _get_benchmark(attr_name, module, klass, func):
#### Returns
**benchmark** (Benchmark instance or None)
: A benchmark instance with the name of the benchmark, the function to be
benchmarked, and its sources. Returns None if no matching benchmark is found.
benchmarked, and its sources. Returns None if no matching benchmark is found
or the function is marked to be skipped.
#### Notes
The function tries to get the `benchmark_name` from `func`. If it fails,
it uses `attr_name` to match with the name regex in the benchmark types.
If a match is found, it creates a new benchmark instance and returns it.
If no match is found, it returns None.
The function tries to get the `benchmark_name` from `func`. If it fails, it
uses `attr_name` to match with the name regex in the benchmark types. If a
match is found, it creates a new benchmark instance and returns it. If no
match is found or the function is marked to be skipped, it returns None.
"""
# Check if the function has been marked to be skipped
if getattr(func, "skip_benchmark", False):
return

try:
name = func.benchmark_name
except AttributeError:
Expand Down

0 comments on commit fc1f43a

Please sign in to comment.