From df5d67d98369e36029c7f1e0e16ba4a8b38c44c9 Mon Sep 17 00:00:00 2001 From: Markus Semmler Date: Wed, 6 Sep 2023 02:21:34 +0200 Subject: [PATCH] Add function `measure_execution_time`. --- tests/value/utils.py | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/value/utils.py b/tests/value/utils.py index 7c38e344f..6681a5835 100644 --- a/tests/value/utils.py +++ b/tests/value/utils.py @@ -1,10 +1,17 @@ from __future__ import annotations +import time from copy import deepcopy -from typing import Callable, Tuple +from functools import wraps +from logging import getLogger +from typing import Callable, Optional, Tuple, TypeVar from pydvl.utils.types import Seed +logger = getLogger(__name__) + +ReturnType = TypeVar("ReturnType") + def call_fn_multiple_seeds( fn: Callable, *args, seeds: Tuple[Seed, ...], **kwargs @@ -23,3 +30,37 @@ def call_fn_multiple_seeds( A tuple of the results of the function. """ return tuple(fn(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds) + + +def measure_execution_time( + func: Callable[..., ReturnType] +) -> Callable[..., Tuple[Optional[ReturnType], float]]: + """ + Takes a function `func` and returns a function with the same input arguments and + the original return value along with the execution time. + + Args: + func: The function to be measured, accepting arbitrary arguments and returning + any type. + + Returns: + A wrapped function that, when called, returns a tuple containing the original + function's result and its execution time in seconds. The decorated function + will have the same input arguments and return type as the original function. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> Tuple[Optional[ReturnType], float]: + result = None + start_time = time.time() + try: + result = func(*args, **kwargs) + except Exception as e: + logger.error(f"Error in {func.__name__}: {e}") + finally: + end_time = time.time() + execution_time = end_time - start_time + logger.info(f"{func.__name__} took {execution_time:.5f} seconds.") + return result, execution_time + + return wrapper