diff --git a/src/pydvl/influence/array.py b/src/pydvl/influence/array.py index a82b380b8..d549eee9d 100644 --- a/src/pydvl/influence/array.py +++ b/src/pydvl/influence/array.py @@ -7,6 +7,7 @@ using the Zarr library. """ +import logging from abc import ABC, abstractmethod from typing import Callable, Generator, Generic, List, Optional, Tuple, Union @@ -14,6 +15,7 @@ from numpy.typing import NDArray from zarr.storage import StoreLike +from ..utils import log_duration from .base_influence_function_model import TensorType @@ -119,6 +121,7 @@ def __init__( ): self.generator_factory = generator_factory + @log_duration(log_level=logging.INFO) def compute(self, aggregator: Optional[SequenceAggregator] = None): """ Computes and optionally aggregates the chunks of the array using the provided @@ -139,6 +142,7 @@ def compute(self, aggregator: Optional[SequenceAggregator] = None): aggregator = ListAggregator() return aggregator(self.generator_factory()) + @log_duration(log_level=logging.INFO) def to_zarr( self, path_or_url: Union[str, StoreLike], @@ -223,6 +227,7 @@ def __init__( ): self.generator_factory = generator_factory + @log_duration(log_level=logging.INFO) def compute(self, aggregator: Optional[NestedSequenceAggregator] = None): """ Computes and optionally aggregates the chunks of the array using the provided @@ -244,6 +249,7 @@ def compute(self, aggregator: Optional[NestedSequenceAggregator] = None): aggregator = NestedListAggregator() return aggregator(self.generator_factory()) + @log_duration(log_level=logging.INFO) def to_zarr( self, path_or_url: Union[str, StoreLike],