diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6ca658a8..1ec26234 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: - id: double-quote-string-fixer - repo: https://github.com/psf/black - rev: 22.10.0 + rev: 22.12.0 hooks: - id: black args: ["--line-length", "100", "--skip-string-normalization"] diff --git a/setup.cfg b/setup.cfg index 7ad2703f..416ce2bb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,7 +9,7 @@ select = B,C,E,F,W,T4,B9 [isort] known_first_party=xpublish -known_third_party=cachey,dask,fastapi,numcodecs,numpy,pandas,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,starlette,uvicorn,xarray,zarr +known_third_party=cachey,dask,fastapi,numcodecs,numpy,pandas,pkg_resources,pydantic,pytest,setuptools,sphinx_autosummary_accessors,starlette,uvicorn,xarray,zarr multi_line_output=3 include_trailing_comma=True force_grid_wrap=0 diff --git a/setup.py b/setup.py index 106125ca..43255029 100644 --- a/setup.py +++ b/setup.py @@ -46,4 +46,12 @@ keywords=['xarray', 'zarr', 'api'], use_scm_version={'version_scheme': 'post-release', 'local_scheme': 'dirty-tag'}, setup_requires=['setuptools_scm>=3.4', 'setuptools>=42'], + entry_points={ + 'xpublish.plugin': [ + 'info = xpublish.included_plugins.dataset_info:DatasetInfoPlugin', + 'zarr = xpublish.included_plugins.zarr:ZarrPlugin', + 'module_version = xpublish.included_plugins.module_version:ModuleVersionPlugin', + 'plugin_info = xpublish.included_plugins.plugin_info:PluginInfoPlugin', + ] + }, ) diff --git a/tests/test_rest_api.py b/tests/test_rest_api.py index 04564a42..64a179fc 100644 --- a/tests/test_rest_api.py +++ b/tests/test_rest_api.py @@ -102,7 +102,7 @@ def test_custom_app_routers(airtemp_ds, dims_router, router_kws, path): else: routers = [(dims_router, router_kws)] - rest = Rest(airtemp_ds, routers=routers) + rest = Rest(airtemp_ds, routers=routers, plugins={}) client = TestClient(rest.app) response = client.get(path) diff --git a/xpublish/__init__.py b/xpublish/__init__.py index f1109835..a58a2c5b 100644 --- a/xpublish/__init__.py +++ b/xpublish/__init__.py @@ -1,6 +1,8 @@ from pkg_resources import DistributionNotFound, get_distribution -from .rest import Rest, RestAccessor # noqa: F401 +from .accessor import RestAccessor # noqa: F401 +from .plugin import Plugin, Router # noqa: F401 +from .rest import Rest # noqa: F401 try: __version__ = get_distribution(__name__).version diff --git a/xpublish/accessor.py b/xpublish/accessor.py new file mode 100644 index 00000000..f70e6d70 --- /dev/null +++ b/xpublish/accessor.py @@ -0,0 +1,74 @@ +import cachey +import xarray as xr +from fastapi import FastAPI + +from .rest import Rest + + +@xr.register_dataset_accessor('rest') +class RestAccessor: + """REST API Accessor for serving one dataset in its + dedicated FastAPI application. + + """ + + def __init__(self, xarray_obj): + + self._obj = xarray_obj + self._rest = None + + self._initialized = False + + def _get_rest_obj(self): + if self._rest is None: + self._rest = Rest(self._obj) + + return self._rest + + def __call__(self, **kwargs): + """Initialize this accessor by setting optional configuration values. + + Parameters + ---------- + **kwargs + Arguments passed to :func:`xpublish.Rest.__init__`. + + Notes + ----- + This method can only be invoked once. + + """ + if self._initialized: + raise RuntimeError('This accessor has already been initialized') + self._initialized = True + + self._rest = Rest(self._obj, **kwargs) + + return self + + @property + def cache(self) -> cachey.Cache: + """Returns the :class:`cachey.Cache` instance used by the FastAPI application.""" + + return self._get_rest_obj().cache + + @property + def app(self) -> FastAPI: + """Returns the :class:`fastapi.FastAPI` application instance.""" + + return self._get_rest_obj().app + + def serve(self, **kwargs): + """Serve this FastAPI application via :func:`uvicorn.run`. + + Parameters + ---------- + **kwargs : + Arguments passed to :func:`xpublish.Rest.serve`. + + Notes + ----- + This method is blocking and does not return. + + """ + self._get_rest_obj().serve(**kwargs) diff --git a/xpublish/dependencies.py b/xpublish/dependencies.py index 6499acc7..dcd3a412 100644 --- a/xpublish/dependencies.py +++ b/xpublish/dependencies.py @@ -1,6 +1,8 @@ """ Helper functions to use a FastAPI dependencies. """ +from typing import TYPE_CHECKING, Dict, List + import cachey import xarray as xr from fastapi import Depends @@ -8,8 +10,11 @@ from .utils.api import DATASET_ID_ATTR_KEY from .utils.zarr import create_zmetadata, create_zvariables, zarr_metadata_key +if TYPE_CHECKING: + from .plugin import Plugin + -def get_dataset_ids(): +def get_dataset_ids() -> List[str]: """FastAPI dependency for getting the list of ids (string keys) of the collection of datasets being served. @@ -23,7 +28,7 @@ def get_dataset_ids(): return [] # pragma: no cover -def get_dataset(dataset_id: str): +def get_dataset(dataset_id: str) -> xr.Dataset: """FastAPI dependency for accessing the published xarray dataset object. Use this callable as dependency in any FastAPI path operation @@ -33,10 +38,10 @@ def get_dataset(dataset_id: str): application. """ - return None # pragma: no cover + return xr.Dataset() -def get_cache(): +def get_cache() -> cachey.Cache: """FastAPI dependency for accessing the application's cache. Use this callable as dependency in any FastAPI path operation @@ -47,7 +52,7 @@ def get_cache(): application. """ - return None # pragma: no cover + return cachey.Cache(available_bytes=1e6) def get_zvariables( @@ -84,3 +89,9 @@ def get_zmetadata( cache.put(cache_key, zmeta, 99999) return zmeta + + +def get_plugins() -> Dict[str, 'Plugin']: + """FastAPI dependency that returns the a dictionary of loaded plugins""" + + return [] diff --git a/xpublish/included_plugins/dataset_info.py b/xpublish/included_plugins/dataset_info.py new file mode 100644 index 00000000..86d53675 --- /dev/null +++ b/xpublish/included_plugins/dataset_info.py @@ -0,0 +1,71 @@ +import xarray as xr +from fastapi import Depends +from pydantic import Field +from starlette.responses import HTMLResponse +from zarr.storage import attrs_key + +from ..dependencies import get_zmetadata, get_zvariables +from ..plugin import Plugin, Router + + +class DatasetInfoRouter(Router): + """API entry-points providing basic information about the dataset(s).""" + + prefix = '' + + def register(self): + @self._router.get('/') + def html_representation( + dataset=Depends(self.deps.dataset), + ): + """Returns a HTML representation of the dataset.""" + + with xr.set_options(display_style='html'): + return HTMLResponse(dataset._repr_html_()) + + @self._router.get('/keys') + def list_keys( + dataset=Depends(self.deps.dataset), + ): + return list(dataset.variables) + + @self._router.get('/dict') + def to_dict( + dataset=Depends(self.deps.dataset), + ): + return dataset.to_dict(data=False) + + @self._router.get('/info') + def info( + dataset=Depends(self.deps.dataset), + cache=Depends(self.deps.cache), + ): + """Dataset schema (close to the NCO-JSON schema).""" + + zvariables = get_zvariables(dataset, cache) + zmetadata = get_zmetadata(dataset, cache, zvariables) + + info = {} + info['dimensions'] = dict(dataset.dims.items()) + info['variables'] = {} + + meta = zmetadata['metadata'] + + for name, var in zvariables.items(): + attrs = meta[f'{name}/{attrs_key}'] + attrs.pop('_ARRAY_DIMENSIONS') + info['variables'][name] = { + 'type': var.data.dtype.name, + 'dimensions': list(var.dims), + 'attributes': attrs, + } + + info['global_attributes'] = meta[attrs_key] + + return info + + +class DatasetInfoPlugin(Plugin): + name = 'dataset_info' + + dataset_router: DatasetInfoRouter = Field(default_factory=DatasetInfoRouter) diff --git a/xpublish/included_plugins/module_version.py b/xpublish/included_plugins/module_version.py new file mode 100644 index 00000000..f0210380 --- /dev/null +++ b/xpublish/included_plugins/module_version.py @@ -0,0 +1,49 @@ +""" +Version information router +""" +import importlib +import sys + +from pydantic import Field + +from ..plugin import Plugin, Router +from ..utils.info import get_sys_info, netcdf_and_hdf5_versions + + +class ModuleVersionAppRouter(Router): + """Module and system version information""" + + prefix = '' + + def register(self): + @self._router.get('/versions') + def get_versions(): + versions = dict(get_sys_info() + netcdf_and_hdf5_versions()) + modules = [ + 'xarray', + 'zarr', + 'numcodecs', + 'fastapi', + 'starlette', + 'pandas', + 'numpy', + 'dask', + 'distributed', + 'uvicorn', + ] + for modname in modules: + try: + if modname in sys.modules: + mod = sys.modules[modname] + else: # pragma: no cover + mod = importlib.import_module(modname) + versions[modname] = getattr(mod, '__version__', None) + except ImportError: # pragma: no cover + pass + return versions + + +class ModuleVersionPlugin(Plugin): + name = 'module_version' + + app_router: ModuleVersionAppRouter = Field(default_factory=ModuleVersionAppRouter) diff --git a/xpublish/included_plugins/plugin_info.py b/xpublish/included_plugins/plugin_info.py new file mode 100644 index 00000000..d32a2996 --- /dev/null +++ b/xpublish/included_plugins/plugin_info.py @@ -0,0 +1,49 @@ +""" +Plugin information router +""" +import importlib +from typing import Dict, Optional + +from fastapi import Depends +from pydantic import BaseModel, Field + +from ..plugin import Plugin, Router + + +class PluginInfo(BaseModel): + path: str + version: Optional[str] + + +class PluginInfoAppRouter(Router): + """Plugin information""" + + prefix = '' + + def register(self): + @self._router.get('/plugins') + def get_plugins( + plugins: Dict[str, Plugin] = Depends(self.deps.plugins) + ) -> Dict[str, PluginInfo]: + plugin_info = {} + + for name, plugin in plugins.items(): + plugin_type = type(plugin) + module_name = plugin_type.__module__.split('.')[0] + try: + mod = importlib.import_module(module_name) + version = getattr(mod, '__version__', None) + except ImportError: + version = None + + plugin_info[name] = PluginInfo( + path=f'{plugin_type.__module__}.{plugin.__repr_name__()}', version=version + ) + + return plugin_info + + +class PluginInfoPlugin(Plugin): + name = 'plugin_info' + + app_router: PluginInfoAppRouter = Field(default_factory=PluginInfoAppRouter) diff --git a/xpublish/included_plugins/zarr.py b/xpublish/included_plugins/zarr.py new file mode 100644 index 00000000..543ba39b --- /dev/null +++ b/xpublish/included_plugins/zarr.py @@ -0,0 +1,114 @@ +import json +import logging +from typing import List + +import cachey +import xarray as xr +from fastapi import Depends, HTTPException +from pydantic import Field +from starlette.responses import Response +from zarr.storage import array_meta_key, attrs_key, group_meta_key + +from ..dependencies import get_cache, get_dataset, get_zmetadata, get_zvariables +from ..plugin import Plugin, Router +from ..utils.api import DATASET_ID_ATTR_KEY +from ..utils.cache import CostTimer +from ..utils.zarr import encode_chunk, get_data_chunk, jsonify_zmetadata, zarr_metadata_key + +logger = logging.getLogger('zarr_api') + + +class ZarrDatasetRouter(Router): + """Provides access to data and metadata through as Zarr compatible API.""" + + prefix: str = '' + tags: List[str] = Field(default_factory=lambda: ['zarr']) + + def register(self): + @self._router.get(f'/{zarr_metadata_key}') + def get_zarr_metadata( + dataset=Depends(self.deps.dataset), + cache=Depends(self.deps.cache), + ): + zvariables = get_zvariables(dataset, cache) + zmetadata = get_zmetadata(dataset, cache, zvariables) + + zjson = jsonify_zmetadata(dataset, zmetadata) + + return Response(json.dumps(zjson).encode('ascii'), media_type='application/json') + + @self._router.get(f'/{group_meta_key}') + def get_zarr_group( + dataset=Depends(self.deps.dataset), + cache=Depends(self.deps.cache), + ): + zvariables = get_zvariables(dataset, cache) + zmetadata = get_zmetadata(dataset, cache, zvariables) + + return zmetadata['metadata'][group_meta_key] + + @self._router.get(f'/{attrs_key}') + def get_zarr_attrs( + dataset=Depends(self.deps.dataset), + cache=Depends(self.deps.cache), + ): + zvariables = get_zvariables(dataset, cache) + zmetadata = get_zmetadata(dataset, cache, zvariables) + + return zmetadata['metadata'][attrs_key] + + @self._router.get('/{var}/{chunk}') + def get_variable_chunk( + var: str, + chunk: str, + dataset: xr.Dataset = Depends(get_dataset), + cache: cachey.Cache = Depends(get_cache), + ): + """Get a zarr array chunk. + + This will return cached responses when available. + + """ + zvariables = get_zvariables(dataset, cache) + zmetadata = get_zmetadata(dataset, cache, zvariables) + + # First check that this request wasn't for variable metadata + if array_meta_key in chunk: + return zmetadata['metadata'][f'{var}/{array_meta_key}'] + elif attrs_key in chunk: + return zmetadata['metadata'][f'{var}/{attrs_key}'] + elif group_meta_key in chunk: + raise HTTPException(status_code=404, detail='No subgroups') + else: + logger.debug('var is %s', var) + logger.debug('chunk is %s', chunk) + + cache_key = dataset.attrs.get(DATASET_ID_ATTR_KEY, '') + '/' + f'{var}/{chunk}' + response = cache.get(cache_key) + + if response is None: + with CostTimer() as ct: + arr_meta = zmetadata['metadata'][f'{var}/{array_meta_key}'] + da = zvariables[var].data + + data_chunk = get_data_chunk(da, chunk, out_shape=arr_meta['chunks']) + + echunk = encode_chunk( + data_chunk.tobytes(), + filters=arr_meta['filters'], + compressor=arr_meta['compressor'], + ) + + response = Response(echunk, media_type='application/octet-stream') + + cache.put(cache_key, response, ct.time, len(echunk)) + + return response + + +class ZarrPlugin(Plugin): + """Adds Zarr-like accessing endpoints for datasets""" + + name = 'zarr' + + dataset_router: ZarrDatasetRouter = Field(default_factory=ZarrDatasetRouter) diff --git a/xpublish/plugin/__init__.py b/xpublish/plugin/__init__.py new file mode 100644 index 00000000..ff53127a --- /dev/null +++ b/xpublish/plugin/__init__.py @@ -0,0 +1,2 @@ +from .base import Plugin, Router, get_plugins # noqa: F401 +from .manage import configure_plugins, find_default_plugins, load_default_plugins # noqa: F401 diff --git a/xpublish/plugin/base.py b/xpublish/plugin/base.py new file mode 100644 index 00000000..369f1fa8 --- /dev/null +++ b/xpublish/plugin/base.py @@ -0,0 +1,118 @@ +from typing import Any, Callable, List, Optional + +import cachey +import xarray as xr +from fastapi import APIRouter +from pydantic import BaseModel, Field, PrivateAttr + +from ..dependencies import get_cache, get_dataset, get_dataset_ids, get_plugins + + +class PluginDependencies(BaseModel): + dataset_ids: Callable[..., List[str]] = get_dataset_ids + dataset: Callable[..., xr.Dataset] = get_dataset + cache: Callable[..., cachey.Cache] = get_cache + plugins: Callable[..., 'Plugin'] = get_plugins + + +class Plugin(BaseModel): + """ + Xpublish plugins provide ways to extend the core of xpublish with + new routers and other functionality. + + To create a plugin, subclass ``Plugin` and add attributes that are + subclasses of `PluginType` (`Router` for instance). + + The specific attributes correspond to how Xpublish should use + the plugin. + """ + + name: str + dependencies: PluginDependencies = Field( + default_factory=PluginDependencies, + description='Xpublish dependencies, which can be overridden on a per-plugin basis', + ) + + app_router: Optional['Router'] = Field( + description='Top level routes that are not dependent on specific datasets' + ) + dataset_router: Optional['Router'] = Field( + description='Routes that are dependent on specific datasets' + ) + + def __init__(self, **data: Any) -> None: + super().__init__(**data) + self.set_parent() + self.register() + + def register(self): + """Setup routes and other plugin functionality""" + for extension in self.iter_extensions(): + extension.register() + + def iter_extensions(self): + """Iterate over all types of plugins that the plugin supports""" + for key in self.dict(): + attr = getattr(self, key) + + if isinstance(attr, PluginType): + yield attr + + def _parent(self): + """Secret helper method to allow extensions to access the parent plugin without + ending up in a recursive loop""" + return self + + def set_parent(self): + """ + Set the parent attribute on extensions to allow them to access attributes + from other parts of plugins + """ + for extension in self.iter_extensions(): + extension._parent = self._parent + + +class PluginType(BaseModel): + """A base class for various plugin functionality to be built off of. + + This helps provide access to the parent class, and dependencies. + + Subclasses need to reimplement the `register()` method to enable + their functionality. + """ + + _parent: Optional[Callable[[], Plugin]] = PrivateAttr() + + @property + def plugin(self): + """Access the plugin from an extension""" + return self._parent() + + @property + def deps(self): + """Access the dependencies of plugin""" + return self.plugin.dependencies + + def register(self): + """Implement for any plugin functionality that requires setup. + + If no setup is needed, re-implement and pass to avoid errors. + """ + raise NotImplementedError + + +class Router(PluginType): + """Base class used by plugins implementing new routes. + + Subclass Router, and create routes by re-implementing `register()` with + using `@self._router.METHOD()` on nested functions. + """ + + prefix: str = Field(description='Shared route prefix') + tags: List[str] = Field(default_factory=list, description='Tags in OpenAPI documentation') + + _router: APIRouter = PrivateAttr(default_factory=APIRouter) + + def register(self): + """Register API Routes""" + raise NotImplementedError diff --git a/xpublish/plugin/manage.py b/xpublish/plugin/manage.py new file mode 100644 index 00000000..5926baa9 --- /dev/null +++ b/xpublish/plugin/manage.py @@ -0,0 +1,45 @@ +""" +Load and configure Xpublish plugins from entry point group `xpublish.plugin` +""" +from importlib.metadata import entry_points +from typing import Dict, Iterable, Optional + +from .base import Plugin + + +def find_default_plugins(exclude_plugins: Optional[Iterable[str]] = None): + """Find Xpublish plugins from entry point group `xpublish.plugin` + + Individual plugins may be ignored by adding them to `exclude_plugins`. + """ + exclude_plugins = set(exclude_plugins or []) + + plugins: Dict[str, Plugin] = {} + + for entry_point in entry_points()['xpublish.plugin']: + if entry_point.name not in exclude_plugins: + plugins[entry_point.name] = entry_point.load() + + return plugins + + +def load_default_plugins(exclude_plugins: Optional[Iterable[str]] = None): + """Find and initialize plugins from entry point group `xpublish.plugin`""" + initialized_plugins: Dict[str, Plugin] = {} + + for name, plugin in find_default_plugins(exclude_plugins=exclude_plugins).items(): + initialized_plugins[name] = plugin() + + return initialized_plugins + + +def configure_plugins(plugins: Dict[str, Plugin], plugin_configs: Optional[Dict] = None): + """Initialize and configure plugins""" + initialized_plugins: Dict[str, Plugin] = {} + plugin_configs = plugin_configs or {} + + for name, plugin in plugins.items(): + kwargs = plugin_configs.get(name, {}) + initialized_plugins[name] = plugin(**kwargs) + + return initialized_plugins diff --git a/xpublish/rest.py b/xpublish/rest.py index 7981531b..9a410f04 100644 --- a/xpublish/rest.py +++ b/xpublish/rest.py @@ -1,10 +1,12 @@ +from typing import Dict, Optional + import cachey import uvicorn -import xarray as xr from fastapi import FastAPI, HTTPException from .dependencies import get_cache, get_dataset, get_dataset_ids -from .routers import base_router, common_router, dataset_collection_router, zarr_router +from .plugin import Plugin, get_plugins, load_default_plugins +from .routers import dataset_collection_router from .utils.api import ( SingleDatasetOpenAPIOverrider, check_route_conflicts, @@ -44,19 +46,9 @@ def _set_app_routers(dataset_routers=None, dataset_route_prefix=''): app_routers = [] - # top-level api endpoints - app_routers.append((common_router, {})) - if dataset_route_prefix: app_routers.append((dataset_collection_router, {'tags': ['info']})) - # dataset-specifc api endpoints - if dataset_routers is None: - dataset_routers = [ - (base_router, {'tags': ['info']}), - (zarr_router, {'tags': ['zarr']}), - ] - app_routers += normalize_app_routers(dataset_routers, dataset_route_prefix) check_route_conflicts(app_routers) @@ -79,8 +71,8 @@ class Rest: are converted to strings. See also the notes below. routers : list, optional A list of dataset-specific :class:`fastapi.APIRouter` instances to - include in the fastAPI application. If None, the default routers will be - included. + include in the fastAPI application. These routers are in addition + to any loaded via plugins. The items of the list may also be tuples with the following format: ``[(router1, {'prefix': '/foo', 'tags': ['foo', 'bar']})]``, where the 1st tuple element is a :class:`fastapi.APIRouter` instance and the @@ -94,6 +86,10 @@ class Rest: app_kws : dict, optional Dictionary of keyword arguments to be passed to :meth:`fastapi.FastAPI.__init__()`. + plugins : dict, optional + Optional dictionary of loaded, configured plugins. + Overrides automatic loading of plugins. + If no plugins are desired, set to an empty dict. Notes ----- @@ -106,30 +102,92 @@ class Rest: """ - def __init__(self, datasets, routers=None, cache_kws=None, app_kws=None): + def __init__( + self, + datasets, + routers=None, + cache_kws=None, + app_kws=None, + plugins: Optional[Dict[str, Plugin]] = None, + ): - self._datasets = normalize_datasets(datasets) + dataset_route_prefix = self.setup_datasets(datasets) - if not self._datasets: - # publish single dataset - self._get_dataset_func = _dataset_unique_getter(datasets) - dataset_route_prefix = '' + self.setup_plugins(plugins) + self.setup_routers(routers, dataset_route_prefix) + + self.init_app_kwargs(app_kws) + self.init_cache_kwargs(cache_kws) + + def setup_routers(self, routers, dataset_route_prefix): + """Setup plugin and dataset routers""" + plugin_app_routers, plugin_dataset_routers = self.plugin_routers() + + self._app_routers = plugin_app_routers + self._app_routers.extend( + _set_app_routers(plugin_dataset_routers + (routers or []), dataset_route_prefix) + ) + + def setup_plugins(self, plugins: Optional[Dict[str, Plugin]] = None): + """Initialize and load plugins from entry_points""" + if plugins is None: + self._plugins = load_default_plugins() else: - self._get_dataset_func = _dataset_from_collection_getter(self._datasets) - dataset_route_prefix = '/datasets/{dataset_id}' + self._plugins = plugins - self._app_routers = _set_app_routers(routers, dataset_route_prefix) + def plugin_routers(self): + """Load the app and dataset routers for plugins""" + app_routers = [] + dataset_routers = [] - self._app = None - self._app_kws = {} - if app_kws is not None: - self._app_kws.update(app_kws) + for plugin in self._plugins.values(): + if plugin.app_router: + router_kwargs = {} + if plugin.app_router.prefix: + router_kwargs['prefix'] = plugin.app_router.prefix + if plugin.app_router.tags: + router_kwargs['tags'] = plugin.app_router.tags + + app_routers.append((plugin.app_router._router, router_kwargs)) + + if plugin.dataset_router: + router_kwargs = {} + if plugin.dataset_router.prefix: + router_kwargs['prefix'] = plugin.dataset_router.prefix + if plugin.dataset_router.tags: + router_kwargs['tags'] = plugin.dataset_router.tags + + dataset_routers.append((plugin.dataset_router._router, router_kwargs)) + + return app_routers, dataset_routers + def init_cache_kwargs(self, cache_kws): + """Set up cache kwargs""" self._cache = None self._cache_kws = {'available_bytes': 1e6} if cache_kws is not None: self._cache_kws.update(cache_kws) + def init_app_kwargs(self, app_kws): + """Set up FastAPI application kwargs""" + self._app = None + self._app_kws = {} + if app_kws is not None: + self._app_kws.update(app_kws) + + def setup_datasets(self, datasets): + """Initialize datasets and getter functions""" + self._datasets = normalize_datasets(datasets) + + if not self._datasets: + # publish single dataset + self._get_dataset_func = _dataset_unique_getter(datasets) + dataset_route_prefix = '' + else: + self._get_dataset_func = _dataset_from_collection_getter(self._datasets) + dataset_route_prefix = '/datasets/{dataset_id}' + return dataset_route_prefix + @property def cache(self) -> cachey.Cache: """Returns the :class:`cachey.Cache` instance used by the FastAPI application.""" @@ -138,6 +196,18 @@ def cache(self) -> cachey.Cache: self._cache = cachey.Cache(**self._cache_kws) return self._cache + @property + def plugins(self) -> Dict[str, Plugin]: + """Returns the loaded plugins""" + return self._plugins + + def _init_dependencies(self): + """Initialize dependencies""" + self._app.dependency_overrides[get_dataset_ids] = lambda: list(self._datasets) + self._app.dependency_overrides[get_dataset] = self._get_dataset_func + self._app.dependency_overrides[get_cache] = lambda: self.cache + self._app.dependency_overrides[get_plugins] = lambda: self.plugins + def _init_app(self): """Initiate the FastAPI application.""" @@ -146,9 +216,7 @@ def _init_app(self): for rt, kwargs in self._app_routers: self._app.include_router(rt, **kwargs) - self._app.dependency_overrides[get_dataset_ids] = lambda: list(self._datasets) - self._app.dependency_overrides[get_dataset] = self._get_dataset_func - self._app.dependency_overrides[get_cache] = lambda: self.cache + self._init_dependencies() if not self._datasets: # fix openapi spec for single dataset @@ -184,72 +252,3 @@ def serve(self, host='0.0.0.0', port=9000, log_level='debug', **kwargs): """ uvicorn.run(self.app, host=host, port=port, log_level=log_level, **kwargs) - - -@xr.register_dataset_accessor('rest') -class RestAccessor: - """REST API Accessor for serving one dataset in its - dedicated FastAPI application. - - """ - - def __init__(self, xarray_obj): - - self._obj = xarray_obj - self._rest = None - - self._initialized = False - - def _get_rest_obj(self): - if self._rest is None: - self._rest = Rest(self._obj) - - return self._rest - - def __call__(self, **kwargs): - """Initialize this accessor by setting optional configuration values. - - Parameters - ---------- - **kwargs - Arguments passed to :func:`xpublish.Rest.__init__`. - - Notes - ----- - This method can only be invoked once. - - """ - if self._initialized: - raise RuntimeError('This accessor has already been initialized') - self._initialized = True - - self._rest = Rest(self._obj, **kwargs) - - return self - - @property - def cache(self) -> cachey.Cache: - """Returns the :class:`cachey.Cache` instance used by the FastAPI application.""" - - return self._get_rest_obj().cache - - @property - def app(self) -> FastAPI: - """Returns the :class:`fastapi.FastAPI` application instance.""" - - return self._get_rest_obj().app - - def serve(self, **kwargs): - """Serve this FastAPI application via :func:`uvicorn.run`. - - Parameters - ---------- - **kwargs : - Arguments passed to :func:`xpublish.Rest.serve`. - - Notes - ----- - This method is blocking and does not return. - - """ - self._get_rest_obj().serve(**kwargs) diff --git a/xpublish/routers/__init__.py b/xpublish/routers/__init__.py index 6a71e11b..ef14749a 100644 --- a/xpublish/routers/__init__.py +++ b/xpublish/routers/__init__.py @@ -1,3 +1 @@ -from .base import base_router -from .common import common_router, dataset_collection_router -from .zarr import zarr_router +from .common import dataset_collection_router diff --git a/xpublish/routers/base.py b/xpublish/routers/base.py deleted file mode 100644 index 00ad1c46..00000000 --- a/xpublish/routers/base.py +++ /dev/null @@ -1,54 +0,0 @@ -import xarray as xr -from fastapi import APIRouter, Depends -from starlette.responses import HTMLResponse -from zarr.storage import attrs_key - -from ..dependencies import get_dataset, get_zmetadata, get_zvariables - -base_router = APIRouter() - - -@base_router.get('/') -def html_representation(dataset: xr.Dataset = Depends(get_dataset)): - """Returns a HTML representation of the dataset.""" - - with xr.set_options(display_style='html'): - return HTMLResponse(dataset._repr_html_()) - - -@base_router.get('/keys') -def list_keys(dataset: xr.Dataset = Depends(get_dataset)): - return list(dataset.variables) - - -@base_router.get('/dict') -def to_dict(dataset: xr.Dataset = Depends(get_dataset)): - return dataset.to_dict(data=False) - - -@base_router.get('/info') -def info( - dataset: xr.Dataset = Depends(get_dataset), - zvariables: dict = Depends(get_zvariables), - zmetadata: dict = Depends(get_zmetadata), -): - """Dataset schema (close to the NCO-JSON schema).""" - - info = {} - info['dimensions'] = dict(dataset.dims.items()) - info['variables'] = {} - - meta = zmetadata['metadata'] - - for name, var in zvariables.items(): - attrs = meta[f'{name}/{attrs_key}'] - attrs.pop('_ARRAY_DIMENSIONS') - info['variables'][name] = { - 'type': var.data.dtype.name, - 'dimensions': list(var.dims), - 'attributes': attrs, - } - - info['global_attributes'] = meta[attrs_key] - - return info diff --git a/xpublish/routers/common.py b/xpublish/routers/common.py index 8451b9cc..658bf084 100644 --- a/xpublish/routers/common.py +++ b/xpublish/routers/common.py @@ -2,42 +2,43 @@ Dataset-independent API routes. """ -import importlib -import sys +# import importlib +# import sys from fastapi import APIRouter, Depends from ..dependencies import get_dataset_ids -from ..utils.info import get_sys_info, netcdf_and_hdf5_versions - -common_router = APIRouter() - - -@common_router.get('/versions') -def get_versions(): - versions = dict(get_sys_info() + netcdf_and_hdf5_versions()) - modules = [ - 'xarray', - 'zarr', - 'numcodecs', - 'fastapi', - 'starlette', - 'pandas', - 'numpy', - 'dask', - 'distributed', - 'uvicorn', - ] - for modname in modules: - try: - if modname in sys.modules: - mod = sys.modules[modname] - else: # pragma: no cover - mod = importlib.import_module(modname) - versions[modname] = getattr(mod, '__version__', None) - except ImportError: # pragma: no cover - pass - return versions + +# from ..utils.info import get_sys_info, netcdf_and_hdf5_versions + +# common_router = APIRouter() + + +# @common_router.get('/versions') +# def get_versions(): +# versions = dict(get_sys_info() + netcdf_and_hdf5_versions()) +# modules = [ +# 'xarray', +# 'zarr', +# 'numcodecs', +# 'fastapi', +# 'starlette', +# 'pandas', +# 'numpy', +# 'dask', +# 'distributed', +# 'uvicorn', +# ] +# for modname in modules: +# try: +# if modname in sys.modules: +# mod = sys.modules[modname] +# else: # pragma: no cover +# mod = importlib.import_module(modname) +# versions[modname] = getattr(mod, '__version__', None) +# except ImportError: # pragma: no cover +# pass +# return versions dataset_collection_router = APIRouter() diff --git a/xpublish/routers/zarr.py b/xpublish/routers/zarr.py deleted file mode 100644 index 582e91f9..00000000 --- a/xpublish/routers/zarr.py +++ /dev/null @@ -1,86 +0,0 @@ -import json -import logging - -import cachey -import xarray as xr -from fastapi import APIRouter, Depends, HTTPException -from starlette.responses import Response -from zarr.storage import array_meta_key, attrs_key, group_meta_key - -from ..dependencies import get_cache, get_dataset, get_zmetadata as _get_zmetadata, get_zvariables -from ..utils.api import DATASET_ID_ATTR_KEY -from ..utils.cache import CostTimer -from ..utils.zarr import encode_chunk, get_data_chunk, jsonify_zmetadata, zarr_metadata_key - -logger = logging.getLogger('api') - -zarr_router = APIRouter() - - -@zarr_router.get(f'/{zarr_metadata_key}') -def get_zmetadata( - dataset: xr.Dataset = Depends(get_dataset), zmetadata: dict = Depends(_get_zmetadata) -): - zjson = jsonify_zmetadata(dataset, zmetadata) - - return Response(json.dumps(zjson).encode('ascii'), media_type='application/json') - - -@zarr_router.get(f'/{group_meta_key}') -def get_zgroup(zmetadata: dict = Depends(_get_zmetadata)): - - return zmetadata['metadata'][group_meta_key] - - -@zarr_router.get(f'/{attrs_key}') -def get_zattrs(zmetadata: dict = Depends(_get_zmetadata)): - - return zmetadata['metadata'][attrs_key] - - -@zarr_router.get('/{var}/{chunk}') -def get_variable_chunk( - var: str, - chunk: str, - dataset: xr.Dataset = Depends(get_dataset), - cache: cachey.Cache = Depends(get_cache), - zvariables: dict = Depends(get_zvariables), - zmetadata: dict = Depends(_get_zmetadata), -): - """Get a zarr array chunk. - - This will return cached responses when available. - - """ - # First check that this request wasn't for variable metadata - if array_meta_key in chunk: - return zmetadata['metadata'][f'{var}/{array_meta_key}'] - elif attrs_key in chunk: - return zmetadata['metadata'][f'{var}/{attrs_key}'] - elif group_meta_key in chunk: - raise HTTPException(status_code=404, detail='No subgroups') - else: - logger.debug('var is %s', var) - logger.debug('chunk is %s', chunk) - - cache_key = dataset.attrs.get(DATASET_ID_ATTR_KEY, '') + '/' + f'{var}/{chunk}' - response = cache.get(cache_key) - - if response is None: - with CostTimer() as ct: - arr_meta = zmetadata['metadata'][f'{var}/{array_meta_key}'] - da = zvariables[var].data - - data_chunk = get_data_chunk(da, chunk, out_shape=arr_meta['chunks']) - - echunk = encode_chunk( - data_chunk.tobytes(), - filters=arr_meta['filters'], - compressor=arr_meta['compressor'], - ) - - response = Response(echunk, media_type='application/octet-stream') - - cache.put(cache_key, response, ct.time, len(echunk)) - - return response