diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fbfd7dd6..3ef309a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,8 +9,8 @@ repos: - id: check-yaml - id: double-quote-string-fixer - - repo: https://github.com/ambv/black - rev: 21.6b0 + - repo: https://github.com/psf/black + rev: 22.12.0 hooks: - id: black args: ["--line-length", "100", "--skip-string-normalization"] diff --git a/setup.py b/setup.py index 7108b2b3..9114a3d5 100644 --- a/setup.py +++ b/setup.py @@ -45,4 +45,11 @@ keywords=['xarray', 'zarr', 'api'], use_scm_version={'version_scheme': 'post-release', 'local_scheme': 'dirty-tag'}, setup_requires=['setuptools_scm', 'setuptools>=30.3.0'], + entry_points={ + 'xpublish.plugin': [ + 'base = xpublish.plugins.base:BasePlugin', + 'zarr = xpublish.plugins.zarr:ZarrPlugin', + 'module_version = xpublish.plugins.module_version:ModuleVersionPlugin' + ] + }, ) diff --git a/xpublish/plugins/__init__.py b/xpublish/plugins/__init__.py new file mode 100644 index 00000000..153b404b --- /dev/null +++ b/xpublish/plugins/__init__.py @@ -0,0 +1,2 @@ +from .factory import XpublishPluginFactory # noqa: F401 +from .load import configure_plugins, find_plugins # noqa: F401 diff --git a/xpublish/routers/base.py b/xpublish/plugins/base.py similarity index 80% rename from xpublish/routers/base.py rename to xpublish/plugins/base.py index 8538feb2..55ec3bc4 100644 --- a/xpublish/routers/base.py +++ b/xpublish/plugins/base.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import List import xarray as xr from fastapi import Depends @@ -6,15 +7,18 @@ from zarr.storage import attrs_key from ..dependencies import get_zmetadata, get_zvariables -from .factory import XpublishFactory +from .factory import XpublishPluginFactory @dataclass -class BaseFactory(XpublishFactory): +class BasePlugin(XpublishPluginFactory): """API entry-points providing basic information about the dataset(s).""" + dataset_router_prefix: str = '/info' + dataset_router_tags: List[str] = field(default_factory=lambda: ['info']) + def register_routes(self): - @self.router.get('/') + @self.dataset_router.get('/') def html_representation( dataset=Depends(self.dataset_dependency), ): @@ -23,19 +27,19 @@ def html_representation( with xr.set_options(display_style='html'): return HTMLResponse(dataset._repr_html_()) - @self.router.get('/keys') + @self.dataset_router.get('/keys') def list_keys( dataset=Depends(self.dataset_dependency), ): return list(dataset.variables) - @self.router.get('/dict') + @self.dataset_router.get('/dict') def to_dict( dataset=Depends(self.dataset_dependency), ): return dataset.to_dict(data=False) - @self.router.get('/info') + @self.dataset_router.get('/info') def info( dataset=Depends(self.dataset_dependency), cache=Depends(self.cache_dependency), diff --git a/xpublish/plugins/factory.py b/xpublish/plugins/factory.py new file mode 100644 index 00000000..da844e0e --- /dev/null +++ b/xpublish/plugins/factory.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass, field +from typing import Callable, List, Optional + +import cachey +import xarray as xr +from fastapi import APIRouter + +from ..dependencies import get_cache, get_dataset, get_dataset_ids + + +@dataclass +class XpublishPluginFactory: + """Xpublish plugin factory. + + Xpublish plugins are designed to be automatically loaded via the entry point + group `xpublish.plugin` from any installed package. + + Plugins can define both app (top-level) and dataset based routes, and + default prefixes and tags for both. + """ + + app_router: APIRouter = field(default_factory=APIRouter) + app_router_prefix: Optional[str] = None + app_router_tags: List[str] = field(default_factory=list) + + dataset_router: APIRouter = field(default_factory=APIRouter) + dataset_router_prefix: Optional[str] = None + dataset_router_tags: List[str] = field(default_factory=list) + + dataset_ids_dependency: Callable[..., List[str]] = get_dataset_ids + dataset_dependency: Callable[..., xr.Dataset] = get_dataset + cache_dependency: Callable[..., cachey.Cache] = get_cache + + def __post_init__(self): + self.register_routes() + + def register_routes(self): + """Register xpublish routes.""" + raise NotImplementedError() diff --git a/xpublish/plugins/load.py b/xpublish/plugins/load.py new file mode 100644 index 00000000..ab3a4f89 --- /dev/null +++ b/xpublish/plugins/load.py @@ -0,0 +1,34 @@ +""" +Load and configure Xpublish plugins from entry point group `xpublish.plugin` +""" +from importlib.metadata import entry_points +from typing import Dict, List, Optional + +from .factory import XpublishPluginFactory + + +def find_plugins(exclude_plugins: Optional[List[str]] = None): + """Find Xpublish plugins from entry point group `xpublish.plugin`""" + exclude_plugin_names = set(exclude_plugins or []) + + plugins: Dict[str, XpublishPluginFactory] = {} + + for entry_point in entry_points()['xpublish.plugin']: + if entry_point.name not in exclude_plugin_names: + plugins[entry_point.name] = entry_point.load() + + return plugins + + +def configure_plugins( + plugins: Dict[str, XpublishPluginFactory], plugin_configs: Optional[Dict] = None +): + """Initialize and configure plugins""" + initialized_plugins: Dict[str, XpublishPluginFactory] = {} + plugin_configs = plugin_configs or {} + + for name, plugin in plugins.items(): + kwargs = plugin_configs.get(name, {}) + initialized_plugins[name] = plugin(**kwargs) + + return initialized_plugins diff --git a/xpublish/plugins/module_version.py b/xpublish/plugins/module_version.py new file mode 100644 index 00000000..a628fd61 --- /dev/null +++ b/xpublish/plugins/module_version.py @@ -0,0 +1,42 @@ +""" +Version information router +""" +from dataclasses import dataclass +import importlib +import sys +from typing import List + +from ..utils.info import get_sys_info, netcdf_and_hdf5_versions +from .factory import XpublishPluginFactory + + +@dataclass +class ModuleVersionPlugin(XpublishPluginFactory): + """Module and system version information""" + + def register_routes(self): + @self.app_router.get("/versions") + def get_versions(): + versions = dict(get_sys_info() + netcdf_and_hdf5_versions()) + modules = [ + 'xarray', + 'zarr', + 'numcodecs', + 'fastapi', + 'starlette', + 'pandas', + 'numpy', + 'dask', + 'distributed', + 'uvicorn', + ] + for modname in modules: + try: + if modname in sys.modules: + mod = sys.modules[modname] + else: # pragma: no cover + mod = importlib.import_module(modname) + versions[modname] = getattr(mod, '__version__', None) + except ImportError: # pragma: no cover + pass + return versions diff --git a/xpublish/routers/zarr.py b/xpublish/plugins/zarr.py similarity index 88% rename from xpublish/routers/zarr.py rename to xpublish/plugins/zarr.py index d5457b36..eda0185b 100644 --- a/xpublish/routers/zarr.py +++ b/xpublish/plugins/zarr.py @@ -1,6 +1,7 @@ import json import logging -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import List import cachey import xarray as xr @@ -12,17 +13,20 @@ from ..utils.api import DATASET_ID_ATTR_KEY from ..utils.cache import CostTimer from ..utils.zarr import encode_chunk, get_data_chunk, jsonify_zmetadata, zarr_metadata_key -from .factory import XpublishFactory +from .factory import XpublishPluginFactory logger = logging.getLogger('zarr_api') @dataclass -class ZarrFactory(XpublishFactory): +class ZarrPlugin(XpublishPluginFactory): """Provides access to data and metadata through as Zarr compatible API.""" + dataset_router_prefix: str = '/zarr' + dataset_router_tags: List[str] = field(default_factory=lambda: ['zarr']) + def register_routes(self): - @self.router.get(f'/{zarr_metadata_key}') + @self.dataset_router.get(f'/{zarr_metadata_key}') def get_zarr_metadata( dataset=Depends(self.dataset_dependency), cache=Depends(self.cache_dependency), @@ -34,7 +38,7 @@ def get_zarr_metadata( return Response(json.dumps(zjson).encode('ascii'), media_type='application/json') - @self.router.get(f'/{group_meta_key}') + @self.dataset_router.get(f'/{group_meta_key}') def get_zarr_group( dataset=Depends(self.dataset_dependency), cache=Depends(self.cache_dependency), @@ -44,7 +48,7 @@ def get_zarr_group( return zmetadata['metadata'][group_meta_key] - @self.router.get(f'/{attrs_key}') + @self.dataset_router.get(f'/{attrs_key}') def get_zarr_attrs( dataset=Depends(self.dataset_dependency), cache=Depends(self.cache_dependency), @@ -54,7 +58,7 @@ def get_zarr_attrs( return zmetadata['metadata'][attrs_key] - @self.router.get('/{var}/{chunk}') + @self.dataset_router.get('/{var}/{chunk}') def get_variable_chunk( var: str, chunk: str, diff --git a/xpublish/rest.py b/xpublish/rest.py index cd248389..3108a2e2 100644 --- a/xpublish/rest.py +++ b/xpublish/rest.py @@ -1,10 +1,13 @@ +from typing import Dict, List, Optional + import cachey import uvicorn import xarray as xr from fastapi import FastAPI, HTTPException from .dependencies import get_cache, get_dataset, get_dataset_ids -from .routers import BaseFactory, ZarrFactory, common_router, dataset_collection_router +from .plugins import XpublishPluginFactory, configure_plugins, find_plugins +from .routers import dataset_collection_router from .utils.api import ( SingleDatasetOpenAPIOverrider, check_route_conflicts, @@ -44,19 +47,9 @@ def _set_app_routers(dataset_routers=None, dataset_route_prefix=''): app_routers = [] - # top-level api endpoints - app_routers.append((common_router, {})) - if dataset_route_prefix: app_routers.append((dataset_collection_router, {'tags': ['info']})) - # dataset-specifc api endpoints - if dataset_routers is None: - dataset_routers = [ - (BaseFactory().router, {'tags': ['info']}), - (ZarrFactory().router, {'tags': ['zarr']}), - ] - app_routers += normalize_app_routers(dataset_routers, dataset_route_prefix) check_route_conflicts(app_routers) @@ -79,8 +72,8 @@ class Rest: are converted to strings. See also the notes below. routers : list, optional A list of dataset-specific :class:`fastapi.APIRouter` instances to - include in the fastAPI application. If None, the default routers will be - included. + include in the fastAPI application. These routers are in addition + to any loaded via plugins. The items of the list may also be tuples with the following format: ``[(router1, {'prefix': '/foo', 'tags': ['foo', 'bar']})]``, where the 1st tuple element is a :class:`fastapi.APIRouter` instance and the @@ -94,6 +87,20 @@ class Rest: app_kws : dict, optional Dictionary of keyword arguments to be passed to :meth:`fastapi.FastAPI.__init__()`. + plugins : dict, optional + Optional dictionary of loaded, configured plugins. + Overrides automatic loading of plugins. + extend_plugins: dict, optional + Optional dictionary of loaded, configured plugins. + Instead of skipping the automatic loading of plugins, + automatic loading still occurs, then plugins can be + manually configured or added. + Useful for plugins without entry points. + exclude_plugin_names: list, optional + Skips automatically loading matching plugins + plugin_configs : dict, optional + Plugin kwargs can be set by passing in a dictionary + of plugin names to a dict of kwargs. Notes ----- @@ -106,8 +113,86 @@ class Rest: """ - def __init__(self, datasets, routers=None, cache_kws=None, app_kws=None): + def __init__( + self, + datasets, + routers=None, + cache_kws=None, + app_kws=None, + plugins: Optional[Dict[str, XpublishPluginFactory]] = None, + extend_plugins: Optional[Dict[str, XpublishPluginFactory]] = None, + exclude_plugin_names: Optional[List[str]] = None, + plugin_configs: Optional[Dict] = None, + ): + + dataset_route_prefix = self.init_datasets(datasets) + + if not plugins: + self.load_plugins(exclude_plugins=exclude_plugin_names, plugin_configs=plugin_configs) + else: + self._plugins = plugins + + self._plugins.update(extend_plugins) + + plugin_app_routers, plugin_dataset_routers = self.plugin_routers() + + self._app_routers = plugin_app_routers + self._app_routers.extend( + _set_app_routers(plugin_dataset_routers + routers, dataset_route_prefix) + ) + + self.init_app_kwargs(app_kws) + self.init_cache_kwargs(cache_kws) + + def load_plugins( + self, exclude_plugins: Optional[List[str]] = None, plugin_configs: Optional[Dict] = None + ): + """Initialize and load plugins from entry_points""" + found_plugins = find_plugins(exclude_plugins=exclude_plugins) + self._plugins = configure_plugins(found_plugins, plugin_configs=plugin_configs) + + def plugin_routers(self): + """Load the app and dataset routers for plugins""" + app_routers = [] + dataset_routers = [] + + for plugin in self._plugins.values(): + if plugin.app_router.routes: + router_kwargs = {} + if plugin.app_router_prefix: + router_kwargs['prefix'] = plugin.app_router_prefix + if plugin.app_router_tags: + router_kwargs['tags'] = plugin.app_router_tags + + app_routers.append((plugin.app_router, router_kwargs)) + + if plugin.dataset_router.routes: + router_kwargs = {} + if plugin.dataset_router_prefix: + router_kwargs['prefix'] = plugin.dataset_router_prefix + if plugin.dataset_router_tags: + router_kwargs['tags'] = plugin.dataset_router_tags + + dataset_routers.append((plugin.dataset_router, router_kwargs)) + + return app_routers, dataset_routers + + def init_cache_kwargs(self, cache_kws): + """Set up cache kwargs""" + self._cache = None + self._cache_kws = {'available_bytes': 1e6} + if cache_kws is not None: + self._cache_kws.update(cache_kws) + + def init_app_kwargs(self, app_kws): + """Set up FastAPI application kwargs""" + self._app = None + self._app_kws = {} + if app_kws is not None: + self._app_kws.update(app_kws) + def init_datasets(self, datasets): + """Initialize datasets and getter functions""" self._datasets = normalize_datasets(datasets) if not self._datasets: @@ -117,18 +202,7 @@ def __init__(self, datasets, routers=None, cache_kws=None, app_kws=None): else: self._get_dataset_func = _dataset_from_collection_getter(self._datasets) dataset_route_prefix = '/datasets/{dataset_id}' - - self._app_routers = _set_app_routers(routers, dataset_route_prefix) - - self._app = None - self._app_kws = {} - if app_kws is not None: - self._app_kws.update(app_kws) - - self._cache = None - self._cache_kws = {'available_bytes': 1e6} - if cache_kws is not None: - self._cache_kws.update(cache_kws) + return dataset_route_prefix @property def cache(self) -> cachey.Cache: diff --git a/xpublish/routers/__init__.py b/xpublish/routers/__init__.py index 7919ae01..ef14749a 100644 --- a/xpublish/routers/__init__.py +++ b/xpublish/routers/__init__.py @@ -1,3 +1 @@ -from .base import BaseFactory -from .common import common_router, dataset_collection_router -from .zarr import ZarrFactory +from .common import dataset_collection_router diff --git a/xpublish/routers/common.py b/xpublish/routers/common.py index 8451b9cc..4d70a1bd 100644 --- a/xpublish/routers/common.py +++ b/xpublish/routers/common.py @@ -2,42 +2,42 @@ Dataset-independent API routes. """ -import importlib -import sys +# import importlib +# import sys from fastapi import APIRouter, Depends from ..dependencies import get_dataset_ids -from ..utils.info import get_sys_info, netcdf_and_hdf5_versions - -common_router = APIRouter() - - -@common_router.get('/versions') -def get_versions(): - versions = dict(get_sys_info() + netcdf_and_hdf5_versions()) - modules = [ - 'xarray', - 'zarr', - 'numcodecs', - 'fastapi', - 'starlette', - 'pandas', - 'numpy', - 'dask', - 'distributed', - 'uvicorn', - ] - for modname in modules: - try: - if modname in sys.modules: - mod = sys.modules[modname] - else: # pragma: no cover - mod = importlib.import_module(modname) - versions[modname] = getattr(mod, '__version__', None) - except ImportError: # pragma: no cover - pass - return versions +# from ..utils.info import get_sys_info, netcdf_and_hdf5_versions + +# common_router = APIRouter() + + +# @common_router.get('/versions') +# def get_versions(): +# versions = dict(get_sys_info() + netcdf_and_hdf5_versions()) +# modules = [ +# 'xarray', +# 'zarr', +# 'numcodecs', +# 'fastapi', +# 'starlette', +# 'pandas', +# 'numpy', +# 'dask', +# 'distributed', +# 'uvicorn', +# ] +# for modname in modules: +# try: +# if modname in sys.modules: +# mod = sys.modules[modname] +# else: # pragma: no cover +# mod = importlib.import_module(modname) +# versions[modname] = getattr(mod, '__version__', None) +# except ImportError: # pragma: no cover +# pass +# return versions dataset_collection_router = APIRouter() diff --git a/xpublish/routers/factory.py b/xpublish/routers/factory.py deleted file mode 100644 index 81708986..00000000 --- a/xpublish/routers/factory.py +++ /dev/null @@ -1,26 +0,0 @@ -from dataclasses import dataclass, field -from typing import Callable, List - -import cachey -import xarray as xr -from fastapi import APIRouter - -from ..dependencies import get_cache, get_dataset, get_dataset_ids - - -@dataclass -class XpublishFactory: - """Xpublish API router factory.""" - - router: APIRouter = field(default_factory=APIRouter) - - dataset_ids_dependency: Callable[..., List[str]] = get_dataset_ids - dataset_dependency: Callable[..., xr.Dataset] = get_dataset - cache_dependency: Callable[..., cachey.Cache] = get_cache - - def __post_init__(self): - self.register_routes() - - def register_routes(self): - """Register xpublish routes.""" - raise NotImplementedError()