Skip to content

Commit

Permalink
Fixing the test
Browse files Browse the repository at this point in the history
apparently there was an issue with the logic of super_check:
the variable to store the counter was shared between all functions, but
this was not working when multiple functions calls super_check, and use
it again inside at some points. Now we store the counters in a global
dict to not pollute the class with a lot of names. This can be done
better.

There was a circular reference issue in the test listener that was
storing a reference to the process inside it, making its serialization
impossible. To fix the tests an ugli hack was used. Storing the
reference to the process outside the class in a global dict using id as
keys. Some more ugly hacks are needed to check correctly the equality of
two processes. We must ignore the fact that the instances if the
listener are different.
  • Loading branch information
rikigigi committed Sep 19, 2023
1 parent 4fd678b commit 820d7aa
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 18 deletions.
29 changes: 22 additions & 7 deletions src/plumpy/base/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
from typing import Any, Callable
from typing import Any, Callable, Dict

__all__ = ['super_check', 'call_with_super_check']

_SUPER_COUNTERS: Dict[str, int] = {}


def super_check(wrapped: Callable[..., Any]) -> Callable[..., Any]:
"""
Expand All @@ -11,21 +13,34 @@ def super_check(wrapped: Callable[..., Any]) -> Callable[..., Any]:
"""

def wrapper(self: Any, *args: Any, **kwargs: Any) -> None:
msg = f"The function '{wrapped.__name__}' was not called through call_with_super_check"
assert getattr(self, '_called', 0) >= 1, msg
# pylint: disable=locally-disabled, global-variable-not-assigned
global _SUPER_COUNTERS
counter_name = _get_counter_name(self, wrapped)
msg = f"The function '{wrapped.__name__}' was not called through call_with_super_check ({counter_name})"
assert _SUPER_COUNTERS.get(counter_name, 0) >= 1, msg
wrapped(self, *args, **kwargs)
self._called -= 1
_SUPER_COUNTERS[counter_name] -= 1
#print('finished:', counter_name, _SUPER_COUNTERS[counter_name])

wrapper.__name__ = wrapped.__name__
return wrapper


def call_with_super_check(wrapped: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""
Call a class method checking that all subclasses called super along the way
"""
# pylint: disable=locally-disabled, global-variable-not-assigned
global _SUPER_COUNTERS
self = wrapped.__self__ # type: ignore # should actually be MethodType, but mypy does not handle this
call_count = getattr(self, '_called', 0)
self._called = call_count + 1
counter_name = _get_counter_name(self, wrapped)
call_count = _SUPER_COUNTERS.get(counter_name, 0)
_SUPER_COUNTERS[counter_name] = call_count + 1
#print(counter_name, call_count)
wrapped(*args, **kwargs)
msg = f"Base '{wrapped.__name__}' was not called from '{self.__class__}'\nHint: Did you forget to call the super?"
assert self._called == call_count, msg
assert _SUPER_COUNTERS[counter_name] == call_count, msg


def _get_counter_name(self: Any, wrapped: Callable[..., Any]) -> str:
return f'{id(self)}_called_' + wrapped.__name__
2 changes: 1 addition & 1 deletion src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
_LOGGER = logging.getLogger(__name__)


@persistence.auto_persist('_listeners')
@persistence.auto_persist('_listeners', '_listener_type')
class EventHelper(persistence.Savable):

def __init__(self, listener_type: 'Type[ProcessListener]'):
Expand Down
7 changes: 5 additions & 2 deletions test/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,9 @@ def test_instance_state_with_outputs(self):
# Check that it is a copy
self.assertIsNot(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
# Check the contents are the same
self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
#we remove the ProcessSaver instance that is an object used only for testing
utils.compare_dictionaries(None, None, outputs, bundle.get(BundleKeys.OUTPUTS, {}), exclude={'_listeners'})
#self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {}))

self.assertIsNot(proc.outputs, saver.snapshots[-1].get(BundleKeys.OUTPUTS, {}))

Expand Down Expand Up @@ -875,7 +877,8 @@ def _check_round_trip(self, proc1):
bundle2 = plumpy.Bundle(proc2)

self.assertEqual(proc1.pid, proc2.pid)
self.assertDictEqual(bundle1, bundle2)
#self.assertDictEqual(bundle1, bundle2)
utils.compare_dictionaries(None, None, bundle1, bundle2, exclude={'_listeners'})


class TestProcessNamespace(unittest.TestCase):
Expand Down
60 changes: 52 additions & 8 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,25 +260,55 @@ def _save(self, p):
self.outputs.append(p.outputs.copy())


class ProcessSaver(plumpy.ProcessListener, Saver):
_ProcessSaverProcReferences = {}
_ProcessSaver_Saver = {}


class ProcessSaver(plumpy.ProcessListener):
"""
Save the instance state of a process each time it is about to enter a new state
"""

def get_process(self):
global _ProcessSaverProcReferences
return _ProcessSaverProcReferences[id(self)]

def _save(self, p):
global _ProcessSaver_Saver
_ProcessSaver_Saver[id(self)]._save(p)

def set_process(self, process):
global _ProcessSaverProcReferences
_ProcessSaverProcReferences[id(self)] = process

def __init__(self, proc):
plumpy.ProcessListener.__init__(self)
Saver.__init__(self)
self.process = proc
proc.add_process_listener(self)
self.init_not_persistent(proc)

def init_not_persistent(self, proc):
global _ProcessSaver_Saver
_ProcessSaver_Saver[id(self)] = Saver()
self.set_process(proc)

def capture(self):
self._save(self.process)
if not self.process.has_terminated():
self._save(self.get_process())
if not self.get_process().has_terminated():
try:
self.process.execute()
self.get_process().execute()
except Exception:
pass

@property
def snapshots(self):
global _ProcessSaver_Saver
return _ProcessSaver_Saver[id(self)].snapshots

@property
def outputs(self):
global _ProcessSaver_Saver
return _ProcessSaver_Saver[id(self)].outputs

@utils.override
def on_process_running(self, process):
self._save(process)
Expand Down Expand Up @@ -335,7 +365,12 @@ def check_process_against_snapshots(loop, proc_class, snapshots):
"""
for i, bundle in zip(list(range(0, len(snapshots))), snapshots):
loaded = bundle.unbundle(plumpy.LoadSaveContext(loop=loop))
saver = ProcessSaver(loaded)
#the process listeners are persisted
saver = list(loaded._event_helper._listeners)[0]
assert type(saver) == ProcessSaver
#process cannot be persisted because of a circular reference. So we load it there
#also the saver is not persisted for the same reason. We load it manually
saver.init_not_persistent(loaded)
saver.capture()

# Now check going backwards until running that the saved states match
Expand All @@ -345,7 +380,11 @@ def check_process_against_snapshots(loop, proc_class, snapshots):
break

compare_dictionaries(
snapshots[-j], saver.snapshots[-j], snapshots[-j], saver.snapshots[-j], exclude={'exception'}
snapshots[-j],
saver.snapshots[-j],
snapshots[-j],
saver.snapshots[-j],
exclude={'exception', '_listeners'}
)
j += 1

Expand Down Expand Up @@ -376,6 +415,11 @@ def compare_value(bundle1, bundle2, v1, v2, exclude=None):
elif isinstance(v1, list) and isinstance(v2, list):
for vv1, vv2 in zip(v1, v2):
compare_value(bundle1, bundle2, vv1, vv2, exclude)
elif isinstance(v1, set) and isinstance(v2, set) and len(v1) == len(v2) and len(v1) <= 1:
#TODO: implement sets with more than one element
compare_value(bundle1, bundle2, list(v1), list(v2), exclude)
elif isinstance(v1, set) and isinstance(v2, set):
raise NotImplementedError('Comparison between sets not implemented')
else:
if v1 != v2:
raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}')
Expand Down

0 comments on commit 820d7aa

Please sign in to comment.