diff --git a/openfl/experimental/interface/fl_spec.py b/openfl/experimental/interface/fl_spec.py index aec582580b..257284ea98 100644 --- a/openfl/experimental/interface/fl_spec.py +++ b/openfl/experimental/interface/fl_spec.py @@ -23,7 +23,6 @@ class FLSpec: - _clones = [] _initial_state = None @@ -65,6 +64,16 @@ def run(self) -> None: print(f"Created flow {self.__class__.__name__}") try: self.start() + + # execute_task_args will be updated in self.start() + # after the next function is executed + self.runtime.execute_task( + self, + self.execute_task_args[0], + self.execute_task_args[1], + self.execute_task_args[2], + **self.execute_task_args[3], + ) except Exception as e: if "cannot pickle" in str(e) or "Failed to unpickle" in str(e): msg = ( @@ -130,9 +139,7 @@ def _is_at_transition_point(self, f: Callable, parent_func: Callable) -> bool: if parent_func.__name__ in self._foreach_methods: self._foreach_methods.append(f.__name__) if should_transfer(f, parent_func): - print( - f"Should transfer from {parent_func.__name__} to {f.__name__}" - ) + print(f"Should transfer from {parent_func.__name__} to {f.__name__}") self.execute_next = f.__name__ return True return False @@ -171,16 +178,7 @@ def next(self, f: Callable, **kwargs) -> None: # Remove included / excluded attributes from next task filter_attributes(self, f, **kwargs) - if self._is_at_transition_point(f, parent_func): - # Collaborator is done executing for now - return - self._display_transition_logs(f, parent_func) - self._runtime.execute_task( - self, - f, - parent_func, - instance_snapshot=agg_to_collab_ss, - **kwargs, - ) + # update parameters for execute_task function + self.execute_task_args = [f, parent_func, agg_to_collab_ss, kwargs] diff --git a/openfl/experimental/placement/placement.py b/openfl/experimental/placement/placement.py index 0662137add..6a8afd7671 100644 --- a/openfl/experimental/placement/placement.py +++ b/openfl/experimental/placement/placement.py @@ -27,7 +27,9 @@ def ray_call_put(self, ctx, func): def get_remote_clones(self): clones = deepcopy(ray.get(self.remote_functions)) + # delete remote_functions to free ray memory and reinitialize del self.remote_functions + self.remote_functions = [] # Remove clones from ray object store for ctx in self.remote_contexts: ray.cancel(ctx) diff --git a/openfl/experimental/runtime/local_runtime.py b/openfl/experimental/runtime/local_runtime.py index efac60efae..0bddb126d6 100644 --- a/openfl/experimental/runtime/local_runtime.py +++ b/openfl/experimental/runtime/local_runtime.py @@ -9,6 +9,7 @@ import gc from openfl.experimental.runtime import Runtime from typing import TYPE_CHECKING + if TYPE_CHECKING: from openfl.experimental.interface import Aggregator, Collaborator, FLSpec from openfl.experimental.placement import RayExecutor @@ -21,6 +22,7 @@ from typing import List from typing import Type from typing import Callable +import importlib class LocalRuntime(Runtime): @@ -101,9 +103,7 @@ def collaborators(self, collaborators: List[Type[Collaborator]]): } def restore_instance_snapshot( - self, - ctx: Type[FLSpec], - instance_snapshot: List[Type[FLSpec]] + self, ctx: Type[FLSpec], instance_snapshot: List[Type[FLSpec]] ): """Restores attributes from backup (in instance snapshot) to ctx""" for backup in instance_snapshot: @@ -118,115 +118,209 @@ def execute_task( f: Callable, parent_func: Callable, instance_snapshot: List[Type[FLSpec]] = [], - **kwargs + **kwargs, ): """ - Performs the execution of a task as defined by the - implementation and underlying backend (single_process, ray, etc) - on a single node + Defines which function to be executed based on name and kwargs + Updates the arguments and executes until end is not reached Args: flspec_obj: Reference to the FLSpec (flow) object. Contains information - about task sequence, flow attributes, that are needed to - execute a future task + about task sequence, flow attributes. f: The next task to be executed within the flow parent_func: The prior task executed in the flow instance_snapshot: A prior FLSpec state that needs to be restored from (i.e. restoring aggregator state after collaborator execution) """ - from openfl.experimental.interface import ( - FLSpec, - final_attributes, - ) + + while f.__name__ != "end": + if "foreach" in kwargs: + f, parent_func, instance_snapshot, kwargs = self.execute_foreach_task( + flspec_obj, f, parent_func, instance_snapshot, **kwargs + ) + else: + f, parent_func, instance_snapshot, kwargs = self.execute_agg_task( + flspec_obj, f + ) + else: + self.execute_end_task(flspec_obj, f) + + def execute_agg_task(self, flspec_obj, f): + """ + Performs execution of aggregator task + Args: + flspec_obj : Reference to the FLSpec (flow) object + f : The task to be executed within the flow + + Returns: + list: updated arguments to be executed + """ + + to_exec = getattr(flspec_obj, f.__name__) + to_exec() + return flspec_obj.execute_task_args + + def execute_end_task(self, flspec_obj, f): + """ + Performs execution of end task + Args: + flspec_obj : Reference to the FLSpec (flow) object + f : The task to be executed within the flow + + Returns: + list: updated arguments to be executed + """ global final_attributes + final_attr_module = importlib.import_module("openfl.experimental.interface") + final_attributes = getattr(final_attr_module, "final_attributes") - if "foreach" in kwargs: - flspec_obj._foreach_methods.append(f.__name__) - selected_collaborators = flspec_obj.__getattribute__( - kwargs["foreach"] - ) + to_exec = getattr(flspec_obj, f.__name__) + to_exec() + checkpoint(flspec_obj, f) + artifacts_iter, _ = generate_artifacts(ctx=flspec_obj) + final_attributes = artifacts_iter() + return - for col in selected_collaborators: - clone = FLSpec._clones[col] - if ( - "exclude" in kwargs and hasattr(clone, kwargs["exclude"][0]) - ) or ( - "include" in kwargs and hasattr(clone, kwargs["include"][0]) - ): - filter_attributes(clone, f, **kwargs) - artifacts_iter, _ = generate_artifacts(ctx=flspec_obj) - for name, attr in artifacts_iter(): - setattr(clone, name, deepcopy(attr)) - clone._foreach_methods = flspec_obj._foreach_methods - - for col in selected_collaborators: - clone = FLSpec._clones[col] - clone.input = col - if aggregator_to_collaborator(f, parent_func): - # remove private aggregator state - for attr in self._aggregator.private_attributes: - self._aggregator.private_attributes[attr] = getattr( - flspec_obj, attr - ) - if hasattr(clone, attr): - delattr(clone, attr) - - func = None - if self.backend == "ray": - ray_executor = RayExecutor() - for col in selected_collaborators: - clone = FLSpec._clones[col] - # Set new LocalRuntime for clone as it is required - # for calling execute_task and also new runtime - # object will not contain private attributes of - # aggregator or other collaborators - clone.runtime = LocalRuntime(backend="single_process") - for name, attr in self.__collaborators[ - clone.input - ].private_attributes.items(): - setattr(clone, name, attr) + def execute_foreach_task( + self, flspec_obj, f, parent_func, instance_snapshot, **kwargs + ): + """ + Performs + 1. Filter include/exclude + 2. Remove aggregator private attributes + 3. Set runtime, collab private attributes , metaflow_interface + 4. Execution of all collaborator for each task + 5. Remove collaborator private attributes + 6. Execute the next function after transition + + Args: + flspec_obj : Reference to the FLSpec (flow) object + f : The task to be executed within the flow + parent_func : The prior task executed in the flow + instance_snapshot : A prior FLSpec state that needs to be restored + + Returns: + list: updated arguments to be executed + """ + + flspec_module = importlib.import_module("openfl.experimental.interface") + flspec_class = getattr(flspec_module, "FLSpec") + flspec_obj._foreach_methods.append(f.__name__) + selected_collaborators = getattr(flspec_obj, kwargs["foreach"]) + + # filter exclude/include attributes for clone + self.filter_exclude_include(flspec_obj, f, selected_collaborators, **kwargs) + + # Remove aggregator private attributes + for col in selected_collaborators: + clone = flspec_class._clones[col] + if aggregator_to_collaborator(f, parent_func): + for attr in self._aggregator.private_attributes: + self._aggregator.private_attributes[attr] = getattr( + flspec_obj, attr + ) + if hasattr(clone, attr): + delattr(clone, attr) + + if self.backend == "ray": + ray_executor = RayExecutor() + + # set runtime,collab private attributes and metaflowinterface + for col in selected_collaborators: + clone = flspec_class._clones[col] + # Set new LocalRuntime for clone as it is required + # new runtime object will not contain private attributes of + # aggregator or other collaborators + clone.runtime = LocalRuntime(backend="single_process") + + # set collab private attributes + for name, attr in self.__collaborators[ + clone.input + ].private_attributes.items(): + setattr(clone, name, attr) + + # write the clone to the object store + # ensure clone is getting latest _metaflow_interface + clone._metaflow_interface = flspec_obj._metaflow_interface + + # For initial step assume there is no trasition from collab_to_agg + not_at_transition_point = True + + # loop until there is no transition + while not_at_transition_point: + # execute to_exec for for each collab + for collab in selected_collaborators: + clone = flspec_class._clones[collab] + # get the function to be executed to_exec = getattr(clone, f.__name__) - # write the clone to the object store - # ensure clone is getting latest _metaflow_interface - clone._metaflow_interface = flspec_obj._metaflow_interface + if self.backend == "ray": ray_executor.ray_call_put(clone, to_exec) else: to_exec() + if self.backend == "ray": + # Execute the collab steps clones = ray_executor.get_remote_clones() - FLSpec._clones.update(zip(selected_collaborators, clones)) - del ray_executor - del clones - gc.collect() - for col in selected_collaborators: - clone = FLSpec._clones[col] - func = clone.execute_next - for attr in self.__collaborators[ - clone.input - ].private_attributes: - if hasattr(clone, attr): - self.__collaborators[clone.input].private_attributes[ - attr - ] = getattr(clone, attr) - delattr(clone, attr) - # Restore the flspec_obj state if back-up is taken - self.restore_instance_snapshot(flspec_obj, instance_snapshot) - del instance_snapshot + flspec_class._clones.update(zip(selected_collaborators, clones)) + + # update the next arguments + f, parent_func, _, kwargs = flspec_class._clones[collab].execute_task_args - g = getattr(flspec_obj, func) - # remove private collaborator state + # check for transition + if flspec_class._clones[collab]._is_at_transition_point(f, parent_func): + not_at_transition_point = False + + # remove clones after transition + if self.backend == "ray": + del ray_executor + del clones gc.collect() - g([FLSpec._clones[col] for col in selected_collaborators]) - else: - to_exec = getattr(flspec_obj, f.__name__) - to_exec() - if f.__name__ == "end": - checkpoint(flspec_obj, f) - artifacts_iter, _ = generate_artifacts(ctx=flspec_obj) - final_attributes = artifacts_iter() + + # Removes collaborator private attributes after transition + for col in selected_collaborators: + clone = flspec_class._clones[col] + for attr in self.__collaborators[clone.input].private_attributes: + if hasattr(clone, attr): + self.__collaborators[clone.input].private_attributes[ + attr + ] = getattr(clone, attr) + delattr(clone, attr) + + # Restore the flspec_obj state if back-up is taken + self.restore_instance_snapshot(flspec_obj, instance_snapshot) + del instance_snapshot + + g = getattr(flspec_obj, f.__name__) + gc.collect() + g([flspec_class._clones[col] for col in selected_collaborators]) + return flspec_obj.execute_task_args + + def filter_exclude_include(self, flspec_obj, f, selected_collaborators, **kwargs): + """ + This function filters exclude/include attributes + Args: + flspec_obj : Reference to the FLSpec (flow) object + f : The task to be executed within the flow + selected_collaborators : all collaborators + """ + + flspec_module = importlib.import_module("openfl.experimental.interface") + flspec_class = getattr(flspec_module, "FLSpec") + + for col in selected_collaborators: + clone = flspec_class._clones[col] + clone.input = col + if ("exclude" in kwargs and hasattr(clone, kwargs["exclude"][0])) or ( + "include" in kwargs and hasattr(clone, kwargs["include"][0]) + ): + filter_attributes(clone, f, **kwargs) + artifacts_iter, _ = generate_artifacts(ctx=flspec_obj) + for name, attr in artifacts_iter(): + setattr(clone, name, deepcopy(attr)) + clone._foreach_methods = flspec_obj._foreach_methods def __repr__(self): return "LocalRuntime"