diff --git a/.pylintrc b/.pylintrc index 990ee559..bd1f175b 100644 --- a/.pylintrc +++ b/.pylintrc @@ -64,6 +64,7 @@ disable=print-statement, too-many-locals, too-many-branches, too-few-public-methods, + too-many-public-methods, too-many-lines, no-self-use, fixme, diff --git a/CHANGELOG.md b/CHANGELOG.md index ab62a4ac..c1a1bd7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## v1.3.0 + +* (#198) Introduce process commands to support more interactions with + parallel processes. Now all `Process` methods of a parallelized + process can be queried from the parent OS process. Users can also add + support for custom methods of their processes. + + This change also simplifies the way `Engine` handles parallel + processes warns users when serializers are not being found + efficiently. + ## v1.2.8 * (#186) Apply function to data from database emitter in `get_history_data_db`. diff --git a/doc/guides/processes.rst b/doc/guides/processes.rst index 2934b540..f63aebc9 100644 --- a/doc/guides/processes.rst +++ b/doc/guides/processes.rst @@ -562,3 +562,40 @@ The above pseudocode is simplified, and for all but the most simple processes you will be better off using Vivarium's built-in simulation capabilities. We hope though that this helps you understand how processes are simulated and the purpose of the API we defined. + +------------------- +Parallel Processing +------------------- + +Process Commands +================ + +When a :term:`process` is run in parallel, we can't interact with it in +the normal Python way. Instead, we can only exchange messages with it +through a pipe. Vivarium structures these exchanges using :term:`process +commands`. + +Vivarium provides some built-in commands, which are documented in +:py:meth:`vivarium.core.process.Process.send_command`. Also see that +method's documentation for instructions on how to add support for your +own commands. + +Process commands are designed to be used asynchronously, so to retrieve +the result of running a command, you need to call +:py:meth:`vivarium.core.process.Process.get_command_result`. As a +convenience, you can also call +:py:meth:`vivarium.core.process.Process.run_command` to send a command +and get its result as a return value in one function call. + +Running Processes in Parallel +============================= + +In normal situations though, you shouldn't have to worry about process +commands. Instead, just pass ``'_parallel': True`` in a process's +configuration dictionary, and the Vivarium Engine will handle the +parallelization for you. Just remember that parallelization requires +that processes be serialized and deserialized at the start of the +simulation, and this serialization only preserves the process +parameters. This means that if you instantiate a process and then change +its instance variables, those changes won't be preserved when the +process gets parallelized. diff --git a/doc/reference/api/vivarium.core.process.rst b/doc/reference/api/vivarium.core.process.rst index 34d1b7b0..6cfa1be1 100644 --- a/doc/reference/api/vivarium.core.process.rst +++ b/doc/reference/api/vivarium.core.process.rst @@ -1,4 +1,5 @@ .. automodule:: vivarium.core.process :members: :undoc-members: + :private-members: _handle_parallel_process :show-inheritance: diff --git a/doc/reference/glossary.rst b/doc/reference/glossary.rst index d7f8501e..bdfcac3a 100644 --- a/doc/reference/glossary.rst +++ b/doc/reference/glossary.rst @@ -233,6 +233,16 @@ Glossary subclass either :py:class:`vivarium.core.process.Process` or another process class. + Process Command + Process command + process command + Process Commands + Process commands + process commands + Instructions that let Vivarium communicate with parallel + processes in a remote-procedure-call-like fashion. See :doc:`the + processes guide ` for details. + Raw Data Raw data raw data diff --git a/doc/vale/styles/Vocab/All/vocab.txt b/doc/vale/styles/Vocab/All/vocab.txt index 426a62c1..a425c424 100644 --- a/doc/vale/styles/Vocab/All/vocab.txt +++ b/doc/vale/styles/Vocab/All/vocab.txt @@ -38,3 +38,5 @@ Agmon Spangler Skalnik Bioinformatics +parallelization +deserialized diff --git a/setup.py b/setup.py index b503c335..fc516da2 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup -VERSION = '1.2.8' +VERSION = '1.3.0' if __name__ == '__main__': diff --git a/vivarium/core/engine.py b/vivarium/core/engine.py index 451ec782..e7fe5abe 100644 --- a/vivarium/core/engine.py +++ b/vivarium/core/engine.py @@ -40,6 +40,7 @@ inverse_topology, normalize_path, ) +from vivarium.library.dict_utils import apply_func_to_leaves from vivarium.core.types import ( HierarchyPath, Topology, State, Update, Processes, Steps, Flow, Schema) @@ -119,21 +120,6 @@ def timestamp(dt: Optional[Any] = None) -> str: dt.hour, dt.minute, dt.second) -def invoke_process( - process: Process, - interval: float, - states: State, -) -> Update: - """Compute a process's next update. - - Call the process's - :py:meth:`vivarium.core.process.Process.next_update` function with - ``interval`` and ``states``. - """ - - return process.next_update(interval, states) - - def empty_front(t: float) -> Dict[str, Union[float, dict]]: return { 'time': t, @@ -154,10 +140,12 @@ def __init__( called. Args: - defer: An object with a ``.get()`` method whose output will - be passed to the function. For example, the object could - be an :py:class:`InvokeProcess` object whose ``.get()`` - method will return the process update. + defer: An object with a ``.get_command_result()`` method + whose output will be passed to the function. For + example, the object could be an + :py:class:`vivarium.core.process.Process` object whose + ``.get_command_result()`` method will return the process + update. function: The function. For example, :py:func:`invert_topology` to transform the returned update. @@ -174,7 +162,7 @@ def get(self) -> Update: The result of calling the function. """ return self.f( - self.defer.get(), + self.defer.get_command_result(), self.args) @@ -188,44 +176,6 @@ def get(self) -> Update: return {} -class InvokeProcess: - def __init__( - self, - process: Process, - interval: float, - states: State, - ) -> None: - """A wrapper object that computes an update. - - This class holds the update of a process that is not running in - parallel. When instantiated, it immediately computes the - process's next update. - - Args: - process: The process that will calculate the update. - interval: The timestep for the update. - states: The simulation state to pass to the process's - ``next_update`` function. - """ - self.process = process - self.interval = interval - self.states = states - self.update = invoke_process( - self.process, - self.interval, - self.states) - - def get(self) -> Update: - """Return the computed update. - - This method is analogous to the ``.get()`` method in - :py:class:`vivarium.core.process.ParallelProcess` so that - parallel and non-parallel updates can be intermixed in the - simulation engine. - """ - return self.update - - class _StepGraph: """A dependency graph of :term:`steps`. @@ -374,7 +324,6 @@ def __init__( emit_topology: bool = True, emit_processes: bool = False, emit_config: bool = False, - invoke: Optional[Any] = None, emit_step: float = 1, display_info: bool = True, progress_bar: bool = False, @@ -409,7 +358,15 @@ def __init__( process for that port. store: A pre-loaded Store. This is an alternative to passing in processes and topology dict, which can not be loaded - at the same time. + at the same time. Note that if you provide this + argument, you must ensure that all parallel processes + (i.e. :py:class:`vivarium.core.process.Process` objects + with the ``parallel`` attribute set to ``True``) are + instances of + :py:class:`vivarium.core.process.ParallelProcess`. This + constructor converts parallel processes to + ``ParallelProcess`` objects automatically if you do not + provide this ``store`` argument. initial_state: By default an empty dictionary, this is the initial state of the simulation. experiment_id: A unique identifier for the experiment. A @@ -466,10 +423,6 @@ def __init__( if self.display_info: self._print_display() - # parallel settings - self.invoke = invoke or InvokeProcess - self.parallel: Dict[HierarchyPath, ParallelProcess] = {} - # get a mapping of all paths to processes self.process_paths: Dict[HierarchyPath, Process] = {} self._step_graph = _StepGraph() @@ -572,6 +525,15 @@ def _make_store( 'load either composite, store, or ' '(processes and topology) into Engine') + self.processes = cast( + Dict[str, Any], + self._parallelize_processes(self.processes) + ) + self.steps = cast( + Dict[str, Any], + self._parallelize_processes(self.steps) + ) + # initialize the store self.state: Store = generate_state( self.processes, @@ -592,6 +554,23 @@ def _make_store( self.flow = self.state.get_flow() or {} self.topology = self.state.get_topology() + def _parallelize_processes( + self, processes: Any) -> Union[dict, Process]: + '''Replace parallel processes with ParallelProcess objects.''' + if isinstance(processes, Process): + if processes.parallel and not isinstance( + processes, ParallelProcess): + processes = ParallelProcess( + processes, bool(self.profiler), self.stats_objs) + elif isinstance(processes, dict): + processes = { + key: self._parallelize_processes(value) + for key, value in processes.items() + } + else: + raise AssertionError(f'Unrecognized collection: {processes}') + return processes + def _add_step_path( self, step: Step, @@ -682,7 +661,6 @@ def _emit_store_data(self) -> None: def _invoke_process( self, process: Process, - path: HierarchyPath, interval: float, states: State, ) -> Any: @@ -694,28 +672,16 @@ def _invoke_process( Args: process: The process. - path: The path at which the process resides. This is used to - track parallel processes in ``self.parallel``. interval: The timestep for which to compute the update. states: The simulation state to pass to :py:meth:`vivarium.core.process.Process.next_update`. Returns: The deferred simulation update, for example a - :py:class:`vivarium.core.process.ParallelProcess` or an - :py:class:`InvokeProcess` object. + :py:class:`vivarium.core.process.ParallelProcess`. """ - if process.parallel: - # add parallel process if it doesn't exist - if path not in self.parallel: - self.parallel[path] = ParallelProcess( - process, bool(self.profiler)) - # trigger the computation of the parallel process - self.parallel[path].update(interval, states) - - return self.parallel[path] - # if not parallel, perform a normal invocation - return self.invoke(process, interval, states) + process.send_command('next_update', (interval, states)) + return process def _process_update( self, @@ -748,7 +714,6 @@ def _process_update( update = self._invoke_process( process, - path, interval, states) @@ -835,6 +800,20 @@ def apply_update( flow_updates, deletions, view_expire ) = self.state.apply_update(update, state) + process_updates = [ + (path, self._parallelize_processes(process)) + for path, process in process_updates + ] + step_updates = [ + (path, self._parallelize_processes(step)) + for path, step in step_updates + ] + # Make sure the Store contains the parallelized processes. + for path, process in process_updates: + self.state.get_path(path).value = process + for path, step in step_updates: + self.state.get_path(path).value = step + flow_update_dict = dict(flow_updates) if topology_updates: @@ -968,16 +947,7 @@ def _check_complete(self) -> None: f"the process at path {path} is an unapplied update" def _remove_deleted_processes(self) -> None: - # find any parallel processes that were removed and terminate them - for terminated in self.parallel.keys() - ( - self.process_paths.keys() | self._step_paths.keys()): - self.parallel[terminated].end() - stats = self.parallel[terminated].stats - if stats: - self.stats_objs.append(stats) - del self.parallel[terminated] - - # remove deleted process paths from the front + '''Remove deleted processes from the front.''' self.front = { path: progress for path, progress in self.front.items() @@ -1107,6 +1077,12 @@ def run_for( if force_complete and self.global_time == end_time: force_complete = False + @staticmethod + def _end_process_if_parallel(process: Process) -> None: + if process.parallel: + assert isinstance(process, ParallelProcess) + process.end() + def end(self) -> None: """Terminate all processes running in parallel. @@ -1115,10 +1091,8 @@ def end(self) -> None: profiling stats, including stats from parallel sub-processes. These stats are stored in ``self.stats``. """ - for parallel in self.parallel.values(): - parallel.end() - if parallel.stats: - self.stats_objs.append(parallel.stats) + apply_func_to_leaves( + self.processes, self._end_process_if_parallel) if self.profiler: self.profiler.disable() total_stats = pstats.Stats(self.profiler) diff --git a/vivarium/core/process.py b/vivarium/core/process.py index 0eccf072..d7bf2433 100644 --- a/vivarium/core/process.py +++ b/vivarium/core/process.py @@ -10,12 +10,14 @@ from multiprocessing import Pipe from multiprocessing import Process as Multiprocess from multiprocessing.connection import Connection +import os import pstats import pickle -from typing import ( - Any, Dict, Optional, Union, List) +from typing import Any, Dict, Optional, Union, List, Tuple from warnings import warn +import pytest + from vivarium.library.dict_utils import ( deep_merge, deep_merge_check, deep_copy_internal) from vivarium.library.topology import assoc_path, get_in @@ -120,6 +122,14 @@ class can provide a ``defaults`` class variable to specify the """ defaults: Dict[str, Any] = {} + METHOD_COMMANDS = ( + 'initial_state', 'generate_processes', 'generate_steps', + 'generate_topology', 'generate_flow', 'merge_overrides', + 'calculate_timestep', 'is_step', 'get_private_state', + 'ports_schema', 'next_update', 'update_condition') + ATTRIBUTE_READ_COMMANDS = ( + 'schema_override', 'parameters', 'condition_path', 'schema') + ATTRIBUTE_WRITE_COMMANDS = ('set_schema',) def __init__(self, parameters: Optional[dict] = None) -> None: parameters = parameters or {} @@ -143,29 +153,209 @@ def __init__(self, parameters: Optional[dict] = None) -> None: elif not hasattr(self, 'name'): self.name = self.__class__.__name__ - self.parameters = copy.deepcopy(self.defaults) - self.parameters = deep_merge(self.parameters, parameters) - self.schema_override = self.parameters.pop('_schema', {}) - self.parallel = self.parameters.pop('_parallel', False) - self.condition_path = None + self._parameters = copy.deepcopy(self.defaults) + self._parameters = deep_merge(self._parameters, parameters) + self._schema_override: Schema = self._parameters.pop('_schema', {}) + self._parallel = self._parameters.pop('_parallel', False) + self._condition_path: Optional[HierarchyPath] = None + self._command_result: Any = None + self._pending_command: Optional[ + Tuple[str, Optional[tuple], Optional[dict]]] = None # set up the conditional state if a condition key is provided - if '_condition' in self.parameters: - self.condition_path = self.parameters.pop('_condition') - if self.condition_path: - self.merge_overrides(assoc_path({}, self.condition_path, { + if '_condition' in self._parameters: + self._condition_path = self._parameters.pop('_condition') + if self._condition_path: + self.merge_overrides(assoc_path({}, self._condition_path, { '_default': True, '_emit': True, '_updater': 'set'})) self._set_timestep() - self.schema: Optional[dict] = None + self._schema: Optional[Schema] = None + + @property + def parameters(self) -> dict: + return self._parameters + + @property + def schema_override(self) -> Schema: + return self._schema_override + + @property + def parallel(self) -> bool: + return self._parallel + + @property + def condition_path(self) -> Optional[HierarchyPath]: + return self._condition_path + + @property + def schema(self) -> Optional[Schema]: + return self._schema + + @schema.setter + def schema(self, value: Optional[Schema]) -> None: + self._schema = value + + def pre_send_command( + self, command: str, args: Optional[tuple], kwargs: + Optional[dict]) -> None: + '''Run pre-checks before starting a command. + + This method should be called at the start of every + implementation of :py:meth:`send_command`. + + Args: + command: The name of the command to run. + args: A tuple of positional arguments for the command. + kwargs: A dictionary of keyword arguments for the command. + + Raises: + RuntimeError: Raised when a user tries to send a command + while a previous command is still pending (i.e. the user + hasn't called :py:meth:`get_command_result` yet for the + previous command). + ''' + if self._pending_command: + raise RuntimeError( + f'Trying to send command {(command, args, kwargs)} but ' + f'command {self._pending_command} is still pending.') + self._pending_command = command, args, kwargs + + + def send_command( + self, command: str, args: Optional[tuple] = None, + kwargs: Optional[dict] = None, + run_pre_check: bool = True) -> None: + '''Handle :term:`process commands`. + + This method handles the commands listed in + :py:attr:`METHOD_COMMANDS` by passing ``args`` + and ``kwargs`` to the method of ``self`` with the name + of the command and saving the return value as the result. + + This method handles the commands listed in + :py:attr:`ATTRIBUTE_READ_COMMANDS` by returning the attribute of + ``self`` with the name matching the command, and it handles the + commands listed in :py:attr:`ATTRIBUTE_WRITE_COMMANDS` by + setting the attribute in the command to the first argument in + ``args``. The command must be named ``set_attr`` for attribute + ``attr``. + + To add support for a custom command, override this function in + your subclass. Each command is defined by a name (a string) + and accepts both positional and keyword arguments. Any custom + commands you add should have associated methods such that: + + * The command name matches the method name. + * The command and method accept the same positional and keyword + arguments. + * The command and method return the same values. + + If all of the above are satisfied, you can use + :py:meth:`Process.run_command_method` to handle the command. + + Your implementation of this function needs to handle all the + commands you want to support. When presented with an unknown + command, you should call the superclass method, which will + either handle the command or call its superclass method. At the + top of this recursive chain, this ``Process.send_command()`` + method handles some built-in commands and will raise an error + for unknown commands. + + Any overrides of this method must also call + :py:meth:`pre_send_command` at the start of the method. This + call will check that no command is currently pending to avoid + confusing behavior when multiple commands are started without + intervening retrievals of command results. Since your overriding + method will have already performed the pre-check, it should pass + ``run_pre_check=False`` when calling the superclass method. + + Args: + command: The name of the command to run. + args: A tuple of positional arguments for the command. + kwargs: A dictionary of keyword arguments for the command. + run_pre_check: Whether to run the pre-checks implemented in + :py:meth:`pre_send_command`. This should be left at its + default value unless the pre-checks have already been + performed (e.g. if this method is being called by a + subclass's overriding method.) + + Returns: + None. This method just starts the command running. + + Raises: + ValueError: For unknown commands. + ''' + if run_pre_check: + self.pre_send_command(command, args, kwargs) + args = args or tuple() + kwargs = kwargs or {} + if command in self.METHOD_COMMANDS: + self._command_result = self.run_command_method( + command, args, kwargs) + elif command in self.ATTRIBUTE_READ_COMMANDS: + self._command_result = getattr(self, command) + elif command in self.ATTRIBUTE_WRITE_COMMANDS: + assert command.startswith('set_') + assert args + setattr(self, command[len('set_'):], args[0]) + else: + raise ValueError( + f'Process {self} does not understand the process ' + f'command {command}') + + def run_command_method( + self, command: str, args: tuple, kwargs: dict) -> Any: + '''Run a command whose name and interface match a method. + + Args: + command: The command name, which must equal to a method of + ``self``. + args: The positional arguments to pass to the method. + kwargs: The keywords arguments for the method. + + Returns: + The result of calling ``self.command(*args, **kwargs)`` is + returned for command ``command``. + ''' + return getattr(self, command)(*args, **kwargs) + + def get_command_result(self) -> Any: + '''Retrieve the result from the last-run command. + + Returns: + The result of the last command run. Note that this method + should only be called once immediately after each call to + :py:meth:`send_command`. + + Raises: + RuntimeError: When there is no command pending. This can + happen when this method is called twice without an + intervening call to :py:meth:`send_command`. + ''' + if not self._pending_command: + raise RuntimeError( + 'Trying to retrieve command result, but no command is ' + 'pending.') + self._pending_command = None + result = self._command_result + self._command_result = None + return result + + def run_command( + self, command: str, args: Optional[tuple] = None, + kwargs: Optional[dict] = None) -> Any: + '''Helper function that sends a command and returns result.''' + self.send_command(command, args, kwargs) + return self.get_command_result() def _set_timestep(self) -> None: - self.parameters.setdefault('timestep', DEFAULT_TIME_STEP) - if self.parameters.get('time_step'): - self.parameters['timestep'] = self.parameters['time_step'] + self._parameters.setdefault('timestep', DEFAULT_TIME_STEP) + if self._parameters.get('time_step'): + self._parameters['timestep'] = self._parameters['time_step'] def __getstate__(self) -> dict: """Return parameters @@ -292,7 +482,7 @@ def merge_overrides(self, override: Schema) -> None: Args: override: The schema override to add. """ - deep_merge(self.schema_override, override) + deep_merge(self._schema_override, override) def ports(self) -> Dict[str, List[str]]: """Get ports and each port's variables. @@ -485,24 +675,43 @@ def is_step(self) -> bool: Deriver = Step -def _run_update( +def _handle_parallel_process( connection: Connection, process: Process, profile: bool) -> None: + '''Handle a parallel Vivarium :term:`process`. + + This function is designed to be passed as ``target`` to + ``Multiprocess()``. In a loop, it receives :term:`process commands` + from a pipe, passes those commands to the parallel process, and + passes the result back along the pipe. + + The special command ``end`` is handled directly by this function. + This command causes the function to exit and therefore shut down the + OS process created by multiprocessing. + + Args: + connection: The child end of a multiprocessing pipe. All + communications received from the pipe should be a 3-tuple of + the form ``(command, args, kwargs)``, and the tuple contents + will be passed to :py:meth:`Process.run_command`. The + result, which may be of any type, will be sent back through + the pipe. + process: The process running in parallel. + profile: Whether to profile the process. + ''' if profile: profiler = cProfile.Profile() profiler.enable() running = True while running: - interval, states = connection.recv() + command, args, kwargs = connection.recv() - # stop process by sending -1 as the interval - if interval == -1: + if command == 'end': running = False - else: - update = process.next_update(interval, states) - connection.send(update) + result = process.run_command(command, args, kwargs) + connection.send(result) if profile: profiler.disable() @@ -512,8 +721,10 @@ def _run_update( connection.close() -class ParallelProcess: - def __init__(self, process: Process, profile: bool = False) -> None: +class ParallelProcess(Process): + def __init__( + self, process: Process, profile: bool = False, + stats_objs: Optional[List[pstats.Stats]] = None) -> None: """Wraps a :py:class:`Process` for multiprocessing. To run a simulation distributed across multiple processors, we @@ -522,38 +733,124 @@ def __init__(self, process: Process, profile: bool = False) -> None: process and the child process with the :py:class:`Process` that this object manages. + Most methods pass their name and arguments to + :py:class:`Process.run_command`. + Args: process: The Process to manage. profile: Whether to use cProfile to profile the subprocess. + stats_objs: List to add cProfile stats objs to when process + is deleted. Only used if ``profile`` is true. """ + super().__init__({ + '_no_original_parameters': True, + 'name': process.name, + '_parallel': True, + }) self.process = process self.profile = profile - self.stats: Optional[pstats.Stats] = None + self._stats_objs = stats_objs + assert not self.profile or self._stats_objs is not None self.parent, self.child = Pipe() self.multiprocess = Multiprocess( - target=_run_update, + target=_handle_parallel_process, args=(self.child, self.process, self.profile)) self.multiprocess.start() - - def update( - self, interval: Union[float, int], states: State) -> None: - """Request an update from the process. - - Args: - interval: The length of the timestep for which the update - should be computed. - states: The pre-update state of the simulation. - """ - self.parent.send((interval, states)) - - def get(self) -> Update: - """Get an update from the process. + self._ended = False + self._pending_command: Optional[ + Tuple[str, Optional[tuple], Optional[dict]]] = None + + def send_command( + self, command: str, args: Optional[tuple] = None, + kwargs: Optional[dict] = None, + run_pre_check: bool = True) -> None: + '''Send a command to the parallel process. + + See :py:func:``_handle_parallel_process`` for details on how the + command will be handled. + ''' + if run_pre_check: + self.pre_send_command(command, args, kwargs) + self.parent.send((command, args, kwargs)) + + def get_command_result(self) -> Update: + """Get the result of a command sent to the parallel process. + + Commands and their results work like a queue, so unlike + :py:class:`Process`, you can technically call this method + multiple times and get different return values each time. + This behavior is subject to change, so you should not rely on + it. Returns: - The update from the process. + The command result. """ + if not self._pending_command: + raise RuntimeError( + 'Trying to retrieve command result, but no command is ' + 'pending.') + self._pending_command = None return self.parent.recv() + def initial_state(self, config: Optional[dict] = None) -> State: + return self.run_command('initial_state', (config,)) + + def generate_processes( + self, config: Optional[dict] = None) -> Dict[str, Any]: + return self.run_command('generate_processes', (config,)) + + def generate_steps( + self, config: Optional[dict] = None) -> Dict[str, Any]: + return self.run_command('generate_steps', (config,)) + + def generate_topology( + self, config: Optional[dict] = None) -> Topology: + return self.run_command('generate_topology', (config,)) + + def generate_flow(self, config: Optional[dict] = None) -> Flow: + return self.run_command('generate_flow', (config,)) + + @property + def schema_override(self) -> Schema: + return self.run_command('schema_override') + + @property + def parameters(self) -> Dict[str, Any]: + return self.run_command('parameters') + + @property + def condition_path(self) -> Optional[HierarchyPath]: + return self.run_command('condition_path') + + @property + def schema(self) -> Schema: + return self.run_command('schema') + + @schema.setter + def schema(self, value: Schema) -> None: + self.run_command('set_schema', (value,)) + + def merge_overrides(self, override: Schema) -> None: + self.run_command('merge_overrides', (override,)) + + def calculate_timestep(self, states: Optional[State]) -> float: + return self.run_command('calculate_timestep', (states,)) + + def is_step(self) -> bool: + return self.run_command('is_step') + + def get_private_state(self) -> State: + return self.run_command('get_private_state') + + def ports_schema(self) -> Schema: + return self.run_command('ports_schema') + + def next_update(self, timestep: float, states: State) -> Update: + return self.run_command('next_update', (timestep, states)) + + def update_condition(self, timestep: float, states: State) -> bool: + return self.run_command('update_condition', (timestep, states)) + def end(self) -> None: """End the child process. @@ -561,11 +858,21 @@ def end(self) -> None: will compile its profiling stats and send those to the parent. The parent then saves those stats in ``self.stats``. """ - self.parent.send((-1, None)) + # Only end once. + if self._ended: + return + self.send_command('end') if self.profile: - self.stats = pstats.Stats() - self.stats.stats = self.parent.recv() # type: ignore + stats = pstats.Stats() + stats.stats = self.get_command_result() # type: ignore + assert self._stats_objs is not None + self._stats_objs.append(stats) self.multiprocess.join() + self.multiprocess.close() + self._ended = True + + def __del__(self) -> None: + self.end() class ToySerializedProcess(Process): @@ -605,6 +912,31 @@ def next_update(self, timestep: float, states: State) -> Update: return {} +class ToyParallelProcess(Process): + + def compare_pid(self, pid: float) -> bool: + return os.getpid() == pid + + def send_command( + self, command: str, args: Optional[tuple] = None, + kwargs: Optional[dict] = None, + run_pre_check: bool = True) -> None: + if run_pre_check: + self.pre_send_command(command, args, kwargs) + args = args or tuple() + kwargs = kwargs or {} + if command == 'compare_pid': + self._command_result = self.compare_pid(*args, **kwargs) + else: + super().send_command(command, args, kwargs, False) + + def ports_schema(self) -> Schema: + return {} + + def next_update(self, timestep: float, states: State) -> Update: + return {} + + def test_serialize_process() -> None: proc = ToySerializedProcess() proc_pickle = pickle.loads(pickle.dumps(proc)) @@ -619,3 +951,44 @@ def test_serialize_process_inheritance() -> None: a = ToySerializedProcessInheritance({'1': 0}) a2 = pickle.loads(pickle.dumps(a)) assert a2.parameters['2'] == 0 + + +def test_process_commands_pending_safeguard() -> None: + process = ToySerializedProcess() + process.send_command('calculate_timestep', (None,)) + with pytest.raises(RuntimeError) as exception: + process.send_command('next_update', (1, {})) + msg = "command ('calculate_timestep', (None,), None) is still pending" + assert msg in str(exception.value) + + +def test_parallel_process_commands_pending_safeguard() -> None: + process = ParallelProcess(ToySerializedProcess()) + process.send_command('calculate_timestep', (None,)) + with pytest.raises(RuntimeError) as exception: + process.send_command('next_update', (1, {})) + msg = "command ('calculate_timestep', (None,), None) is still pending" + assert msg in str(exception.value) + # Reset Process._pending_command so that no warning is thrown when + # __del__() sends the 'end' command. + process.get_command_result() + + +def test_parallel_commands() -> None: + proc = ToyParallelProcess() + parallel_proc = ParallelProcess(proc) + + assert proc.compare_pid(os.getpid()) + proc.send_command('compare_pid', (os.getpid(),)) + assert proc.get_command_result() + + parallel_proc.send_command('compare_pid', (os.getpid(),)) + assert not parallel_proc.get_command_result() + + +def test_invalid_command() -> None: + proc = ToyParallelProcess() + with pytest.raises(ValueError) as exception: + proc.send_command('missing_command') + msg = 'does not understand the process command missing_command' + assert msg in str(exception.value) diff --git a/vivarium/core/serialize.py b/vivarium/core/serialize.py index 81d17df9..7ab90b8d 100644 --- a/vivarium/core/serialize.py +++ b/vivarium/core/serialize.py @@ -37,6 +37,14 @@ def serialize_value(value: Any) -> Any: f'Multiple serializers ({compatible_serializers}) found ' f'for {value} of type {type(value)}') serializer = compatible_serializers[0] + if not isinstance(value, Process): + # We don't warn for processes because since their types + # based on their subclasses, it's not possible to avoid + # searching through the serializers. + warnings.warn( + f'Searched through serializers to find {serializer} ' + f'for a value of type {type(value)}. This is ' + f'inefficient.') return serializer.serialize(value) @@ -66,7 +74,8 @@ class IdentitySerializer(Serializer): # pylint: disable=abstract-method '''Serializer for base types that get serialized as themselves.''' def __init__(self) -> None: - super().__init__(exclusive_types=(int, float, bool, str)) + super().__init__( + exclusive_types=(int, float, bool, str, type(None))) def can_serialize(self, data: Any) -> bool: if ( @@ -381,6 +390,10 @@ class FunctionSerializer(Serializer): Currently only supports serialization (for emitting simulation configs). """ + def __init__(self) -> None: + super().__init__(exclusive_types=( + type(serialize_value), # Get the function type. + )) def can_serialize(self, data: Any) -> bool: return callable(data) diff --git a/vivarium/core/serialize_test.py b/vivarium/core/serialize_test.py index d7a16060..2cc27278 100644 --- a/vivarium/core/serialize_test.py +++ b/vivarium/core/serialize_test.py @@ -23,7 +23,7 @@ def serialize_function() -> None: pass -class TestSerializer(Serializer): +class ToySerializer(Serializer): def __init__(self, prefix: str = '', suffix: str = '') -> None: super().__init__() @@ -43,50 +43,50 @@ def deserialize_from_string(self, data: str) -> str: def test_serialized_in_serializer_string() -> None: - serializer = TestSerializer(prefix='![', suffix=']') + serializer = ToySerializer(prefix='![', suffix=']') serialized = serializer.serialize('hi there!') assert serializer.deserialize(serialized) == 'hi there!' def test_unmatched_closing_bracket_in_serializer_string() -> None: - serializer = TestSerializer(prefix='', suffix=']') + serializer = ToySerializer(prefix='', suffix=']') serialized = serializer.serialize('hi there!') assert serializer.deserialize(serialized) == 'hi there!' def test_unmatched_opening_bracket_in_serializer_string() -> None: - serializer = TestSerializer(prefix='[', suffix='') + serializer = ToySerializer(prefix='[', suffix='') serialized = serializer.serialize('hi there!') print(serialized) assert serializer.deserialize(serialized) == 'hi there!' def test_open_bracket_deep_in_serializer_string() -> None: - serializer = TestSerializer(prefix='abc[', suffix='') + serializer = ToySerializer(prefix='abc[', suffix='') serialized = serializer.serialize('hi there!') assert serializer.deserialize(serialized) == 'hi there!' def test_close_bracket_deep_in_serializer_string() -> None: - serializer = TestSerializer(prefix='abc]', suffix='') + serializer = ToySerializer(prefix='abc]', suffix='') serialized = serializer.serialize('hi there!') assert serializer.deserialize(serialized) == 'hi there!' def test_serialized_prefixing_serializer_string() -> None: - serializer = TestSerializer(prefix='!TestSerializer[test]', suffix='') + serializer = ToySerializer(prefix='!ToySerializer[test]', suffix='') serialized = serializer.serialize('hi there!') assert serializer.deserialize(serialized) == 'hi there!' def test_exclamation_point_prefixing_serializer_string() -> None: - serializer = TestSerializer(prefix='!', suffix='') + serializer = ToySerializer(prefix='!', suffix='') serialized = serializer.serialize('hi there!') assert serializer.deserialize(serialized) == 'hi there!' def test_exclamation_point_suffixing_serializer_string() -> None: - serializer = TestSerializer(prefix='', suffix='!') + serializer = ToySerializer(prefix='', suffix='!') serialized = serializer.serialize('hi there!') assert serializer.deserialize(serialized) == 'hi there!' diff --git a/vivarium/experiments/test_profiler.py b/vivarium/experiments/test_profiler.py index 69c44360..9bb3812b 100644 --- a/vivarium/experiments/test_profiler.py +++ b/vivarium/experiments/test_profiler.py @@ -57,3 +57,7 @@ def test_profiler() -> None: assert 0.6 <= process_a_runtime <= 0.7 assert 0.3 <= process_b_runtime <= 0.4 + + +if __name__ == '__main__': + test_profiler() diff --git a/vivarium/library/dict_utils.py b/vivarium/library/dict_utils.py index 26e62901..452cb620 100644 --- a/vivarium/library/dict_utils.py +++ b/vivarium/library/dict_utils.py @@ -4,7 +4,7 @@ from functools import reduce import operator import traceback -from typing import Optional +from typing import Optional, Any, Callable from vivarium.library.units import Quantity @@ -314,6 +314,22 @@ def make_path_dict(embedded_dict): return path_dict +def apply_func_to_leaves(root: Any, func: Callable[[Any], None]) -> None: + '''Apply a function to every leaf node in a nested dictionary. + + >>> root = {1: [], 2: {3: [], 4: []}} + >>> func = lambda x: x.append(True) + >>> apply_func_to_leaves(root, func) + >>> root + {1: [True], 2: {3: [True], 4: [True]}} + ''' + if not isinstance(root, dict): + func(root) + return + for child in root.values(): + apply_func_to_leaves(child, func) + + def test_deep_copy_internal(): l = [1, 2, 3] d = {1: {2: l}, 3: True} diff --git a/vivarium/processes/meta_division.py b/vivarium/processes/meta_division.py index ab426ba6..ab388ed8 100644 --- a/vivarium/processes/meta_division.py +++ b/vivarium/processes/meta_division.py @@ -3,7 +3,8 @@ import logging as log from vivarium.core.process import ( - Step + Step, + ParallelProcess, ) from vivarium.core.composer import Composer from vivarium.core.directories import ( @@ -101,7 +102,9 @@ def next_update(self, timestep, states): class ToyAgent(Composer): defaults = { 'exchange': {'uptake_rate': 0.1}, - 'agents_path': ('..', '..', 'agents')} + 'agents_path': ('..', '..', 'agents'), + 'parallel': False, + } def generate_processes(self, config): agent_id = config['agent_id'] @@ -109,6 +112,8 @@ def generate_processes(self, config): {}, agent_id=agent_id, composer=self) + config['exchange']['_parallel'] = config['parallel'] + division_config['_parallel'] = config['parallel'] return { 'exchange': ExchangeA(config['exchange']), @@ -127,12 +132,7 @@ def generate_topology(self, config): } -def test_division(): - agent_id = '1' - - # timeline triggers division - time_divide = 5 - time_total = 10 +def _get_toy_experiment(agent_id, time_divide, time_total, parallel): timeline = [ (0, {('agents', agent_id, 'global', 'divide'): False}), (time_divide, {('agents', agent_id, 'global', 'divide'): True}), @@ -140,7 +140,7 @@ def test_division(): # create the processes timeline_process = TimelineProcess({'timeline': timeline}) - agent = ToyAgent({'agent_id': agent_id}) + agent = ToyAgent({'agent_id': agent_id, 'parallel': parallel}) # compose composite = agent.generate(path=('agents', agent_id)) @@ -169,9 +169,50 @@ def test_division(): composite=composite, initial_state=initial_state ) + return experiment + + + +def test_division(): + agent_id = '1' + time_divide = 5 + time_total = 10 + experiment = _get_toy_experiment( + agent_id, time_divide, time_total, False) + + # run simulation + experiment.update(time_total) + output = experiment.emitter.get_data() + experiment.end() + + # external starts at 1, goes down until death, and then back up + # internal does the inverse + assert list(output[time_divide]['agents'].keys()) == [agent_id] + assert agent_id not in list(output[time_divide + 1]['agents'].keys()) + assert len(output[time_divide]['agents']) == 1 + assert len(output[time_divide + 1]['agents']) == 2 + + return output + + +def test_division_parallel(): + agent_id = '1' + time_divide = 5 + time_total = 10 + experiment = _get_toy_experiment( + agent_id, time_divide, time_total, True) # run simulation experiment.update(time_total) + for agent in experiment.state.get_path(('agents',)).inner: + assert isinstance( + experiment.processes['agents'][agent]['exchange'], + ParallelProcess, + ) + assert isinstance( + experiment.state.get_path(('agents', agent, 'exchange')).value, + ParallelProcess, + ) output = experiment.emitter.get_data() experiment.end()