From abbab0383bb811a787d305b2333c0264fe497a13 Mon Sep 17 00:00:00 2001 From: Riccardo Bertossa Date: Mon, 25 Sep 2023 09:40:24 +0200 Subject: [PATCH] cleaning up of the code calling del on dict items of the call_with_super_check's global when the counter reaches zero calling del on dict items of the ProcessListener's global implemented in the test suite --- src/plumpy/base/utils.py | 4 ++-- test/utils.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/plumpy/base/utils.py b/src/plumpy/base/utils.py index 7397aa19..8e645670 100644 --- a/src/plumpy/base/utils.py +++ b/src/plumpy/base/utils.py @@ -20,7 +20,6 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> None: assert _SUPER_COUNTERS.get(counter_name, 0) >= 1, msg wrapped(self, *args, **kwargs) _SUPER_COUNTERS[counter_name] -= 1 - #print('finished:', counter_name, _SUPER_COUNTERS[counter_name]) wrapper.__name__ = wrapped.__name__ return wrapper @@ -36,10 +35,11 @@ def call_with_super_check(wrapped: Callable[..., Any], *args: Any, **kwargs: Any counter_name = _get_counter_name(self, wrapped) call_count = _SUPER_COUNTERS.get(counter_name, 0) _SUPER_COUNTERS[counter_name] = call_count + 1 - #print(counter_name, call_count) wrapped(*args, **kwargs) msg = f"Base '{wrapped.__name__}' was not called from '{self.__class__}'\nHint: Did you forget to call the super?" assert _SUPER_COUNTERS[counter_name] == call_count, msg + if call_count == 0: + del _SUPER_COUNTERS[counter_name] def _get_counter_name(self: Any, wrapped: Callable[..., Any]) -> str: diff --git a/test/utils.py b/test/utils.py index c528c52f..e3017c90 100644 --- a/test/utils.py +++ b/test/utils.py @@ -185,7 +185,7 @@ def run(self): self.out('test', 5) return process_states.Continue(self.middle_step) - def middle_step(self,): + def middle_step(self): return process_states.Continue(self.last_step) def last_step(self): @@ -269,6 +269,14 @@ class ProcessSaver(plumpy.ProcessListener): Save the instance state of a process each time it is about to enter a new state """ + def __del__(self): + global _ProcessSaver_Saver + global _ProcessSaverProcReferences + if _ProcessSaverProcReferences is not None and id(self) in _ProcessSaverProcReferences: + del _ProcessSaverProcReferences[id(self)] + if _ProcessSaver_Saver is not None and id(self) in _ProcessSaver_Saver: + del _ProcessSaver_Saver[id(self)] + def get_process(self): global _ProcessSaverProcReferences return _ProcessSaverProcReferences[id(self)] @@ -365,11 +373,12 @@ def check_process_against_snapshots(loop, proc_class, snapshots): """ for i, bundle in zip(list(range(0, len(snapshots))), snapshots): loaded = bundle.unbundle(plumpy.LoadSaveContext(loop=loop)) - #the process listeners are persisted + # the process listeners are persisted saver = list(loaded._event_helper._listeners)[0] assert type(saver) == ProcessSaver - #process cannot be persisted because of a circular reference. So we load it there - #also the saver is not persisted for the same reason. We load it manually + # the process reference inside this particular implementation of process listener + # cannot be persisted because of a circular reference. So we load it there + # also the saver is not persisted for the same reason. We load it manually saver.init_not_persistent(loaded) saver.capture() @@ -416,7 +425,7 @@ def compare_value(bundle1, bundle2, v1, v2, exclude=None): for vv1, vv2 in zip(v1, v2): compare_value(bundle1, bundle2, vv1, vv2, exclude) elif isinstance(v1, set) and isinstance(v2, set) and len(v1) == len(v2) and len(v1) <= 1: - #TODO: implement sets with more than one element + # TODO: implement sets with more than one element compare_value(bundle1, bundle2, list(v1), list(v2), exclude) elif isinstance(v1, set) and isinstance(v2, set): raise NotImplementedError('Comparison between sets not implemented')