Skip to content

Commit

Permalink
cleaning up of the code
Browse files Browse the repository at this point in the history
calling del on dict items of the call_with_super_check's global when the counter reaches zero
calling del on dict items of the ProcessListener's global implemented in the test suite
  • Loading branch information
rikigigi committed Sep 25, 2023
1 parent 820d7aa commit abbab03
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/plumpy/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> None:
assert _SUPER_COUNTERS.get(counter_name, 0) >= 1, msg
wrapped(self, *args, **kwargs)
_SUPER_COUNTERS[counter_name] -= 1
#print('finished:', counter_name, _SUPER_COUNTERS[counter_name])

wrapper.__name__ = wrapped.__name__
return wrapper
Expand All @@ -36,10 +35,11 @@ def call_with_super_check(wrapped: Callable[..., Any], *args: Any, **kwargs: Any
counter_name = _get_counter_name(self, wrapped)
call_count = _SUPER_COUNTERS.get(counter_name, 0)
_SUPER_COUNTERS[counter_name] = call_count + 1
#print(counter_name, call_count)
wrapped(*args, **kwargs)
msg = f"Base '{wrapped.__name__}' was not called from '{self.__class__}'\nHint: Did you forget to call the super?"
assert _SUPER_COUNTERS[counter_name] == call_count, msg
if call_count == 0:
del _SUPER_COUNTERS[counter_name]


def _get_counter_name(self: Any, wrapped: Callable[..., Any]) -> str:
Expand Down
19 changes: 14 additions & 5 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def run(self):
self.out('test', 5)
return process_states.Continue(self.middle_step)

def middle_step(self,):
def middle_step(self):
return process_states.Continue(self.last_step)

def last_step(self):
Expand Down Expand Up @@ -269,6 +269,14 @@ class ProcessSaver(plumpy.ProcessListener):
Save the instance state of a process each time it is about to enter a new state
"""

def __del__(self):
global _ProcessSaver_Saver
global _ProcessSaverProcReferences
if _ProcessSaverProcReferences is not None and id(self) in _ProcessSaverProcReferences:
del _ProcessSaverProcReferences[id(self)]
if _ProcessSaver_Saver is not None and id(self) in _ProcessSaver_Saver:
del _ProcessSaver_Saver[id(self)]

def get_process(self):
global _ProcessSaverProcReferences
return _ProcessSaverProcReferences[id(self)]
Expand Down Expand Up @@ -365,11 +373,12 @@ def check_process_against_snapshots(loop, proc_class, snapshots):
"""
for i, bundle in zip(list(range(0, len(snapshots))), snapshots):
loaded = bundle.unbundle(plumpy.LoadSaveContext(loop=loop))
#the process listeners are persisted
# the process listeners are persisted
saver = list(loaded._event_helper._listeners)[0]
assert type(saver) == ProcessSaver
#process cannot be persisted because of a circular reference. So we load it there
#also the saver is not persisted for the same reason. We load it manually
# the process reference inside this particular implementation of process listener
# cannot be persisted because of a circular reference. So we load it there
# also the saver is not persisted for the same reason. We load it manually
saver.init_not_persistent(loaded)
saver.capture()

Expand Down Expand Up @@ -416,7 +425,7 @@ def compare_value(bundle1, bundle2, v1, v2, exclude=None):
for vv1, vv2 in zip(v1, v2):
compare_value(bundle1, bundle2, vv1, vv2, exclude)
elif isinstance(v1, set) and isinstance(v2, set) and len(v1) == len(v2) and len(v1) <= 1:
#TODO: implement sets with more than one element
# TODO: implement sets with more than one element
compare_value(bundle1, bundle2, list(v1), list(v2), exclude)
elif isinstance(v1, set) and isinstance(v2, set):
raise NotImplementedError('Comparison between sets not implemented')
Expand Down

0 comments on commit abbab03

Please sign in to comment.