From b81b3925d320ac821dec04f6876ec4c2975bccfc Mon Sep 17 00:00:00 2001 From: Alex Kerney Date: Wed, 14 Jun 2023 21:10:39 -0400 Subject: [PATCH] Experimenting with Dask integration Adds two local plugins and associated infrastructure for most hooks to be able to use Dask. In most cases for different types of Dask infrastruture, a plugin that provides a `get_dask_cluster()` method should do the trick. The hook is set up to only return one result, and the built in plugin will be the last. The Dask client plugin in theory should work with different types of clusters, but is similarly set up to be able to be overridden (dask-on-ray?). The client can be both sync and async, and once it gets accessed, it's cached on `xpublish.Rest`. For hooks that have access to `deps` (which now includes dataset providers), `deps.dask_sync_client` and `deps.dask_async_client` now should give you the client. The async client may need to be passed the current event loop. It appears the way to access the event loop varies by server, so that will probably take some research. - https://github.com/tiangolo/fastapi/discussions/7876 - https://github.com/encode/uvicorn/issues/706 - https://stackoverflow.com/questions/66275747/how-to-use-event-loop-created-by-uvicorn --- requirements.txt | 1 + setup.py | 2 ++ xpublish/plugins/hooks.py | 19 +++++++++-- xpublish/plugins/included/dask_client.py | 31 +++++++++++++++++ .../plugins/included/dask_local_cluster.py | 20 +++++++++++ xpublish/rest.py | 34 +++++++++++++++++-- 6 files changed, 103 insertions(+), 4 deletions(-) create mode 100644 xpublish/plugins/included/dask_client.py create mode 100644 xpublish/plugins/included/dask_local_cluster.py diff --git a/requirements.txt b/requirements.txt index 17e13aca..5a02afc7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ cachey dask +distributed fastapi numcodecs numpy diff --git a/setup.py b/setup.py index ec6978d5..4856e26a 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,8 @@ 'zarr = xpublish.plugins.included.zarr:ZarrPlugin', 'module_version = xpublish.plugins.included.module_version:ModuleVersionPlugin', 'plugin_info = xpublish.plugins.included.plugin_info:PluginInfoPlugin', + # 'dask_client = xpublish.plugins.included.dask_client:DaskClientPlugin', + # 'dask_local_cluster = xpublish.plugins.included.dask_local_cluster:DaskLocalClusterPlugin', ] }, ) diff --git a/xpublish/plugins/hooks.py b/xpublish/plugins/hooks.py index e59d1aa8..5ebca0c0 100644 --- a/xpublish/plugins/hooks.py +++ b/xpublish/plugins/hooks.py @@ -3,6 +3,7 @@ import cachey # type: ignore import pluggy # type: ignore import xarray as xr +from dask import distributed from fastapi import APIRouter from pydantic import BaseModel, Field @@ -40,6 +41,8 @@ class Dependencies(BaseModel): plugin_manager: Callable[..., pluggy.PluginManager] = Field( get_plugin_manager, description='The plugin manager itself, allowing for maximum creativity' ) + dask_sync_client: Callable[..., distributed.Client] + dask_async_client: Callable[..., distributed.Client] def __hash__(self): """Dependency functions aren't easy to hash""" @@ -111,11 +114,11 @@ def dataset_router(self, deps: Dependencies) -> APIRouter: # type: ignore """ @hookspec - def get_datasets(self) -> Iterable[str]: # type: ignore + def get_datasets(self, deps: Dependencies) -> Iterable[str]: # type: ignore """Return an iterable of dataset ids that the plugin can provide""" @hookspec(firstresult=True) - def get_dataset(self, dataset_id: str) -> Optional[xr.Dataset]: # type: ignore + def get_dataset(self, dataset_id: str, deps: Dependencies) -> Optional[xr.Dataset]: # type: ignore """Return a dataset by requested dataset_id. If the plugin does not have the dataset, return None @@ -124,3 +127,15 @@ def get_dataset(self, dataset_id: str) -> Optional[xr.Dataset]: # type: ignore @hookspec def register_hookspec(self): # type: ignore """Return additional hookspec class to register with the plugin manager""" + + @hookspec(firstresult=True) + def get_dask_cluster(self) -> distributed.SpecCluster: + """Return the active dask cluster""" + + @hookspec(firstresult=True) + def get_dask_sync_client(self, cluster: distributed.SpecCluster) -> distributed.Client: + """Return a synchronous Dask client""" + + @hookspec(firstresult=True) + def get_dask_async_client(self, cluster: distributed.SpecCluster) -> distributed.Client: + """Return an async Dask client""" diff --git a/xpublish/plugins/included/dask_client.py b/xpublish/plugins/included/dask_client.py new file mode 100644 index 00000000..9c172046 --- /dev/null +++ b/xpublish/plugins/included/dask_client.py @@ -0,0 +1,31 @@ +""" +Default Dask clients +""" +from dask import distributed +from pydantic import Field + +from .. import Plugin, hookimpl + + +class DaskClientPlugin(Plugin): + name = 'dask_client' + + sync_kwargs: dict = Field( + default_factory=dict, description='Keyword arguments for syncronous Dask distributed.Client' + ) + async_kwargs: dict = Field( + default_factory=dict, + description='Keyword arguments for asyncronous Dask distributed.Client', + ) + + @hookimpl(trylast=True) + def get_dask_sync_client(self, cluster: distributed.SpecCluster): + client = distributed.Client(cluster, **self.sync_kwargs) + + return client + + @hookimpl(trylast=True) + def get_dask_async_client(self, cluster: distributed.SpecCluster): + client = distributed.Client(cluster, asynchronous=True, **self.async_kwargs) + + return client diff --git a/xpublish/plugins/included/dask_local_cluster.py b/xpublish/plugins/included/dask_local_cluster.py new file mode 100644 index 00000000..117d8176 --- /dev/null +++ b/xpublish/plugins/included/dask_local_cluster.py @@ -0,0 +1,20 @@ +""" +Default Dask local cluster +""" +from dask import distributed + +from .. import Plugin, hookimpl + + +class DaskLocalClusterPlugin(Plugin): + name = 'dask_local_cluster' + + @hookimpl(trylast=True) + def get_dask_cluster(self): + """Creates a local Dask cluster""" + try: + return self._cluster + except AttributeError: + cluster = distributed.LocalCluster() + self._cluster = cluster + return cluster diff --git a/xpublish/rest.py b/xpublish/rest.py index 65f8581e..7c80b4a3 100644 --- a/xpublish/rest.py +++ b/xpublish/rest.py @@ -4,6 +4,7 @@ import pluggy import uvicorn import xarray as xr +from dask import distributed from fastapi import APIRouter, FastAPI, HTTPException from .dependencies import get_cache, get_dataset, get_dataset_ids, get_plugin_manager @@ -113,7 +114,7 @@ def get_datasets_from_plugins(self) -> List[str]: """ dataset_ids = list(self._datasets) - for plugin_dataset_ids in self.pm.hook.get_datasets(): + for plugin_dataset_ids in self.pm.hook.get_datasets(deps=self.dependencies()): dataset_ids.extend(plugin_dataset_ids) return dataset_ids @@ -132,7 +133,7 @@ def get_dataset_from_plugins(self, dataset_id: str) -> xr.Dataset: Raises: FastAPI.HTTPException: When a dataset is not found a 404 error is returned. """ - dataset = self.pm.hook.get_dataset(dataset_id=dataset_id) + dataset = self.pm.hook.get_dataset(dataset_id=dataset_id, deps=self.dependencies()) if dataset: return dataset @@ -231,6 +232,33 @@ def plugins(self) -> Dict[str, Plugin]: """Returns the loaded plugins""" return dict(self.pm.list_name_plugin()) + def dask_cluster(self) -> distributed.SpecCluster: + """Currently active Dask cluster""" + try: + return self._dask_cluster + except AttributeError: + self._dask_cluster = self.pm.hook.get_dask_cluster() + + return self._dask_cluster + + def dask_sync_client(self) -> distributed.Client: + """Syncronous Dask client""" + try: + return self._dask_sync_client + except AttributeError: + self._dask_sync_client = self.pm.hook.get_dask_sync_client(cluster=self.dask_cluster()) + + return self._dask_client + + def dask_async_client(self) -> distributed.Client: + """Asyncronous Dask client""" + try: + return self._dask_async_client + except AttributeError: + self._dask_async_client = self.pm.hook.get_dask_async_client() + + return self._dask_async_client + def _init_routers(self, dataset_routers: Optional[APIRouter]): """Setup plugin and dataset routers. Needs to run after dataset and plugin setup""" app_routers, plugin_dataset_routers = self.plugin_routers() @@ -276,6 +304,8 @@ def dependencies(self) -> Dependencies: cache=lambda: self.cache, plugins=lambda: self.plugins, plugin_manager=lambda: self.pm, + dask_sync_client=self.dask_sync_client, + dask_async_client=self.dask_async_client, ) return deps