Skip to content

Commit

Permalink
reverted call_with_super_check changes
Browse files Browse the repository at this point in the history
apparently the error was generated from something else
  • Loading branch information
rikigigi committed Oct 26, 2023
1 parent abbab03 commit ca0644e
Showing 1 changed file with 8 additions and 22 deletions.
30 changes: 8 additions & 22 deletions src/plumpy/base/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# -*- coding: utf-8 -*-
from typing import Any, Callable, Dict
from typing import Any, Callable

__all__ = ['super_check', 'call_with_super_check']

_SUPER_COUNTERS: Dict[str, int] = {}


def super_check(wrapped: Callable[..., Any]) -> Callable[..., Any]:
"""
Expand All @@ -13,14 +11,11 @@ def super_check(wrapped: Callable[..., Any]) -> Callable[..., Any]:
"""

def wrapper(self: Any, *args: Any, **kwargs: Any) -> None:
# pylint: disable=locally-disabled, global-variable-not-assigned
global _SUPER_COUNTERS
counter_name = _get_counter_name(self, wrapped)
msg = f"The function '{wrapped.__name__}' was not called through call_with_super_check ({counter_name})"
assert _SUPER_COUNTERS.get(counter_name, 0) >= 1, msg
msg = f"The function '{wrapped.__name__}' was not called through call_with_super_check"
assert getattr(self, '_called', 0) >= 1, msg
wrapped(self, *args, **kwargs)
_SUPER_COUNTERS[counter_name] -= 1

self._called -= 1
#the following is to show the correct name later in the call_with_super_check error message
wrapper.__name__ = wrapped.__name__
return wrapper

Expand All @@ -29,18 +24,9 @@ def call_with_super_check(wrapped: Callable[..., Any], *args: Any, **kwargs: Any
"""
Call a class method checking that all subclasses called super along the way
"""
# pylint: disable=locally-disabled, global-variable-not-assigned
global _SUPER_COUNTERS
self = wrapped.__self__ # type: ignore # should actually be MethodType, but mypy does not handle this
counter_name = _get_counter_name(self, wrapped)
call_count = _SUPER_COUNTERS.get(counter_name, 0)
_SUPER_COUNTERS[counter_name] = call_count + 1
call_count = getattr(self, '_called', 0)
self._called = call_count + 1
wrapped(*args, **kwargs)
msg = f"Base '{wrapped.__name__}' was not called from '{self.__class__}'\nHint: Did you forget to call the super?"
assert _SUPER_COUNTERS[counter_name] == call_count, msg
if call_count == 0:
del _SUPER_COUNTERS[counter_name]


def _get_counter_name(self: Any, wrapped: Callable[..., Any]) -> str:
return f'{id(self)}_called_' + wrapped.__name__
assert self._called == call_count, msg

0 comments on commit ca0644e

Please sign in to comment.