Skip to content

Commit

Permalink
Fixed mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
nathaniel-hudson committed Feb 27, 2024
1 parent d999d07 commit 871742c
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 37 deletions.
91 changes: 91 additions & 0 deletions env.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
name: flox
channels:
- bioconda
- r
- defaults
dependencies:
- bzip2=1.0.8
Expand All @@ -17,47 +19,82 @@ dependencies:
- xz=5.4.2
- zlib=1.2.13
- pip:
- aiofiles==23.2.1
- aioice==0.9.0
- aiortc==1.6.0
- aiosqlite==0.19.0
- alabaster==0.7.13
- anyio==4.0.0
- appnope==0.1.3
- argon2-cffi==23.1.0
- argon2-cffi-bindings==21.2.0
- arrow==1.3.0
- asttokens==2.4.1
- async-lru==2.0.4
- async-timeout==4.0.3
- attrs==23.1.0
- av==11.0.0
- babel==2.13.1
- backcall==0.2.0
- beautifulsoup4==4.12.2
- black==23.11.0
- bleach==6.1.0
- blinker==1.7.0
- certifi==2023.7.22
- cffi==1.16.0
- charset-normalizer==3.3.1
- click==8.1.7
- cloudpickle==3.0.0
- colorama==0.4.6
- comm==0.1.4
- contourpy==1.1.1
- coverage==6.5.0
- coveralls==3.3.1
- cryptography==41.0.5
- cycler==0.12.1
- dacite==1.8.1
- debugpy==1.8.0
- decorator==5.1.1
- defusedxml==0.7.1
- dill==0.3.5.1
- dnspython==2.4.2
- docopt==0.6.2
- docutils==0.20.1
- editorconfig==0.12.3
- exceptiongroup==1.1.3
- executing==2.0.0
- fastjsonschema==2.18.1
- filelock==3.12.4
- flask==3.0.0
- fonttools==4.43.1
- fqdn==1.5.1
- fsspec==2023.10.0
- furo==2023.9.10
- ghp-import==2.1.0
- globus-compute-common==0.3.0
- globus-compute-sdk==2.13.0
- globus-sdk==3.30.0
- google-crc32c==1.5.0
- griffe==0.36.9
- h11==0.14.0
- h2==4.1.0
- hpack==4.0.0
- httptools==0.6.1
- hypercorn==0.15.0
- hyperframe==6.0.1
- idna==3.4
- ifaddr==0.2.0
- imagesize==1.4.1
- iniconfig==2.0.0
- ipykernel==6.26.0
- ipython==8.16.1
- ipython-genutils==0.2.0
- ipywidgets==8.1.1
- isoduration==20.11.0
- itsdangerous==2.1.2
- jedi==0.19.1
- jinja2==3.1.2
- joblib==1.3.2
- jsbeautifier==1.14.9
- json5==0.9.14
- jsonpointer==2.4
Expand All @@ -76,10 +113,16 @@ dependencies:
- jupyterlab-server==2.25.0
- jupyterlab-widgets==3.0.9
- kiwisolver==1.4.5
- lazy-object-proxy==1.9.0
- lockfile==0.12.2
- markdown==3.5
- markdown-it-py==3.0.0
- markupsafe==2.1.3
- matplotlib==3.8.0
- matplotlib-inline==0.1.6
- matplotlib-stubs==0.2.0
- mdit-py-plugins==0.4.0
- mdurl==0.1.2
- mergedeep==1.3.4
- mistune==3.0.2
- mkdocs==1.5.3
Expand All @@ -93,74 +136,122 @@ dependencies:
- mkdocstrings==0.23.0
- mkdocstrings-python==1.7.3
- mpmath==1.3.0
- mypy==1.6.1
- mypy-extensions==1.0.0
- myst-parser==2.0.0
- nbclient==0.8.0
- nbconvert==7.9.2
- nbformat==5.9.2
- nest-asyncio==1.5.8
- networkx==3.2
- networkx-stubs==0.0.1
- notebook==7.0.6
- notebook-shim==0.2.3
- numpy==1.26.1
- overrides==7.4.0
- packaging==23.2
- paginate==0.5.6
- pandas==2.1.2
- pandas-stubs==2.1.1.230928
- pandocfilters==1.5.0
- parso==0.8.3
- pathspec==0.11.2
- pexpect==4.8.0
- pickleshare==0.7.5
- pika==1.3.2
- pillow==10.1.0
- platformdirs==3.11.0
- pluggy==1.3.0
- priority==2.0.0
- prometheus-client==0.17.1
- prompt-toolkit==3.0.39
- proxystore==0.6.1
- psutil==5.9.6
- ptyprocess==0.7.0
- pure-eval==0.2.2
- pyarrow==14.0.1
- pycparser==2.21
- pydantic==1.10.13
- pyee==11.1.0
- pygments==2.16.1
- pygraphviz==1.11
- pyjwt==2.8.0
- pylibsrtp==0.9.1
- pymdown-extensions==10.3.1
- pyopenssl==23.3.0
- pyparsing==3.1.1
- pystun3==1.0.0
- pytest==7.4.3
- python-daemon==3.0.1
- python-dateutil==2.8.2
- python-dotenv==1.0.0
- python-graphviz==0.20.1
- python-json-logger==2.0.7
- pytz==2023.3.post1
- pyyaml==6.0.1
- pyyaml-env-tag==0.1
- pyzmq==25.1.1
- qtconsole==5.4.4
- qtpy==2.4.1
- quart==0.19.4
- redis==5.0.1
- referencing==0.30.2
- regex==2023.10.3
- requests==2.31.0
- rfc3339-validator==0.1.4
- rfc3986-validator==0.1.1
- rpds-py==0.10.6
- scikit-learn==1.3.2
- scipy==1.11.3
- seaborn==0.13.0
- send2trash==1.8.2
- six==1.16.0
- sniffio==1.3.0
- snowballstemmer==2.2.0
- soupsieve==2.5
- sphinx==7.2.6
- sphinx-basic-ng==1.0.0b2
- sphinxcontrib-applehelp==1.0.7
- sphinxcontrib-devhelp==1.0.5
- sphinxcontrib-htmlhelp==2.0.4
- sphinxcontrib-jsmath==1.0.1
- sphinxcontrib-qthelp==1.0.6
- sphinxcontrib-serializinghtml==1.1.9
- stack-data==0.6.3
- sympy==1.12
- taskgroup==0.0.0a4
- tblib==1.7.0
- terminado==0.17.1
- texttable==1.7.0
- threadpoolctl==3.2.0
- tinycss2==1.2.1
- tomli==2.0.1
- tomli-w==1.0.0
- torch==2.1.0
- torchaudio==2.1.0
- torchvision==0.16.0
- tornado==6.3.3
- tosholi==0.1.0
- tqdm==4.66.1
- traitlets==5.12.0
- types-python-dateutil==2.8.19.14
- types-pytz==2023.3.1.1
- types-pyyaml==6.0.12.12
- types-tqdm==4.66.0.5
- typing-extensions==4.8.0
- tzdata==2023.3
- uri-template==1.3.0
- urllib3==2.0.7
- uvicorn==0.24.0.post1
- uvloop==0.19.0
- watchdog==3.0.0
- watchfiles==0.21.0
- wcwidth==0.2.8
- webcolors==1.13
- webencodings==0.5.1
- websocket-client==1.6.4
- websockets==10.3
- werkzeug==3.0.1
- widgetsnbextension==4.0.9
- wsproto==1.2.0
prefix: /Users/Nathaniel/miniconda3/envs/flox
4 changes: 2 additions & 2 deletions flox/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def federated_split(
"Provided ``Dataset`` does not override ``__len__``, which is required for ``federated_split()``."
)

num_samples_for_workers = (sample_distr * data_count).astype(int)
_num_samples = (sample_distr * data_count).astype(int)
num_samples_for_workers = {
worker.idx: num_samples
for worker, num_samples in zip(flock.workers, num_samples_for_workers)
for worker, num_samples in zip(flock.workers, _num_samples)
}
label_probs = {w.idx: label_distr[i] for i, w in enumerate(flock.workers)}

Expand Down
33 changes: 16 additions & 17 deletions flox/runtime/fit.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import datetime
import typing

import numpy as np
from pandas import DataFrame

from flox.data import FloxDataset
from flox.flock import Flock
from flox.nn import FloxModule
# from flox.run.fit_sync import sync_federated_fit
from flox.nn.typing import Kind
from flox.runtime.launcher import (
GlobusComputeLauncher,
Expand All @@ -24,20 +22,21 @@


def create_launcher(kind: str, **launcher_cfg) -> Launcher:
if kind == "thread":
return LocalLauncher(
pool="thread", n_workers=launcher_cfg.get("max_workers", 3)
)
elif kind == "process":
return LocalLauncher(
pool="process", n_workers=launcher_cfg.get("max_workers", 3)
)
elif kind == "globus-compute":
return GlobusComputeLauncher()
elif kind == "parsl":
return ParslLauncher()
else:
raise ValueError("Illegal value for argument `kind`.")
match kind:
case "thread":
return LocalLauncher(
pool="thread", n_workers=launcher_cfg.get("max_workers", 3)
)
case "process":
return LocalLauncher(
pool="process", n_workers=launcher_cfg.get("max_workers", 3)
)
case "globus-compute":
return GlobusComputeLauncher()
case "parsl":
return ParslLauncher()
case _:
raise ValueError("Illegal value for argument `kind`.")


def federated_fit(
Expand Down Expand Up @@ -106,5 +105,5 @@ def federated_fit(
start_time = datetime.datetime.now()
module, history = process.start(debug_mode)
history["train/rel_time"] = history["train/time"] - start_time
history["train/rel_time"] /= np.timedelta64(1, "s")
history["train/rel_time"] = history["train/rel_time"].dt.total_seconds()
return module, history
9 changes: 5 additions & 4 deletions flox/runtime/jobs/aggr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from flox.flock import FlockNode
from flox.flock import FlockNode, FlockNodeID
from flox.runtime.result import Result
from flox.runtime.transfer import BaseTransfer
from flox.strategies import Strategy
Expand All @@ -19,12 +19,13 @@ def aggregation_job(
Aggregation results.
"""
import pandas
from flox.flock.states import FloxAggregatorState
from flox.flock.states import FloxAggregatorState, NodeState
from flox.runtime import JobResult

child_states, child_state_dicts = {}, {}
child_states: dict[FlockNodeID, NodeState] = {}
child_state_dicts = {}
for result in results:
idx = result.node_idx
idx: FlockNodeID = result.node_idx
child_states[idx] = result.node_state
child_state_dicts[idx] = result.state_dict

Expand Down
4 changes: 3 additions & 1 deletion flox/runtime/launcher/globus_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@

from flox.flock import FlockNode

"""
NodeCallable: TypeAlias = Union[
Callable[[FlockNode], Any],
Callable[[FlockNode, Any], Any],
Callable[[FlockNode, Any, ...], Any],
]
"""


class GlobusComputeLauncher(Launcher):
Expand All @@ -33,7 +35,7 @@ def __init__(self):

def submit(
self,
fn: NodeCallable, # Callable[[FlockNode, Any, ...], Any],
fn: Callable, # NodeCallable, # Callable[[FlockNode, Any, ...], Any], # FIXME
node: FlockNode,
/,
*args,
Expand Down
4 changes: 2 additions & 2 deletions flox/runtime/process/proc_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from flox.data import FloxDataset
from flox.flock import Flock, FlockNodeID
from flox.flock.states import FloxAggregatorState, FloxWorkerState
from flox.flock.states import FloxAggregatorState, FloxWorkerState, NodeState
from flox.nn import FloxModule
from flox.runtime.jobs import local_training_job
from flox.runtime.process.proc import BaseProcess
Expand Down Expand Up @@ -64,7 +64,7 @@ def start(self, debug_mode: bool = False) -> tuple[FloxModule, DataFrame]:

histories: list[DataFrame] = []
worker_rounds: dict[FlockNodeID, int] = {}
worker_states: dict[FlockNodeID, FloxWorkerState] = {}
worker_states: dict[FlockNodeID, NodeState] = {}
worker_state_dicts: dict[FlockNodeID, StateDict] = {}
for worker in self.flock.workers:
worker_rounds[worker.idx] = 0
Expand Down
Loading

0 comments on commit 871742c

Please sign in to comment.