From 820d7aa2f2256b8ba14276a0c4b9930d14b2c9d1 Mon Sep 17 00:00:00 2001 From: Riccardo Bertossa Date: Tue, 19 Sep 2023 13:10:26 +0200 Subject: [PATCH] Fixing the test apparently there was an issue with the logic of super_check: the variable to store the counter was shared between all functions, but this was not working when multiple functions calls super_check, and use it again inside at some points. Now we store the counters in a global dict to not pollute the class with a lot of names. This can be done better. There was a circular reference issue in the test listener that was storing a reference to the process inside it, making its serialization impossible. To fix the tests an ugli hack was used. Storing the reference to the process outside the class in a global dict using id as keys. Some more ugly hacks are needed to check correctly the equality of two processes. We must ignore the fact that the instances if the listener are different. --- src/plumpy/base/utils.py | 29 +++++++++++++----- src/plumpy/event_helper.py | 2 +- test/test_processes.py | 7 +++-- test/utils.py | 60 +++++++++++++++++++++++++++++++++----- 4 files changed, 80 insertions(+), 18 deletions(-) diff --git a/src/plumpy/base/utils.py b/src/plumpy/base/utils.py index c4820f1b..7397aa19 100644 --- a/src/plumpy/base/utils.py +++ b/src/plumpy/base/utils.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- -from typing import Any, Callable +from typing import Any, Callable, Dict __all__ = ['super_check', 'call_with_super_check'] +_SUPER_COUNTERS: Dict[str, int] = {} + def super_check(wrapped: Callable[..., Any]) -> Callable[..., Any]: """ @@ -11,11 +13,16 @@ def super_check(wrapped: Callable[..., Any]) -> Callable[..., Any]: """ def wrapper(self: Any, *args: Any, **kwargs: Any) -> None: - msg = f"The function '{wrapped.__name__}' was not called through call_with_super_check" - assert getattr(self, '_called', 0) >= 1, msg + # pylint: disable=locally-disabled, global-variable-not-assigned + global _SUPER_COUNTERS + counter_name = _get_counter_name(self, wrapped) + msg = f"The function '{wrapped.__name__}' was not called through call_with_super_check ({counter_name})" + assert _SUPER_COUNTERS.get(counter_name, 0) >= 1, msg wrapped(self, *args, **kwargs) - self._called -= 1 + _SUPER_COUNTERS[counter_name] -= 1 + #print('finished:', counter_name, _SUPER_COUNTERS[counter_name]) + wrapper.__name__ = wrapped.__name__ return wrapper @@ -23,9 +30,17 @@ def call_with_super_check(wrapped: Callable[..., Any], *args: Any, **kwargs: Any """ Call a class method checking that all subclasses called super along the way """ + # pylint: disable=locally-disabled, global-variable-not-assigned + global _SUPER_COUNTERS self = wrapped.__self__ # type: ignore # should actually be MethodType, but mypy does not handle this - call_count = getattr(self, '_called', 0) - self._called = call_count + 1 + counter_name = _get_counter_name(self, wrapped) + call_count = _SUPER_COUNTERS.get(counter_name, 0) + _SUPER_COUNTERS[counter_name] = call_count + 1 + #print(counter_name, call_count) wrapped(*args, **kwargs) msg = f"Base '{wrapped.__name__}' was not called from '{self.__class__}'\nHint: Did you forget to call the super?" - assert self._called == call_count, msg + assert _SUPER_COUNTERS[counter_name] == call_count, msg + + +def _get_counter_name(self: Any, wrapped: Callable[..., Any]) -> str: + return f'{id(self)}_called_' + wrapped.__name__ diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index ed54f495..b36cd602 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -10,7 +10,7 @@ _LOGGER = logging.getLogger(__name__) -@persistence.auto_persist('_listeners') +@persistence.auto_persist('_listeners', '_listener_type') class EventHelper(persistence.Savable): def __init__(self, listener_type: 'Type[ProcessListener]'): diff --git a/test/test_processes.py b/test/test_processes.py index 737b463d..4b7494f5 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -800,7 +800,9 @@ def test_instance_state_with_outputs(self): # Check that it is a copy self.assertIsNot(outputs, bundle.get(BundleKeys.OUTPUTS, {})) # Check the contents are the same - self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {})) + #we remove the ProcessSaver instance that is an object used only for testing + utils.compare_dictionaries(None, None, outputs, bundle.get(BundleKeys.OUTPUTS, {}), exclude={'_listeners'}) + #self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {})) self.assertIsNot(proc.outputs, saver.snapshots[-1].get(BundleKeys.OUTPUTS, {})) @@ -875,7 +877,8 @@ def _check_round_trip(self, proc1): bundle2 = plumpy.Bundle(proc2) self.assertEqual(proc1.pid, proc2.pid) - self.assertDictEqual(bundle1, bundle2) + #self.assertDictEqual(bundle1, bundle2) + utils.compare_dictionaries(None, None, bundle1, bundle2, exclude={'_listeners'}) class TestProcessNamespace(unittest.TestCase): diff --git a/test/utils.py b/test/utils.py index feb3d1c8..c528c52f 100644 --- a/test/utils.py +++ b/test/utils.py @@ -260,25 +260,55 @@ def _save(self, p): self.outputs.append(p.outputs.copy()) -class ProcessSaver(plumpy.ProcessListener, Saver): +_ProcessSaverProcReferences = {} +_ProcessSaver_Saver = {} + + +class ProcessSaver(plumpy.ProcessListener): """ Save the instance state of a process each time it is about to enter a new state """ + def get_process(self): + global _ProcessSaverProcReferences + return _ProcessSaverProcReferences[id(self)] + + def _save(self, p): + global _ProcessSaver_Saver + _ProcessSaver_Saver[id(self)]._save(p) + + def set_process(self, process): + global _ProcessSaverProcReferences + _ProcessSaverProcReferences[id(self)] = process + def __init__(self, proc): plumpy.ProcessListener.__init__(self) - Saver.__init__(self) - self.process = proc proc.add_process_listener(self) + self.init_not_persistent(proc) + + def init_not_persistent(self, proc): + global _ProcessSaver_Saver + _ProcessSaver_Saver[id(self)] = Saver() + self.set_process(proc) def capture(self): - self._save(self.process) - if not self.process.has_terminated(): + self._save(self.get_process()) + if not self.get_process().has_terminated(): try: - self.process.execute() + self.get_process().execute() except Exception: pass + @property + def snapshots(self): + global _ProcessSaver_Saver + return _ProcessSaver_Saver[id(self)].snapshots + + @property + def outputs(self): + global _ProcessSaver_Saver + return _ProcessSaver_Saver[id(self)].outputs + @utils.override def on_process_running(self, process): self._save(process) @@ -335,7 +365,12 @@ def check_process_against_snapshots(loop, proc_class, snapshots): """ for i, bundle in zip(list(range(0, len(snapshots))), snapshots): loaded = bundle.unbundle(plumpy.LoadSaveContext(loop=loop)) - saver = ProcessSaver(loaded) + #the process listeners are persisted + saver = list(loaded._event_helper._listeners)[0] + assert type(saver) == ProcessSaver + #process cannot be persisted because of a circular reference. So we load it there + #also the saver is not persisted for the same reason. We load it manually + saver.init_not_persistent(loaded) saver.capture() # Now check going backwards until running that the saved states match @@ -345,7 +380,11 @@ def check_process_against_snapshots(loop, proc_class, snapshots): break compare_dictionaries( - snapshots[-j], saver.snapshots[-j], snapshots[-j], saver.snapshots[-j], exclude={'exception'} + snapshots[-j], + saver.snapshots[-j], + snapshots[-j], + saver.snapshots[-j], + exclude={'exception', '_listeners'} ) j += 1 @@ -376,6 +415,11 @@ def compare_value(bundle1, bundle2, v1, v2, exclude=None): elif isinstance(v1, list) and isinstance(v2, list): for vv1, vv2 in zip(v1, v2): compare_value(bundle1, bundle2, vv1, vv2, exclude) + elif isinstance(v1, set) and isinstance(v2, set) and len(v1) == len(v2) and len(v1) <= 1: + #TODO: implement sets with more than one element + compare_value(bundle1, bundle2, list(v1), list(v2), exclude) + elif isinstance(v1, set) and isinstance(v2, set): + raise NotImplementedError('Comparison between sets not implemented') else: if v1 != v2: raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}')