diff --git a/requirements.txt b/requirements.txt index 17e13aca..5a02afc7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ cachey dask +distributed fastapi numcodecs numpy diff --git a/setup.py b/setup.py index ec6978d5..4856e26a 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,8 @@ 'zarr = xpublish.plugins.included.zarr:ZarrPlugin', 'module_version = xpublish.plugins.included.module_version:ModuleVersionPlugin', 'plugin_info = xpublish.plugins.included.plugin_info:PluginInfoPlugin', + # 'dask_client = xpublish.plugins.included.dask_client:DaskClientPlugin', + # 'dask_local_cluster = xpublish.plugins.included.dask_local_cluster:DaskLocalClusterPlugin', ] }, ) diff --git a/xpublish/plugins/hooks.py b/xpublish/plugins/hooks.py index e59d1aa8..5ebca0c0 100644 --- a/xpublish/plugins/hooks.py +++ b/xpublish/plugins/hooks.py @@ -3,6 +3,7 @@ import cachey # type: ignore import pluggy # type: ignore import xarray as xr +from dask import distributed from fastapi import APIRouter from pydantic import BaseModel, Field @@ -40,6 +41,8 @@ class Dependencies(BaseModel): plugin_manager: Callable[..., pluggy.PluginManager] = Field( get_plugin_manager, description='The plugin manager itself, allowing for maximum creativity' ) + dask_sync_client: Callable[..., distributed.Client] + dask_async_client: Callable[..., distributed.Client] def __hash__(self): """Dependency functions aren't easy to hash""" @@ -111,11 +114,11 @@ def dataset_router(self, deps: Dependencies) -> APIRouter: # type: ignore """ @hookspec - def get_datasets(self) -> Iterable[str]: # type: ignore + def get_datasets(self, deps: Dependencies) -> Iterable[str]: # type: ignore """Return an iterable of dataset ids that the plugin can provide""" @hookspec(firstresult=True) - def get_dataset(self, dataset_id: str) -> Optional[xr.Dataset]: # type: ignore + def get_dataset(self, dataset_id: str, deps: Dependencies) -> Optional[xr.Dataset]: # type: ignore """Return a dataset by requested dataset_id. If the plugin does not have the dataset, return None @@ -124,3 +127,15 @@ def get_dataset(self, dataset_id: str) -> Optional[xr.Dataset]: # type: ignore @hookspec def register_hookspec(self): # type: ignore """Return additional hookspec class to register with the plugin manager""" + + @hookspec(firstresult=True) + def get_dask_cluster(self) -> distributed.SpecCluster: + """Return the active dask cluster""" + + @hookspec(firstresult=True) + def get_dask_sync_client(self, cluster: distributed.SpecCluster) -> distributed.Client: + """Return a synchronous Dask client""" + + @hookspec(firstresult=True) + def get_dask_async_client(self, cluster: distributed.SpecCluster) -> distributed.Client: + """Return an async Dask client""" diff --git a/xpublish/plugins/included/dask_client.py b/xpublish/plugins/included/dask_client.py new file mode 100644 index 00000000..9c172046 --- /dev/null +++ b/xpublish/plugins/included/dask_client.py @@ -0,0 +1,31 @@ +""" +Default Dask clients +""" +from dask import distributed +from pydantic import Field + +from .. import Plugin, hookimpl + + +class DaskClientPlugin(Plugin): + name = 'dask_client' + + sync_kwargs: dict = Field( + default_factory=dict, description='Keyword arguments for syncronous Dask distributed.Client' + ) + async_kwargs: dict = Field( + default_factory=dict, + description='Keyword arguments for asyncronous Dask distributed.Client', + ) + + @hookimpl(trylast=True) + def get_dask_sync_client(self, cluster: distributed.SpecCluster): + client = distributed.Client(cluster, **self.sync_kwargs) + + return client + + @hookimpl(trylast=True) + def get_dask_async_client(self, cluster: distributed.SpecCluster): + client = distributed.Client(cluster, asynchronous=True, **self.async_kwargs) + + return client diff --git a/xpublish/plugins/included/dask_local_cluster.py b/xpublish/plugins/included/dask_local_cluster.py new file mode 100644 index 00000000..117d8176 --- /dev/null +++ b/xpublish/plugins/included/dask_local_cluster.py @@ -0,0 +1,20 @@ +""" +Default Dask local cluster +""" +from dask import distributed + +from .. import Plugin, hookimpl + + +class DaskLocalClusterPlugin(Plugin): + name = 'dask_local_cluster' + + @hookimpl(trylast=True) + def get_dask_cluster(self): + """Creates a local Dask cluster""" + try: + return self._cluster + except AttributeError: + cluster = distributed.LocalCluster() + self._cluster = cluster + return cluster diff --git a/xpublish/rest.py b/xpublish/rest.py index 65f8581e..7c80b4a3 100644 --- a/xpublish/rest.py +++ b/xpublish/rest.py @@ -4,6 +4,7 @@ import pluggy import uvicorn import xarray as xr +from dask import distributed from fastapi import APIRouter, FastAPI, HTTPException from .dependencies import get_cache, get_dataset, get_dataset_ids, get_plugin_manager @@ -113,7 +114,7 @@ def get_datasets_from_plugins(self) -> List[str]: """ dataset_ids = list(self._datasets) - for plugin_dataset_ids in self.pm.hook.get_datasets(): + for plugin_dataset_ids in self.pm.hook.get_datasets(deps=self.dependencies()): dataset_ids.extend(plugin_dataset_ids) return dataset_ids @@ -132,7 +133,7 @@ def get_dataset_from_plugins(self, dataset_id: str) -> xr.Dataset: Raises: FastAPI.HTTPException: When a dataset is not found a 404 error is returned. """ - dataset = self.pm.hook.get_dataset(dataset_id=dataset_id) + dataset = self.pm.hook.get_dataset(dataset_id=dataset_id, deps=self.dependencies()) if dataset: return dataset @@ -231,6 +232,33 @@ def plugins(self) -> Dict[str, Plugin]: """Returns the loaded plugins""" return dict(self.pm.list_name_plugin()) + def dask_cluster(self) -> distributed.SpecCluster: + """Currently active Dask cluster""" + try: + return self._dask_cluster + except AttributeError: + self._dask_cluster = self.pm.hook.get_dask_cluster() + + return self._dask_cluster + + def dask_sync_client(self) -> distributed.Client: + """Syncronous Dask client""" + try: + return self._dask_sync_client + except AttributeError: + self._dask_sync_client = self.pm.hook.get_dask_sync_client(cluster=self.dask_cluster()) + + return self._dask_client + + def dask_async_client(self) -> distributed.Client: + """Asyncronous Dask client""" + try: + return self._dask_async_client + except AttributeError: + self._dask_async_client = self.pm.hook.get_dask_async_client() + + return self._dask_async_client + def _init_routers(self, dataset_routers: Optional[APIRouter]): """Setup plugin and dataset routers. Needs to run after dataset and plugin setup""" app_routers, plugin_dataset_routers = self.plugin_routers() @@ -276,6 +304,8 @@ def dependencies(self) -> Dependencies: cache=lambda: self.cache, plugins=lambda: self.plugins, plugin_manager=lambda: self.pm, + dask_sync_client=self.dask_sync_client, + dask_async_client=self.dask_async_client, ) return deps