Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix actions #27

Merged
merged 7 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,9 @@ jobs:
python -m pip install pre-commit tox
- name: run pre-commit
run: pre-commit run -a
- name: mypy (sdk)
- name: mypy
run: |
cd compute_sdk
tox -e mypy
- name: mypy (endpoint)
run: |
cd compute_endpoint
cd flox
tox -e mypy

safety-check:
Expand All @@ -39,7 +35,7 @@ jobs:
- name: install requirements
run: |
python -m pip install --upgrade pip setuptools
python -m pip install './flox'
python -m pip install '.[all]'
python -m pip install safety
- name: run safety check
run: safety check
Expand Down
2 changes: 1 addition & 1 deletion docs/getting_started/flock.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The ``flox.flock`` module contains the code needed to define your own ``Flock``
1. interactive mode
2. file mode

Interactive mode involves creating a ``NetworkX.DiGraph()`` object directly and then passing that into the ``Flock`` constructor. This is **not** recommended.
Interactive mode involves creating a ``NetworkX.DiGraph()`` object directly and then passing that into the ``Flock`` constructor. This is **not** recommended.

The recommended approach is ***file mode***. In this mode, you define the Flock network using a supported file type (e.g., `*.yaml`) and simply use it to create the Flock instance.

Expand Down
2 changes: 1 addition & 1 deletion docs/getting_started/strategies/callbacks.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Strategy Callbacks

## How are Strategies defined in FLoX?
FLoX was designed to support modularity to enable creative and novel solutions for FL research. Therefore, in FLoX, we define a base ``Strategy`` class which serves as a class of callbacks. Classes that extend this base class (e.g., `FedAvg` extends `Strategy`) can implement their own unique logic which is seamlessly incorporated into the FL process.
FLoX was designed to support modularity to enable creative and novel solutions for FL research. Therefore, in FLoX, we define a base ``Strategy`` class which serves as a class of callbacks. Classes that extend this base class (e.g., `FedAvg` extends `Strategy`) can implement their own unique logic which is seamlessly incorporated into the FL process.

```python
from flox.strategies import Strategy
Expand Down
8 changes: 4 additions & 4 deletions docs/getting_started/strategies/custom.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Defining Your Own Custom Strategies

FLoX was designed with customizability in mind. FL is a new research area that invites countless questions about how to
best perform FL. Additionally, the best FL approach will vary depending on the data, network connectivity, other
FLoX was designed with customizability in mind. FL is a new research area that invites countless questions about how to
best perform FL. Additionally, the best FL approach will vary depending on the data, network connectivity, other
requirements, etc. As such, we aimed to make defining original Strategies to be as pain-free as possible.

Implementing a custom ``Strategy`` simply requires defining a new class that extends/subclasses the ``Strategy`` protocol
Implementing a custom ``Strategy`` simply requires defining a new class that extends/subclasses the ``Strategy`` protocol
(as seen above). The ``Strategy`` protocol provides a handful of callbacks for you to inject custom logic to adjust how the
FL process runs.
FL process runs.

As an example, let's use our source code for the implementation of ``FedProx`` as an example.

Expand Down
2 changes: 1 addition & 1 deletion docs/getting_started/strategies/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Some prominent examples from the literature of what we consider "Strategies" inc
* FedAvg with Proximal Term (`FedProx`)[^fedprox]

## What _exactly_ do Strategies do?
In a nutshell, a lot. Federated Learning is a complex process with tasks being done on the worker nodes and the aggregator node(s). Thus, Strategies can touch a lot of different parts of the entire logic of an FL process.
In a nutshell, a lot. Federated Learning is a complex process with tasks being done on the worker nodes and the aggregator node(s). Thus, Strategies can touch a lot of different parts of the entire logic of an FL process.

### Model Parameter Aggregation
...
Expand Down
10 changes: 5 additions & 5 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Welcome to FLoX

### Getting Started
FLoX is a simple, highly customizable, and easy-to-deploy framework for launching Federated Learning processes across a
decentralized network. It is designed to simulate FL workflows while also making it trivially easy to deploy them on
real-world devices (e.g., Internet-of-Things and edge devices). Built on top of _Globus Compute_ (formerly known as
_funcX_), FLoX is designed to run on anything that can be started as a Globus Compute Endpoint.
FLoX is a simple, highly customizable, and easy-to-deploy framework for launching Federated Learning processes across a
decentralized network. It is designed to simulate FL workflows while also making it trivially easy to deploy them on
real-world devices (e.g., Internet-of-Things and edge devices). Built on top of _Globus Compute_ (formerly known as
_funcX_), FLoX is designed to run on anything that can be started as a Globus Compute Endpoint.


### What can FLoX do?

FLoX is supports several state-of-the-art approaches for FL processes, including hierarchical and asynchronous FL.
FLoX is supports several state-of-the-art approaches for FL processes, including hierarchical and asynchronous FL.

| | 2-tier | Hierarhchical |
| --: |:----------------:|:-----------------:|
Expand Down
16 changes: 10 additions & 6 deletions flox/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,20 @@ def federated_split(
sample_distr = stats.dirichlet(np.full(num_workers, samples_alpha))
label_distr = stats.dirichlet(np.full(num_classes, labels_alpha))

num_samples_for_workers = (sample_distr.rvs()[0] * len(data)).astype(int)
# pytorch intentionally doesn't define an empty __len__ for DataSet, even though
# most subclasses implement it
data_count = len(data) # type: ignore

num_samples_for_workers = (sample_distr.rvs()[0] * data_count).astype(int)
num_samples_for_workers = {
worker.idx: num_samples
for worker, num_samples in zip(flock.workers, num_samples_for_workers)
}
label_probs = {w.idx: label_distr.rvs()[0] for w in flock.workers}

indices: dict[int, list[int]] = defaultdict(list)
indices: dict[FlockNodeID, list[int]] = defaultdict(list)
loader = DataLoader(data, batch_size=1)
worker_samples = defaultdict(int)
worker_samples: Counter[FlockNodeID] = Counter()
for idx, batch in enumerate(loader):
_, y = batch
label = y.item()
Expand All @@ -90,11 +94,11 @@ def federated_split(
)
raise err

probs = np.array(probs)
probs = probs / probs.sum()
probs_norm = np.array(probs)
probs_norm = probs_norm / probs_norm.sum()

if len(temp_workers) > 0:
chosen_worker = np.random.choice(temp_workers, p=probs)
chosen_worker = np.random.choice(temp_workers, p=probs_norm)
indices[chosen_worker].append(idx)
worker_samples[chosen_worker] += 1

Expand Down
23 changes: 12 additions & 11 deletions flox/flock/flock.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(self, topo: nx.DiGraph | None = None, _src: Path | str | None = Non
"""
self.node_counter: int = 0
self._src = "interactive" if _src is None else _src
self.leader = None

if topo is None:
# By default (i.e., `topo is None`),
Expand Down Expand Up @@ -87,6 +86,8 @@ def __init__(self, topo: nx.DiGraph | None = None, _src: Path | str | None = Non
raise ValueError(
"A legal Flock cannot have more than one leader."
)
if not found_leader:
raise ValueError("A legal Flock must have a leader.")

def add_node(
self,
Expand All @@ -105,7 +106,7 @@ def add_node(
proxystore_endpoint_id=proxystore_endpoint_id,
)
self.node_counter += 1
return FlockNodeID(idx)
return idx

def add_edge(self, u: FlockNodeID, v: FlockNodeID, **attrs) -> None:
"""
Expand Down Expand Up @@ -226,7 +227,7 @@ def validate_topo(self) -> bool:

return True

def children(self, node: FlockNode | FlockNodeID | int) -> Generator[FlockNode]:
def children(self, node: FlockNode | FlockNodeID | int) -> Iterator[FlockNode]:
if isinstance(node, FlockNode):
idx = node.idx
else:
Expand Down Expand Up @@ -316,7 +317,7 @@ def from_dict(content: dict[str, Any], _src: Path | str | None = None) -> "Flock
return Flock(topo=topo, _src=_src)

@staticmethod
def from_json(path: Path | str) -> "Flock":
def from_json(path: Path | str) -> Flock:
"""Imports a .json file as a Flock.

Examples:
Expand All @@ -329,12 +330,12 @@ def from_json(path: Path | str) -> "Flock":
An instance of a Flock.
"""
# TODO: Figure out how to address the issue of JSON requiring string keys for `from_json()`.
with open(path, "r") as f:
with open(path) as f:
content = json.load(f)
return Flock.from_dict(content, _src=path)

@staticmethod
def from_yaml(path: Path | str) -> "Flock":
def from_yaml(path: Path | str) -> Flock:
"""Imports a `.yaml` file as a Flock.

Examples:
Expand All @@ -346,7 +347,7 @@ def from_yaml(path: Path | str) -> "Flock":
Returns:
An instance of a Flock.
"""
with open(path, "r") as f:
with open(path) as f:
content = yaml.safe_load(f)
return Flock.from_dict(content, _src=path)

Expand All @@ -360,7 +361,7 @@ def globus_compute_ready(self) -> bool:
"""
# TODO: The leader does NOT need a Globus Compute endpoint.
key = "globus_compute_endpoint"
for idx, data in self.topo.nodes(data=True):
for _idx, data in self.topo.nodes(data=True):
value = data[key]
if any([value is None, isinstance(value, UUID) is False]):
return False
Expand All @@ -377,7 +378,7 @@ def proxystore_ready(self) -> bool:
in size) with Globus Compute.
"""
key = "proxystore_endpoint"
for idx, data in self.topo.nodes(data=True):
for _idx, data in self.topo.nodes(data=True):
value = data[key]

try:
Expand All @@ -392,7 +393,7 @@ def proxystore_ready(self) -> bool:
# return self.nodes(by_kind=FlockNodeKind.LEADER)

@property
def aggregators(self) -> Generator[FlockNode]:
def aggregators(self) -> Iterator[FlockNode]:
"""
The aggregator nodes of the Flock.

Expand All @@ -402,7 +403,7 @@ def aggregators(self) -> Generator[FlockNode]:
return self.nodes(by_kind=FlockNodeKind.AGGREGATOR)

@property
def workers(self) -> Generator[FlockNode]:
def workers(self) -> Iterator[FlockNode]:
"""
The worker nodes of the Flock.

Expand Down
2 changes: 1 addition & 1 deletion flox/nn/logger/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def clear(self) -> None:

def to_pandas(self) -> pd.DataFrame:
df = pd.DataFrame.from_records(self.records)
for col in df:
for col in df.columns:
if "time" in col:
df[col] = pd.to_datetime(df[col])
return df
4 changes: 3 additions & 1 deletion flox/nn/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch


class FloxModule(torch.nn.Module):
class FloxModule(torch.nn.Module, ABC):
"""
The ``FloxModule`` is a wrapper for the standard ``torch.nn.Module`` class from PyTorch, with
a lot of inspiration from the ``lightning.LightningModule`` class from PyTorch Lightning.
Expand All @@ -10,6 +10,7 @@ class FloxModule(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@abstractmethod
def training_step(
self, batch: torch.Tensor | tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
Expand All @@ -24,6 +25,7 @@ def training_step(
Loss from the training step.
"""

@abstractmethod
def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configures, initializes, and returns the optimizer used to train the model.

Expand Down
7 changes: 5 additions & 2 deletions flox/nn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ def fit(
loss.backward()

try:
assert strategy is not None
assert node_state is not None
strategy.wrk_after_train_step(node_state, loss)
except NotImplementedError:
except (AttributeError, AssertionError):
"""
The current strategy does not override the `wrk_after_train_step()` callback.
``node_state`` is None, ``strategy`` is None, or ``strategy`` doesn't
implement ``wrk_after_train_step()``.
"""
pass

Expand Down
2 changes: 1 addition & 1 deletion flox/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
"""

from flox.strategies.base import Strategy
from flox.strategies.registry.fedsgd import FedSGD
from flox.strategies.registry.fedavg import FedAvg
from flox.strategies.registry.fedprox import FedProx
from flox.strategies.registry.fedsgd import FedSGD

__all__ = ["Strategy", "FedSGD", "FedAvg", "FedProx"]
Loading
Loading