diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 1fe3a22..f0792a0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -20,13 +20,9 @@ jobs: python -m pip install pre-commit tox - name: run pre-commit run: pre-commit run -a - - name: mypy (sdk) + - name: mypy run: | - cd compute_sdk - tox -e mypy - - name: mypy (endpoint) - run: | - cd compute_endpoint + cd flox tox -e mypy safety-check: @@ -39,7 +35,7 @@ jobs: - name: install requirements run: | python -m pip install --upgrade pip setuptools - python -m pip install './flox' + python -m pip install '.[all]' python -m pip install safety - name: run safety check run: safety check diff --git a/docs/getting_started/flock.md b/docs/getting_started/flock.md index e41cb93..ececda4 100644 --- a/docs/getting_started/flock.md +++ b/docs/getting_started/flock.md @@ -6,7 +6,7 @@ The ``flox.flock`` module contains the code needed to define your own ``Flock`` 1. interactive mode 2. file mode -Interactive mode involves creating a ``NetworkX.DiGraph()`` object directly and then passing that into the ``Flock`` constructor. This is **not** recommended. +Interactive mode involves creating a ``NetworkX.DiGraph()`` object directly and then passing that into the ``Flock`` constructor. This is **not** recommended. The recommended approach is ***file mode***. In this mode, you define the Flock network using a supported file type (e.g., `*.yaml`) and simply use it to create the Flock instance. diff --git a/docs/getting_started/strategies/callbacks.md b/docs/getting_started/strategies/callbacks.md index 681e454..42ce006 100644 --- a/docs/getting_started/strategies/callbacks.md +++ b/docs/getting_started/strategies/callbacks.md @@ -1,7 +1,7 @@ # Strategy Callbacks ## How are Strategies defined in FLoX? -FLoX was designed to support modularity to enable creative and novel solutions for FL research. Therefore, in FLoX, we define a base ``Strategy`` class which serves as a class of callbacks. Classes that extend this base class (e.g., `FedAvg` extends `Strategy`) can implement their own unique logic which is seamlessly incorporated into the FL process. +FLoX was designed to support modularity to enable creative and novel solutions for FL research. Therefore, in FLoX, we define a base ``Strategy`` class which serves as a class of callbacks. Classes that extend this base class (e.g., `FedAvg` extends `Strategy`) can implement their own unique logic which is seamlessly incorporated into the FL process. ```python from flox.strategies import Strategy diff --git a/docs/getting_started/strategies/custom.md b/docs/getting_started/strategies/custom.md index 4f674d9..92653b3 100644 --- a/docs/getting_started/strategies/custom.md +++ b/docs/getting_started/strategies/custom.md @@ -1,12 +1,12 @@ # Defining Your Own Custom Strategies -FLoX was designed with customizability in mind. FL is a new research area that invites countless questions about how to -best perform FL. Additionally, the best FL approach will vary depending on the data, network connectivity, other +FLoX was designed with customizability in mind. FL is a new research area that invites countless questions about how to +best perform FL. Additionally, the best FL approach will vary depending on the data, network connectivity, other requirements, etc. As such, we aimed to make defining original Strategies to be as pain-free as possible. -Implementing a custom ``Strategy`` simply requires defining a new class that extends/subclasses the ``Strategy`` protocol +Implementing a custom ``Strategy`` simply requires defining a new class that extends/subclasses the ``Strategy`` protocol (as seen above). The ``Strategy`` protocol provides a handful of callbacks for you to inject custom logic to adjust how the -FL process runs. +FL process runs. As an example, let's use our source code for the implementation of ``FedProx`` as an example. diff --git a/docs/getting_started/strategies/index.md b/docs/getting_started/strategies/index.md index 85cadf1..5f83063 100644 --- a/docs/getting_started/strategies/index.md +++ b/docs/getting_started/strategies/index.md @@ -8,7 +8,7 @@ Some prominent examples from the literature of what we consider "Strategies" inc * FedAvg with Proximal Term (`FedProx`)[^fedprox] ## What _exactly_ do Strategies do? -In a nutshell, a lot. Federated Learning is a complex process with tasks being done on the worker nodes and the aggregator node(s). Thus, Strategies can touch a lot of different parts of the entire logic of an FL process. +In a nutshell, a lot. Federated Learning is a complex process with tasks being done on the worker nodes and the aggregator node(s). Thus, Strategies can touch a lot of different parts of the entire logic of an FL process. ### Model Parameter Aggregation ... diff --git a/docs/index.md b/docs/index.md index b2a1349..3b11e15 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,15 +1,15 @@ # Welcome to FLoX ### Getting Started -FLoX is a simple, highly customizable, and easy-to-deploy framework for launching Federated Learning processes across a -decentralized network. It is designed to simulate FL workflows while also making it trivially easy to deploy them on -real-world devices (e.g., Internet-of-Things and edge devices). Built on top of _Globus Compute_ (formerly known as -_funcX_), FLoX is designed to run on anything that can be started as a Globus Compute Endpoint. +FLoX is a simple, highly customizable, and easy-to-deploy framework for launching Federated Learning processes across a +decentralized network. It is designed to simulate FL workflows while also making it trivially easy to deploy them on +real-world devices (e.g., Internet-of-Things and edge devices). Built on top of _Globus Compute_ (formerly known as +_funcX_), FLoX is designed to run on anything that can be started as a Globus Compute Endpoint. ### What can FLoX do? -FLoX is supports several state-of-the-art approaches for FL processes, including hierarchical and asynchronous FL. +FLoX is supports several state-of-the-art approaches for FL processes, including hierarchical and asynchronous FL. | | 2-tier | Hierarhchical | | --: |:----------------:|:-----------------:| diff --git a/flox/data/utils.py b/flox/data/utils.py index bbfe442..63c48ba 100644 --- a/flox/data/utils.py +++ b/flox/data/utils.py @@ -60,16 +60,20 @@ def federated_split( sample_distr = stats.dirichlet(np.full(num_workers, samples_alpha)) label_distr = stats.dirichlet(np.full(num_classes, labels_alpha)) - num_samples_for_workers = (sample_distr.rvs()[0] * len(data)).astype(int) + # pytorch intentionally doesn't define an empty __len__ for DataSet, even though + # most subclasses implement it + data_count = len(data) # type: ignore + + num_samples_for_workers = (sample_distr.rvs()[0] * data_count).astype(int) num_samples_for_workers = { worker.idx: num_samples for worker, num_samples in zip(flock.workers, num_samples_for_workers) } label_probs = {w.idx: label_distr.rvs()[0] for w in flock.workers} - indices: dict[int, list[int]] = defaultdict(list) + indices: dict[FlockNodeID, list[int]] = defaultdict(list) loader = DataLoader(data, batch_size=1) - worker_samples = defaultdict(int) + worker_samples: Counter[FlockNodeID] = Counter() for idx, batch in enumerate(loader): _, y = batch label = y.item() @@ -90,11 +94,11 @@ def federated_split( ) raise err - probs = np.array(probs) - probs = probs / probs.sum() + probs_norm = np.array(probs) + probs_norm = probs_norm / probs_norm.sum() if len(temp_workers) > 0: - chosen_worker = np.random.choice(temp_workers, p=probs) + chosen_worker = np.random.choice(temp_workers, p=probs_norm) indices[chosen_worker].append(idx) worker_samples[chosen_worker] += 1 diff --git a/flox/flock/flock.py b/flox/flock/flock.py index d44722c..908bf50 100644 --- a/flox/flock/flock.py +++ b/flox/flock/flock.py @@ -58,7 +58,6 @@ def __init__(self, topo: nx.DiGraph | None = None, _src: Path | str | None = Non """ self.node_counter: int = 0 self._src = "interactive" if _src is None else _src - self.leader = None if topo is None: # By default (i.e., `topo is None`), @@ -87,6 +86,8 @@ def __init__(self, topo: nx.DiGraph | None = None, _src: Path | str | None = Non raise ValueError( "A legal Flock cannot have more than one leader." ) + if not found_leader: + raise ValueError("A legal Flock must have a leader.") def add_node( self, @@ -105,7 +106,7 @@ def add_node( proxystore_endpoint_id=proxystore_endpoint_id, ) self.node_counter += 1 - return FlockNodeID(idx) + return idx def add_edge(self, u: FlockNodeID, v: FlockNodeID, **attrs) -> None: """ @@ -226,7 +227,7 @@ def validate_topo(self) -> bool: return True - def children(self, node: FlockNode | FlockNodeID | int) -> Generator[FlockNode]: + def children(self, node: FlockNode | FlockNodeID | int) -> Iterator[FlockNode]: if isinstance(node, FlockNode): idx = node.idx else: @@ -316,7 +317,7 @@ def from_dict(content: dict[str, Any], _src: Path | str | None = None) -> "Flock return Flock(topo=topo, _src=_src) @staticmethod - def from_json(path: Path | str) -> "Flock": + def from_json(path: Path | str) -> Flock: """Imports a .json file as a Flock. Examples: @@ -329,12 +330,12 @@ def from_json(path: Path | str) -> "Flock": An instance of a Flock. """ # TODO: Figure out how to address the issue of JSON requiring string keys for `from_json()`. - with open(path, "r") as f: + with open(path) as f: content = json.load(f) return Flock.from_dict(content, _src=path) @staticmethod - def from_yaml(path: Path | str) -> "Flock": + def from_yaml(path: Path | str) -> Flock: """Imports a `.yaml` file as a Flock. Examples: @@ -346,7 +347,7 @@ def from_yaml(path: Path | str) -> "Flock": Returns: An instance of a Flock. """ - with open(path, "r") as f: + with open(path) as f: content = yaml.safe_load(f) return Flock.from_dict(content, _src=path) @@ -360,7 +361,7 @@ def globus_compute_ready(self) -> bool: """ # TODO: The leader does NOT need a Globus Compute endpoint. key = "globus_compute_endpoint" - for idx, data in self.topo.nodes(data=True): + for _idx, data in self.topo.nodes(data=True): value = data[key] if any([value is None, isinstance(value, UUID) is False]): return False @@ -377,7 +378,7 @@ def proxystore_ready(self) -> bool: in size) with Globus Compute. """ key = "proxystore_endpoint" - for idx, data in self.topo.nodes(data=True): + for _idx, data in self.topo.nodes(data=True): value = data[key] try: @@ -392,7 +393,7 @@ def proxystore_ready(self) -> bool: # return self.nodes(by_kind=FlockNodeKind.LEADER) @property - def aggregators(self) -> Generator[FlockNode]: + def aggregators(self) -> Iterator[FlockNode]: """ The aggregator nodes of the Flock. @@ -402,7 +403,7 @@ def aggregators(self) -> Generator[FlockNode]: return self.nodes(by_kind=FlockNodeKind.AGGREGATOR) @property - def workers(self) -> Generator[FlockNode]: + def workers(self) -> Iterator[FlockNode]: """ The worker nodes of the Flock. diff --git a/flox/nn/logger/base.py b/flox/nn/logger/base.py index d9338d6..a6f211f 100644 --- a/flox/nn/logger/base.py +++ b/flox/nn/logger/base.py @@ -19,7 +19,7 @@ def clear(self) -> None: def to_pandas(self) -> pd.DataFrame: df = pd.DataFrame.from_records(self.records) - for col in df: + for col in df.columns: if "time" in col: df[col] = pd.to_datetime(df[col]) return df diff --git a/flox/nn/model.py b/flox/nn/model.py index b6bc0ee..f939876 100644 --- a/flox/nn/model.py +++ b/flox/nn/model.py @@ -1,7 +1,7 @@ import torch -class FloxModule(torch.nn.Module): +class FloxModule(torch.nn.Module, ABC): """ The ``FloxModule`` is a wrapper for the standard ``torch.nn.Module`` class from PyTorch, with a lot of inspiration from the ``lightning.LightningModule`` class from PyTorch Lightning. @@ -10,6 +10,7 @@ class FloxModule(torch.nn.Module): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @abstractmethod def training_step( self, batch: torch.Tensor | tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> torch.Tensor: @@ -24,6 +25,7 @@ def training_step( Loss from the training step. """ + @abstractmethod def configure_optimizers(self) -> torch.optim.Optimizer: """Configures, initializes, and returns the optimizer used to train the model. diff --git a/flox/nn/trainer.py b/flox/nn/trainer.py index a29ca4f..d303f6a 100644 --- a/flox/nn/trainer.py +++ b/flox/nn/trainer.py @@ -58,10 +58,13 @@ def fit( loss.backward() try: + assert strategy is not None + assert node_state is not None strategy.wrk_after_train_step(node_state, loss) - except NotImplementedError: + except (AttributeError, AssertionError): """ - The current strategy does not override the `wrk_after_train_step()` callback. + ``node_state`` is None, ``strategy`` is None, or ``strategy`` doesn't + implement ``wrk_after_train_step()``. """ pass diff --git a/flox/strategies/__init__.py b/flox/strategies/__init__.py index 2f6f918..25d0a12 100644 --- a/flox/strategies/__init__.py +++ b/flox/strategies/__init__.py @@ -3,8 +3,8 @@ """ from flox.strategies.base import Strategy -from flox.strategies.registry.fedsgd import FedSGD from flox.strategies.registry.fedavg import FedAvg from flox.strategies.registry.fedprox import FedProx +from flox.strategies.registry.fedsgd import FedSGD __all__ = ["Strategy", "FedSGD", "FedAvg", "FedProx"] diff --git a/flox/strategies/base.py b/flox/strategies/base.py index 78ee666..4b29e03 100644 --- a/flox/strategies/base.py +++ b/flox/strategies/base.py @@ -5,11 +5,12 @@ from flox.flock import FlockNode, FlockNodeID from flox.flock.states import FloxWorkerState, FloxAggregatorState, NodeState from flox.typing import StateDict +from abc import abstractmethod Loss: TypeAlias = torch.Tensor -class Strategy: +class Strategy(ABC): """Base class for the logical blocks of a FL process. A ``Strategy`` in FLoX is used to implement the logic of an FL process. A ``Strategy`` provides @@ -21,7 +22,7 @@ class Strategy: they are run in an FL process. """ - registry = {} + registry: dict[str, type["Strategy"]] = {} @classmethod def get_strategy(cls, name: str) -> type["Strategy"]: @@ -42,7 +43,7 @@ def get_strategy(cls, name: str) -> type["Strategy"]: if name in cls.registry: return cls.registry[name] else: - raise KeyError(f"Strategy name ({name=}) is not in the Strategy registry.") + raise KeyError(f"Strategy name ({name}) is not in the Strategy registry.") def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) @@ -89,7 +90,7 @@ def agg_before_share_params( return state_dict #################################################################################### - # AGGREGATOR CALLBACKS. # + # CLIENT CALLBACKS. # #################################################################################### def agg_before_round(self, state: FloxAggregatorState) -> None: @@ -99,8 +100,9 @@ def agg_before_round(self, state: FloxAggregatorState) -> None: Args: state (FloxAggregatorState): The current state of the Aggregator FloxNode. """ + raise NotImplementedError() - # @required + @abstractmethod def agg_param_aggregation( self, state: FloxAggregatorState, @@ -122,15 +124,76 @@ def agg_param_aggregation( StateDict """ - # @required + def agg_worker_selection( + self, state: FloxAggregatorState, children: Iterable[FlockNode], *args, **kwargs + ) -> Iterable[FlockNode]: + """ + + Args: + state (): + children (): + *args (): + **kwargs (): + + Returns: + List of selected nodes that are children of the aggregator. + """ + + def agg_before_share_params( + self, state: FloxAggregatorState, state_dict: StateDict, *args, **kwargs + ) -> StateDict: + """Callback before sharing parameters to child nodes. + + This is mostly done is modify the global model's StateDict. This can be done to encrypt the + model parameters, apply noise, personalize, etc. + + Args: + state (FloxAggregatorState): The current state of the aggregator. + state_dict (StateDict): The global model's current StateDict (i.e., parameters) before + sharing with workers. + + Returns: + The global global_module StateDict. + """ + return state_dict #################################################################################### - # WORKER CALLBACKS. # + # AGGREGATOR CALLBACKS. # #################################################################################### - def wrk_on_recv_params( - self, state: FloxWorkerState, params: StateDict, *args, **kwargs - ): + def agg_before_round(self, state: FloxAggregatorState) -> None: + """ + Some process to run at the start of a round. + + Args: + state (FloxAggregatorState): The current state of the Aggregator FloxNode. + """ + raise NotImplementedError() + + def agg_param_aggregation( + self, + state: FloxAggregatorState, + children_states: Mapping[FlockNodeID, NodeState], + children_state_dicts: Mapping[FlockNodeID, StateDict], + *args, + **kwargs, + ) -> StateDict: + """ + + Args: + state (FloxAggregatorState): + children_states (Mapping[FlockNodeID, NodeState]): + children_state_dicts (Mapping[FlockNodeID, NodeState]): + *args (): + **kwargs (): + + Returns: + StateDict + """ + + def agg_worker_selection( + self, state: FloxAggregatorState, children: Iterable[FlockNode], *args, **kwargs + ) -> Iterable[FlockNode]: """ Args: @@ -142,7 +205,7 @@ def wrk_on_recv_params( Returns: """ - return params + return children def wrk_before_train_step(self, state: FloxWorkerState, *args, **kwargs): """ @@ -155,7 +218,7 @@ def wrk_before_train_step(self, state: FloxWorkerState, *args, **kwargs): Returns: """ - pass + raise NotImplementedError() def wrk_after_train_step( self, state: FloxWorkerState, loss: Loss, *args, **kwargs @@ -186,4 +249,20 @@ def wrk_before_submit_params( Returns: """ - return state.post_local_train_model.state_dict() + raise NotImplementedError() + + def wrk_on_recv_params( + self, state: FloxWorkerState, params: StateDict, *args, **kwargs + ): + """ + + Args: + state (): + params (): + *args (): + **kwargs (): + + Returns: + + """ + return params diff --git a/flox/strategies/commons/averaging.py b/flox/strategies/commons/averaging.py index a43b097..1803620 100644 --- a/flox/strategies/commons/averaging.py +++ b/flox/strategies/commons/averaging.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping + import numpy as np import torch @@ -6,8 +8,8 @@ def average_state_dicts( - state_dicts: dict[FlockNodeID, StateDict], - weights: dict[FlockNodeID, float] | None = None, + state_dicts: Mapping[FlockNodeID, StateDict], + weights: Mapping[FlockNodeID, float] | None = None, ) -> StateDict: """Averages the parameters given by ``global_module.state_dict()`` from a set of ``FlockNodes``. @@ -25,7 +27,7 @@ def average_state_dicts( with torch.no_grad(): avg_weights = {} for node, state_dict in state_dicts.items(): - w = 1 / num_nodes if weights is None else weights[node] / weight_sum + w = 1 / num_nodes if weights is None else weights[node] / weight_sum # type: ignore for name, value in state_dict.items(): value = w * torch.clone(value) if name not in avg_weights: diff --git a/flox/strategies/commons/worker_selection.py b/flox/strategies/commons/worker_selection.py index bcda33b..0dc60e3 100644 --- a/flox/strategies/commons/worker_selection.py +++ b/flox/strategies/commons/worker_selection.py @@ -1,15 +1,18 @@ -import numpy as np +from collections.abc import Iterable +from typing import cast + from numpy.random import RandomState +from numpy.typing import NDArray from flox.flock import FlockNode, FlockNodeKind def random_worker_selection( - children: list[FlockNode], + children: Iterable[FlockNode], participation: float = 1.0, probabilistic: bool = False, always_include_child_aggregators: bool = True, - seed: int = None, + seed: int | None = None, ) -> list[FlockNode]: """ @@ -31,9 +34,9 @@ def random_worker_selection( def fixed_random_worker_selection( - children: list[FlockNode], + children: Iterable[FlockNode], participation: float = 1.0, - seed: int = None, + seed: int | None = None, ) -> list[FlockNode]: """ @@ -48,12 +51,14 @@ def fixed_random_worker_selection( children = np.array(children) rand_state = RandomState(seed) num_selected = min(1, int(participation) * len(list(children))) - selected_children = rand_state.choice(children, size=num_selected, replace=False) + # numpy annotates RandomState.choice too narrowly; need this to satisfy mypy + achildren = cast(NDArray, children) + selected_children = rand_state.choice(achildren, size=num_selected, replace=False) return list(selected_children) def prob_random_worker_selection( - children: list[FlockNode], + children: Iterable[FlockNode], participation: float = 1.0, always_include_child_aggregators: bool = True, seed: int | None = None, @@ -78,7 +83,9 @@ def prob_random_worker_selection( selected_children.append(child) if len(selected_children) == 0: - random_child = rand_state.choice(children) + # numpy annotates RandomState.choice too narrowly; need this to satisfy mypy + achildren = cast(NDArray, children) + random_child = rand_state.choice(achildren) selected_children.append(random_child) return selected_children diff --git a/flox/strategies/registry/fedavg.py b/flox/strategies/registry/fedavg.py index 1125d93..1db3045 100644 --- a/flox/strategies/registry/fedavg.py +++ b/flox/strategies/registry/fedavg.py @@ -1,6 +1,7 @@ +from collections.abc import Mapping + from flox.flock import FlockNodeID -from flox.flock.states import FloxWorkerState -from flox.flock.states import NodeState, FloxAggregatorState +from flox.flock.states import FloxAggregatorState, FloxWorkerState, NodeState from flox.strategies.commons.averaging import average_state_dicts from flox.strategies.registry.fedsgd import FedSGD from flox.typing import StateDict @@ -47,10 +48,10 @@ def wrk_before_train_step(self, state: FloxWorkerState, *args, **kwargs): def agg_param_aggregation( self, state: FloxAggregatorState, - children_states: dict[FlockNodeID, NodeState], - children_state_dicts: dict[FlockNodeID, StateDict], - *args, - **kwargs, + children_states: Mapping[FlockNodeID, NodeState], + children_state_dicts: Mapping[FlockNodeID, StateDict], + *_args, + **_kwargs, ): weights = {} for node, child_state in children_states.items(): diff --git a/flox/strategies/registry/fedprox.py b/flox/strategies/registry/fedprox.py index b8b7141..968173c 100644 --- a/flox/strategies/registry/fedprox.py +++ b/flox/strategies/registry/fedprox.py @@ -27,7 +27,7 @@ def __init__( participation: float = 1.0, probabilistic: bool = False, always_include_child_aggregators: bool = True, - seed: int = None, + seed: int | None = None, ): """ @@ -74,6 +74,8 @@ def wrk_after_train_step( """ global_model = state.pre_local_train_model local_model = state.post_local_train_model + assert global_model is not None + assert local_model is not None params = list(local_model.state_dict().values()) params0 = list(global_model.state_dict().values()) diff --git a/flox/strategies/registry/fedsgd.py b/flox/strategies/registry/fedsgd.py index 377e076..1c88d29 100644 --- a/flox/strategies/registry/fedsgd.py +++ b/flox/strategies/registry/fedsgd.py @@ -1,5 +1,7 @@ from __future__ import annotations +from collections.abc import Iterable, Mapping + from flox.flock import FlockNode, FlockNodeID from flox.flock.states import FloxAggregatorState, NodeState from flox.strategies.base import Strategy @@ -26,7 +28,7 @@ def __init__( participation: float = 1.0, probabilistic: bool = True, always_include_child_aggregators: bool = True, - seed: int = None, + seed: int | None = None, ): """Initializes the FedSGD strategy with the desired parameters. @@ -49,7 +51,11 @@ def __init__( self.seed = seed def agg_worker_selection( - self, state: FloxAggregatorState, children: list[FlockNode], **kwargs + self, + state: FloxAggregatorState, + children: Iterable[FlockNode], + *_args, + **_kwargs, ) -> list[FlockNode]: """Performs a simple average of the model weights returned by the child nodes. @@ -82,10 +88,10 @@ def agg_worker_selection( def agg_param_aggregation( self, state: FloxAggregatorState, - children_states: dict[FlockNodeID, NodeState], - children_state_dicts: dict[FlockNodeID, StateDict], - *args, - **kwargs, + children_states: Mapping[FlockNodeID, NodeState], + children_state_dicts: Mapping[FlockNodeID, StateDict], + *_args, + **_kwargs, ) -> StateDict: """Runs simple, unweighted averaging of ``StateDict`` objects from each child node. diff --git a/flox/utils/random/flock.py b/flox/utils/random/flock.py index 4331af0..3b19750 100644 --- a/flox/utils/random/flock.py +++ b/flox/utils/random/flock.py @@ -1,6 +1,7 @@ import networkx as nx from flox.flock import Flock +from flox.flock.flock import REQUIRED_ATTRS def random_flock(num_nodes: int, seed: int | None = None) -> Flock: @@ -16,7 +17,7 @@ def random_flock(num_nodes: int, seed: int | None = None) -> Flock: # TODO: Finish this and create a test. tree = nx.random_tree(n=num_nodes, seed=seed, create_using=nx.DiGraph) for node in tree.nodes(): - for attr in Flock.required_attrs: + for attr in REQUIRED_ATTRS: tree.nodes[node][attr] = None flock = Flock(tree) return flock diff --git a/pyproject.toml b/pyproject.toml index 8607cb9..2232f3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,8 +36,10 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["black", "coverage", "jupyterlab", "matplotlib", "numpy", "pytest", "seaborn", "tensorboard", "torchvision"] +dev = ["black", "coverage", "jupyterlab", "matplotlib", "numpy", "pytest", "seaborn", "tensorboard", "torchvision", "matplotlib-stubs", "pandas-stubs", "networkx-stubs"] monitoring = ["tensorboard"] +proxystore = ["proxystore"] +all = ["flox[dev,monitoring,proxystore]"] [tool.pytest.ini_options] addopts = [ diff --git a/quickstart/README.md b/quickstart/README.md index b79ccb8..87cc00f 100644 --- a/quickstart/README.md +++ b/quickstart/README.md @@ -1,8 +1,8 @@ # Quickstart Examples This directory has some simple scripts playing around with different parts of FLoX. -They can be used to learn the package as a whole or to use as a starting point for writing +They can be used to learn the package as a whole or to use as a starting point for writing your own custom FLoX-based scripts! Note, if you want to run anything in this directory without installing FLoX (i.e., you'e just -cloned the repo to your local machine), then you will need to use `import sys; +cloned the repo to your local machine), then you will need to use `import sys; sys.path.append("..")` at the top of the script and run it from the `quickstart` directory. \ No newline at end of file diff --git a/quickstart/fashion_mnist_demo.py b/quickstart/fashion_mnist_demo.py index 9e930a6..adb2469 100644 --- a/quickstart/fashion_mnist_demo.py +++ b/quickstart/fashion_mnist_demo.py @@ -1,21 +1,23 @@ +import os import sys +from pathlib import Path -sys.path.append("..") - -import os import pandas as pd import torch - -from flox.flock import Flock -from flox.run import federated_fit -from flox.nn import FloxModule -from flox.strategies import FedProx -from flox.data.utils import federated_split -from pathlib import Path from torch import nn from torchvision.datasets import FashionMNIST from torchvision.transforms import ToTensor +try: + sys.path.append("..") + from flox.data.utils import federated_split + from flox.flock import Flock + from flox.nn import FloxModule + from flox.run import federated_fit + from flox.strategies import FedProx +except Exception as e: + raise ImportError("unable to import FloX libraries") from e + class MyModule(FloxModule): def __init__(self, lr: float = 0.01): diff --git a/quickstart/fed_data_distributions.py b/quickstart/fed_data_distributions.py index 3a75fda..e7236fb 100644 --- a/quickstart/fed_data_distributions.py +++ b/quickstart/fed_data_distributions.py @@ -1,15 +1,17 @@ +import os import sys -sys.path.append("..") - import matplotlib.pyplot as plt -import os - -from flox.flock import Flock -from flox.data.utils import federated_split, fed_barplot from torchvision.datasets import FashionMNIST from torchvision.transforms import ToTensor +try: + sys.path.append("..") + from flox.data.utils import fed_barplot, federated_split + from flox.flock import Flock +except Exception as e: + raise ImportError("unable to import FloX libraries") from e + plt.style.use("ggplot") if __name__ == "__main__": diff --git a/quickstart/flox_module.py b/quickstart/flox_module.py index 648c8f9..9cbb48c 100644 --- a/quickstart/flox_module.py +++ b/quickstart/flox_module.py @@ -1,10 +1,11 @@ import os -import torch +import torch from torch import nn from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision.transforms import ToTensor + from flox.nn import FloxModule, Trainer diff --git a/tests/data/test_datasets.py b/tests/data/test_datasets.py index f2bacf7..2246cd8 100644 --- a/tests/data/test_datasets.py +++ b/tests/data/test_datasets.py @@ -3,6 +3,7 @@ import pandas as pd import torch from sklearn.datasets import make_classification + # TODO: Get rid of `sklearn` as a dependency. from torch.utils.data import Dataset @@ -42,7 +43,7 @@ def test_dir_datasets(tmpdir): client_dir = (data_dir / f"{worker.idx}").mkdir() client_path = client_dir / "data.csv" with open(client_path, "w") as file: - print(f"x1, x2, y", file=file) + print("x1, x2, y", file=file) num_samples = rand_state.randint(low=1, high=1000) for _ in range(num_samples): a = rand_state.randint(low=-1000, high=1000) diff --git a/tests/fit/test_fit_process.py b/tests/fit/test_fit_process.py index c11bb9a..c582818 100644 --- a/tests/fit/test_fit_process.py +++ b/tests/fit/test_fit_process.py @@ -3,13 +3,15 @@ import pandas as pd import pytest import torch + from torch import nn from torchvision.datasets import MNIST from torchvision.transforms import ToTensor -from flox import Flock, federated_fit -from flox.data.utils import federated_split +from flox.flock import Flock from flox.nn import FloxModule +from flox.run import federated_fit +from flox.data.utils import federated_split class MyModule(FloxModule): diff --git a/tox.ini b/tox.ini index e3628fb..583b8c1 100644 --- a/tox.ini +++ b/tox.ini @@ -16,6 +16,9 @@ commands = [testenv:mypy] deps = mypy>=1.6.1 +extras = + dev + proxystore commands = mypy --install-types --non-interactive -p flox {posargs}