From ca0644e2c4419441a7fb15639d621d62c3afdef3 Mon Sep 17 00:00:00 2001 From: Riccardo Bertossa Date: Thu, 26 Oct 2023 08:29:43 +0000 Subject: [PATCH] reverted call_with_super_check changes apparently the error was generated from something else --- src/plumpy/base/utils.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/src/plumpy/base/utils.py b/src/plumpy/base/utils.py index 8e645670..f90fdd39 100644 --- a/src/plumpy/base/utils.py +++ b/src/plumpy/base/utils.py @@ -1,10 +1,8 @@ # -*- coding: utf-8 -*- -from typing import Any, Callable, Dict +from typing import Any, Callable __all__ = ['super_check', 'call_with_super_check'] -_SUPER_COUNTERS: Dict[str, int] = {} - def super_check(wrapped: Callable[..., Any]) -> Callable[..., Any]: """ @@ -13,14 +11,11 @@ def super_check(wrapped: Callable[..., Any]) -> Callable[..., Any]: """ def wrapper(self: Any, *args: Any, **kwargs: Any) -> None: - # pylint: disable=locally-disabled, global-variable-not-assigned - global _SUPER_COUNTERS - counter_name = _get_counter_name(self, wrapped) - msg = f"The function '{wrapped.__name__}' was not called through call_with_super_check ({counter_name})" - assert _SUPER_COUNTERS.get(counter_name, 0) >= 1, msg + msg = f"The function '{wrapped.__name__}' was not called through call_with_super_check" + assert getattr(self, '_called', 0) >= 1, msg wrapped(self, *args, **kwargs) - _SUPER_COUNTERS[counter_name] -= 1 - + self._called -= 1 + #the following is to show the correct name later in the call_with_super_check error message wrapper.__name__ = wrapped.__name__ return wrapper @@ -29,18 +24,9 @@ def call_with_super_check(wrapped: Callable[..., Any], *args: Any, **kwargs: Any """ Call a class method checking that all subclasses called super along the way """ - # pylint: disable=locally-disabled, global-variable-not-assigned - global _SUPER_COUNTERS self = wrapped.__self__ # type: ignore # should actually be MethodType, but mypy does not handle this - counter_name = _get_counter_name(self, wrapped) - call_count = _SUPER_COUNTERS.get(counter_name, 0) - _SUPER_COUNTERS[counter_name] = call_count + 1 + call_count = getattr(self, '_called', 0) + self._called = call_count + 1 wrapped(*args, **kwargs) msg = f"Base '{wrapped.__name__}' was not called from '{self.__class__}'\nHint: Did you forget to call the super?" - assert _SUPER_COUNTERS[counter_name] == call_count, msg - if call_count == 0: - del _SUPER_COUNTERS[counter_name] - - -def _get_counter_name(self: Any, wrapped: Callable[..., Any]) -> str: - return f'{id(self)}_called_' + wrapped.__name__ + assert self._called == call_count, msg \ No newline at end of file