Skip to content

Commit

Permalink
Add function measure_execution_time.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Sep 6, 2023
1 parent 0914b66 commit df5d67d
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion tests/value/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from __future__ import annotations

import time
from copy import deepcopy
from typing import Callable, Tuple
from functools import wraps
from logging import getLogger
from typing import Callable, Optional, Tuple, TypeVar

from pydvl.utils.types import Seed

logger = getLogger(__name__)

ReturnType = TypeVar("ReturnType")


def call_fn_multiple_seeds(
fn: Callable, *args, seeds: Tuple[Seed, ...], **kwargs
Expand All @@ -23,3 +30,37 @@ def call_fn_multiple_seeds(
A tuple of the results of the function.
"""
return tuple(fn(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds)


def measure_execution_time(
func: Callable[..., ReturnType]
) -> Callable[..., Tuple[Optional[ReturnType], float]]:
"""
Takes a function `func` and returns a function with the same input arguments and
the original return value along with the execution time.
Args:
func: The function to be measured, accepting arbitrary arguments and returning
any type.
Returns:
A wrapped function that, when called, returns a tuple containing the original
function's result and its execution time in seconds. The decorated function
will have the same input arguments and return type as the original function.
"""

@wraps(func)
def wrapper(*args, **kwargs) -> Tuple[Optional[ReturnType], float]:
result = None
start_time = time.time()
try:
result = func(*args, **kwargs)
except Exception as e:
logger.error(f"Error in {func.__name__}: {e}")
finally:
end_time = time.time()
execution_time = end_time - start_time
logger.info(f"{func.__name__} took {execution_time:.5f} seconds.")
return result, execution_time

return wrapper

0 comments on commit df5d67d

Please sign in to comment.