From 871742ca33382616ec65e09a9af4ca585e44a8b9 Mon Sep 17 00:00:00 2001 From: Nathaniel Hudson Date: Tue, 27 Feb 2024 12:10:02 -0600 Subject: [PATCH] Fixed mypy errors --- env.yml | 91 +++++++++++++++++++++++++ flox/data/utils.py | 4 +- flox/runtime/fit.py | 33 +++++---- flox/runtime/jobs/aggr.py | 9 +-- flox/runtime/launcher/globus_compute.py | 4 +- flox/runtime/process/proc_async.py | 4 +- flox/runtime/result.py | 14 ++-- flox/strategies/base.py | 8 +-- 8 files changed, 130 insertions(+), 37 deletions(-) diff --git a/env.yml b/env.yml index de92146..7f91dc3 100644 --- a/env.yml +++ b/env.yml @@ -1,5 +1,7 @@ name: flox channels: + - bioconda + - r - defaults dependencies: - bzip2=1.0.8 @@ -17,6 +19,11 @@ dependencies: - xz=5.4.2 - zlib=1.2.13 - pip: + - aiofiles==23.2.1 + - aioice==0.9.0 + - aiortc==1.6.0 + - aiosqlite==0.19.0 + - alabaster==0.7.13 - anyio==4.0.0 - appnope==0.1.3 - argon2-cffi==23.1.0 @@ -24,40 +31,70 @@ dependencies: - arrow==1.3.0 - asttokens==2.4.1 - async-lru==2.0.4 + - async-timeout==4.0.3 - attrs==23.1.0 + - av==11.0.0 - babel==2.13.1 - backcall==0.2.0 - beautifulsoup4==4.12.2 + - black==23.11.0 - bleach==6.1.0 + - blinker==1.7.0 - certifi==2023.7.22 - cffi==1.16.0 - charset-normalizer==3.3.1 - click==8.1.7 + - cloudpickle==3.0.0 - colorama==0.4.6 - comm==0.1.4 - contourpy==1.1.1 + - coverage==6.5.0 + - coveralls==3.3.1 + - cryptography==41.0.5 - cycler==0.12.1 + - dacite==1.8.1 - debugpy==1.8.0 - decorator==5.1.1 - defusedxml==0.7.1 + - dill==0.3.5.1 + - dnspython==2.4.2 + - docopt==0.6.2 + - docutils==0.20.1 - editorconfig==0.12.3 - exceptiongroup==1.1.3 - executing==2.0.0 - fastjsonschema==2.18.1 - filelock==3.12.4 + - flask==3.0.0 - fonttools==4.43.1 - fqdn==1.5.1 - fsspec==2023.10.0 + - furo==2023.9.10 - ghp-import==2.1.0 + - globus-compute-common==0.3.0 + - globus-compute-sdk==2.13.0 + - globus-sdk==3.30.0 + - google-crc32c==1.5.0 - griffe==0.36.9 + - h11==0.14.0 + - h2==4.1.0 + - hpack==4.0.0 + - httptools==0.6.1 + - hypercorn==0.15.0 + - hyperframe==6.0.1 - idna==3.4 + - ifaddr==0.2.0 + - imagesize==1.4.1 + - iniconfig==2.0.0 - ipykernel==6.26.0 - ipython==8.16.1 - ipython-genutils==0.2.0 - ipywidgets==8.1.1 - isoduration==20.11.0 + - itsdangerous==2.1.2 - jedi==0.19.1 - jinja2==3.1.2 + - joblib==1.3.2 - jsbeautifier==1.14.9 - json5==0.9.14 - jsonpointer==2.4 @@ -76,10 +113,16 @@ dependencies: - jupyterlab-server==2.25.0 - jupyterlab-widgets==3.0.9 - kiwisolver==1.4.5 + - lazy-object-proxy==1.9.0 + - lockfile==0.12.2 - markdown==3.5 + - markdown-it-py==3.0.0 - markupsafe==2.1.3 - matplotlib==3.8.0 - matplotlib-inline==0.1.6 + - matplotlib-stubs==0.2.0 + - mdit-py-plugins==0.4.0 + - mdurl==0.1.2 - mergedeep==1.3.4 - mistune==3.0.2 - mkdocs==1.5.3 @@ -93,11 +136,15 @@ dependencies: - mkdocstrings==0.23.0 - mkdocstrings-python==1.7.3 - mpmath==1.3.0 + - mypy==1.6.1 + - mypy-extensions==1.0.0 + - myst-parser==2.0.0 - nbclient==0.8.0 - nbconvert==7.9.2 - nbformat==5.9.2 - nest-asyncio==1.5.8 - networkx==3.2 + - networkx-stubs==0.0.1 - notebook==7.0.6 - notebook-shim==0.2.3 - numpy==1.26.1 @@ -105,24 +152,40 @@ dependencies: - packaging==23.2 - paginate==0.5.6 - pandas==2.1.2 + - pandas-stubs==2.1.1.230928 - pandocfilters==1.5.0 - parso==0.8.3 - pathspec==0.11.2 - pexpect==4.8.0 - pickleshare==0.7.5 + - pika==1.3.2 - pillow==10.1.0 - platformdirs==3.11.0 + - pluggy==1.3.0 + - priority==2.0.0 - prometheus-client==0.17.1 - prompt-toolkit==3.0.39 + - proxystore==0.6.1 - psutil==5.9.6 - ptyprocess==0.7.0 - pure-eval==0.2.2 + - pyarrow==14.0.1 - pycparser==2.21 + - pydantic==1.10.13 + - pyee==11.1.0 - pygments==2.16.1 - pygraphviz==1.11 + - pyjwt==2.8.0 + - pylibsrtp==0.9.1 - pymdown-extensions==10.3.1 + - pyopenssl==23.3.0 - pyparsing==3.1.1 + - pystun3==1.0.0 + - pytest==7.4.3 + - python-daemon==3.0.1 - python-dateutil==2.8.2 + - python-dotenv==1.0.0 + - python-graphviz==0.20.1 - python-json-logger==2.0.7 - pytz==2023.3.post1 - pyyaml==6.0.1 @@ -130,37 +193,65 @@ dependencies: - pyzmq==25.1.1 - qtconsole==5.4.4 - qtpy==2.4.1 + - quart==0.19.4 + - redis==5.0.1 - referencing==0.30.2 - regex==2023.10.3 - requests==2.31.0 - rfc3339-validator==0.1.4 - rfc3986-validator==0.1.1 - rpds-py==0.10.6 + - scikit-learn==1.3.2 - scipy==1.11.3 - seaborn==0.13.0 - send2trash==1.8.2 - six==1.16.0 - sniffio==1.3.0 + - snowballstemmer==2.2.0 - soupsieve==2.5 + - sphinx==7.2.6 + - sphinx-basic-ng==1.0.0b2 + - sphinxcontrib-applehelp==1.0.7 + - sphinxcontrib-devhelp==1.0.5 + - sphinxcontrib-htmlhelp==2.0.4 + - sphinxcontrib-jsmath==1.0.1 + - sphinxcontrib-qthelp==1.0.6 + - sphinxcontrib-serializinghtml==1.1.9 - stack-data==0.6.3 - sympy==1.12 + - taskgroup==0.0.0a4 + - tblib==1.7.0 - terminado==0.17.1 + - texttable==1.7.0 + - threadpoolctl==3.2.0 - tinycss2==1.2.1 - tomli==2.0.1 + - tomli-w==1.0.0 - torch==2.1.0 - torchaudio==2.1.0 - torchvision==0.16.0 - tornado==6.3.3 + - tosholi==0.1.0 - tqdm==4.66.1 - traitlets==5.12.0 - types-python-dateutil==2.8.19.14 + - types-pytz==2023.3.1.1 + - types-pyyaml==6.0.12.12 + - types-tqdm==4.66.0.5 - typing-extensions==4.8.0 - tzdata==2023.3 - uri-template==1.3.0 - urllib3==2.0.7 + - uvicorn==0.24.0.post1 + - uvloop==0.19.0 - watchdog==3.0.0 + - watchfiles==0.21.0 - wcwidth==0.2.8 - webcolors==1.13 - webencodings==0.5.1 - websocket-client==1.6.4 + - websockets==10.3 + - werkzeug==3.0.1 - widgetsnbextension==4.0.9 + - wsproto==1.2.0 +prefix: /Users/Nathaniel/miniconda3/envs/flox diff --git a/flox/data/utils.py b/flox/data/utils.py index fa61381..ecefb70 100644 --- a/flox/data/utils.py +++ b/flox/data/utils.py @@ -74,10 +74,10 @@ def federated_split( "Provided ``Dataset`` does not override ``__len__``, which is required for ``federated_split()``." ) - num_samples_for_workers = (sample_distr * data_count).astype(int) + _num_samples = (sample_distr * data_count).astype(int) num_samples_for_workers = { worker.idx: num_samples - for worker, num_samples in zip(flock.workers, num_samples_for_workers) + for worker, num_samples in zip(flock.workers, _num_samples) } label_probs = {w.idx: label_distr[i] for i, w in enumerate(flock.workers)} diff --git a/flox/runtime/fit.py b/flox/runtime/fit.py index c57a119..ed0f2dd 100644 --- a/flox/runtime/fit.py +++ b/flox/runtime/fit.py @@ -1,13 +1,11 @@ import datetime import typing -import numpy as np from pandas import DataFrame from flox.data import FloxDataset from flox.flock import Flock from flox.nn import FloxModule -# from flox.run.fit_sync import sync_federated_fit from flox.nn.typing import Kind from flox.runtime.launcher import ( GlobusComputeLauncher, @@ -24,20 +22,21 @@ def create_launcher(kind: str, **launcher_cfg) -> Launcher: - if kind == "thread": - return LocalLauncher( - pool="thread", n_workers=launcher_cfg.get("max_workers", 3) - ) - elif kind == "process": - return LocalLauncher( - pool="process", n_workers=launcher_cfg.get("max_workers", 3) - ) - elif kind == "globus-compute": - return GlobusComputeLauncher() - elif kind == "parsl": - return ParslLauncher() - else: - raise ValueError("Illegal value for argument `kind`.") + match kind: + case "thread": + return LocalLauncher( + pool="thread", n_workers=launcher_cfg.get("max_workers", 3) + ) + case "process": + return LocalLauncher( + pool="process", n_workers=launcher_cfg.get("max_workers", 3) + ) + case "globus-compute": + return GlobusComputeLauncher() + case "parsl": + return ParslLauncher() + case _: + raise ValueError("Illegal value for argument `kind`.") def federated_fit( @@ -106,5 +105,5 @@ def federated_fit( start_time = datetime.datetime.now() module, history = process.start(debug_mode) history["train/rel_time"] = history["train/time"] - start_time - history["train/rel_time"] /= np.timedelta64(1, "s") + history["train/rel_time"] = history["train/rel_time"].dt.total_seconds() return module, history diff --git a/flox/runtime/jobs/aggr.py b/flox/runtime/jobs/aggr.py index 64d88e4..1029a5a 100644 --- a/flox/runtime/jobs/aggr.py +++ b/flox/runtime/jobs/aggr.py @@ -1,4 +1,4 @@ -from flox.flock import FlockNode +from flox.flock import FlockNode, FlockNodeID from flox.runtime.result import Result from flox.runtime.transfer import BaseTransfer from flox.strategies import Strategy @@ -19,12 +19,13 @@ def aggregation_job( Aggregation results. """ import pandas - from flox.flock.states import FloxAggregatorState + from flox.flock.states import FloxAggregatorState, NodeState from flox.runtime import JobResult - child_states, child_state_dicts = {}, {} + child_states: dict[FlockNodeID, NodeState] = {} + child_state_dicts = {} for result in results: - idx = result.node_idx + idx: FlockNodeID = result.node_idx child_states[idx] = result.node_state child_state_dicts[idx] = result.state_dict diff --git a/flox/runtime/launcher/globus_compute.py b/flox/runtime/launcher/globus_compute.py index da9002e..ce7218e 100644 --- a/flox/runtime/launcher/globus_compute.py +++ b/flox/runtime/launcher/globus_compute.py @@ -12,11 +12,13 @@ from flox.flock import FlockNode + """ NodeCallable: TypeAlias = Union[ Callable[[FlockNode], Any], Callable[[FlockNode, Any], Any], Callable[[FlockNode, Any, ...], Any], ] + """ class GlobusComputeLauncher(Launcher): @@ -33,7 +35,7 @@ def __init__(self): def submit( self, - fn: NodeCallable, # Callable[[FlockNode, Any, ...], Any], + fn: Callable, # NodeCallable, # Callable[[FlockNode, Any, ...], Any], # FIXME node: FlockNode, /, *args, diff --git a/flox/runtime/process/proc_async.py b/flox/runtime/process/proc_async.py index f84ff81..756c3a2 100644 --- a/flox/runtime/process/proc_async.py +++ b/flox/runtime/process/proc_async.py @@ -9,7 +9,7 @@ from flox.data import FloxDataset from flox.flock import Flock, FlockNodeID -from flox.flock.states import FloxAggregatorState, FloxWorkerState +from flox.flock.states import FloxAggregatorState, FloxWorkerState, NodeState from flox.nn import FloxModule from flox.runtime.jobs import local_training_job from flox.runtime.process.proc import BaseProcess @@ -64,7 +64,7 @@ def start(self, debug_mode: bool = False) -> tuple[FloxModule, DataFrame]: histories: list[DataFrame] = [] worker_rounds: dict[FlockNodeID, int] = {} - worker_states: dict[FlockNodeID, FloxWorkerState] = {} + worker_states: dict[FlockNodeID, NodeState] = {} worker_state_dicts: dict[FlockNodeID, StateDict] = {} for worker in self.flock.workers: worker_rounds[worker.idx] = 0 diff --git a/flox/runtime/result.py b/flox/runtime/result.py index ee0fc4a..f7fb1cd 100644 --- a/flox/runtime/result.py +++ b/flox/runtime/result.py @@ -1,7 +1,7 @@ from __future__ import annotations import typing -from dataclasses import dataclass +from dataclasses import dataclass, field from proxystore.proxy import Proxy @@ -20,22 +20,22 @@ class JobResult: Aggregators and Worker nodes have to return the same type of object to support hierarchical execution. """ - node_state: NodeState | None + node_state: NodeState """The state of the ``Flock`` node based on its kind.""" - node_idx: FlockNodeID | None + node_idx: FlockNodeID """The ID of the ``Flock`` node.""" - node_kind: FlockNodeKind | None + node_kind: FlockNodeKind """The kind of the ``Flock`` node.""" - state_dict: StateDict | None + state_dict: StateDict """The ``StateDict`` of the PyTorch global_module (either aggregated or trained locally).""" - history: DataFrame | None + history: DataFrame """The history of results.""" - cache: dict[str, typing.Any] | None = None + cache: dict[str, typing.Any] = field(default_factory=dict) """Miscellaneous data to be returned as part of the ``JobResult``.""" diff --git a/flox/strategies/base.py b/flox/strategies/base.py index 4ef7a89..da37da7 100644 --- a/flox/strategies/base.py +++ b/flox/strategies/base.py @@ -2,11 +2,11 @@ import typing from abc import abstractmethod, ABC -from typing import Iterable, Mapping, TypeAlias if typing.TYPE_CHECKING: import torch + from typing import Iterable, MutableMapping, TypeAlias from flox.flock import FlockNode, FlockNodeID from flox.flock.states import FloxWorkerState, FloxAggregatorState, NodeState from flox.nn.typing import StateDict @@ -26,7 +26,7 @@ class Strategy(ABC): they are run in an FL process. """ - registry: dict[str, type["Strategy"]] = {} + registry: MutableMapping[str, type["Strategy"]] = {} @classmethod def get_strategy(cls, name: str) -> type["Strategy"]: @@ -114,8 +114,8 @@ def agg_before_round(self, state: FloxAggregatorState) -> None: def agg_param_aggregation( self, state: FloxAggregatorState, - children_states: Mapping[FlockNodeID, NodeState], - children_state_dicts: Mapping[FlockNodeID, StateDict], + children_states: MutableMapping[FlockNodeID, NodeState], + children_state_dicts: MutableMapping[FlockNodeID, StateDict], *args, **kwargs, ) -> StateDict: