Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Entry point plugins #140

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
- id: double-quote-string-fixer

- repo: https://github.com/psf/black
rev: 22.10.0
rev: 22.12.0
hooks:
- id: black
args: ["--line-length", "100", "--skip-string-normalization"]
Expand Down
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,11 @@
keywords=['xarray', 'zarr', 'api'],
use_scm_version={'version_scheme': 'post-release', 'local_scheme': 'dirty-tag'},
setup_requires=['setuptools_scm>=3.4', 'setuptools>=42'],
entry_points={
'xpublish.plugin': [
'base = xpublish.plugins.base:BasePlugin',
'zarr = xpublish.plugins.zarr:ZarrPlugin',
'module_version = xpublish.plugins.module_version:ModuleVersionPlugin',
]
},
)
2 changes: 1 addition & 1 deletion tests/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_custom_app_routers(airtemp_ds, dims_router, router_kws, path):
else:
routers = [(dims_router, router_kws)]

rest = Rest(airtemp_ds, routers=routers)
rest = Rest(airtemp_ds, routers=routers, plugins={})
client = TestClient(rest.app)

response = client.get(path)
Expand Down
12 changes: 7 additions & 5 deletions xpublish/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Helper functions to use a FastAPI dependencies.
"""
from typing import List

import cachey
import xarray as xr
from fastapi import Depends
Expand All @@ -9,7 +11,7 @@
from .utils.zarr import create_zmetadata, create_zvariables, zarr_metadata_key


def get_dataset_ids():
def get_dataset_ids() -> List[str]:
"""FastAPI dependency for getting the list of ids (string keys)
of the collection of datasets being served.

Expand All @@ -23,7 +25,7 @@ def get_dataset_ids():
return [] # pragma: no cover


def get_dataset(dataset_id: str):
def get_dataset(dataset_id: str) -> xr.Dataset:
"""FastAPI dependency for accessing the published xarray dataset object.

Use this callable as dependency in any FastAPI path operation
Expand All @@ -33,10 +35,10 @@ def get_dataset(dataset_id: str):
application.

"""
return None # pragma: no cover
return xr.Dataset()


def get_cache():
def get_cache() -> cachey.Cache:
"""FastAPI dependency for accessing the application's cache.

Use this callable as dependency in any FastAPI path operation
Expand All @@ -47,7 +49,7 @@ def get_cache():
application.

"""
return None # pragma: no cover
return cachey.Cache(available_bytes=1e6)


def get_zvariables(
Expand Down
2 changes: 2 additions & 0 deletions xpublish/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .factory import XpublishPluginFactory # noqa: F401
from .load import configure_plugins, find_plugins # noqa: F401
69 changes: 69 additions & 0 deletions xpublish/plugins/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from dataclasses import dataclass, field
from typing import List

import xarray as xr
from fastapi import Depends
from starlette.responses import HTMLResponse
from zarr.storage import attrs_key

from ..dependencies import get_zmetadata, get_zvariables
from .factory import XpublishPluginFactory


@dataclass
class BasePlugin(XpublishPluginFactory):
"""API entry-points providing basic information about the dataset(s)."""

dataset_router_prefix: str = ''
dataset_router_tags: List[str] = field(default_factory=lambda: ['info'])

def register_routes(self):
@self.dataset_router.get('/')
def html_representation(
dataset=Depends(self.dataset_dependency),
):
"""Returns a HTML representation of the dataset."""

with xr.set_options(display_style='html'):
return HTMLResponse(dataset._repr_html_())

@self.dataset_router.get('/keys')
def list_keys(
dataset=Depends(self.dataset_dependency),
):
return list(dataset.variables)

@self.dataset_router.get('/dict')
def to_dict(
dataset=Depends(self.dataset_dependency),
):
return dataset.to_dict(data=False)

@self.dataset_router.get('/info')
def info(
dataset=Depends(self.dataset_dependency),
cache=Depends(self.cache_dependency),
):
"""Dataset schema (close to the NCO-JSON schema)."""

zvariables = get_zvariables(dataset, cache)
zmetadata = get_zmetadata(dataset, cache, zvariables)

info = {}
info['dimensions'] = dict(dataset.dims.items())
info['variables'] = {}

meta = zmetadata['metadata']

for name, var in zvariables.items():
attrs = meta[f'{name}/{attrs_key}']
attrs.pop('_ARRAY_DIMENSIONS')
info['variables'][name] = {
'type': var.data.dtype.name,
'dimensions': list(var.dims),
'attributes': attrs,
}

info['global_attributes'] = meta[attrs_key]

return info
61 changes: 61 additions & 0 deletions xpublish/plugins/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from dataclasses import dataclass, field
from typing import Callable, List, Optional

import cachey
import xarray as xr
from fastapi import APIRouter

from ..dependencies import get_cache, get_dataset, get_dataset_ids


@dataclass
class XpublishPluginFactory:
"""Xpublish plugin factory.

Xpublish plugins are designed to be automatically loaded via the entry point
group `xpublish.plugin` from any installed package.

Plugins can define both app (top-level) and dataset based routes, and
default prefixes and tags for both.

Parameters
----------
app_router : ApiRouter
Router for defining top level routes, that is routes
that are not nested under a dataset.
app_router_prefix : str
Default prefix for all app level routes.
app_router_tags : list
Default OpenAPI tags for app level routes.
dataset_router : ApiRouter
Routes that work with individual datasets.
dataset_router_prefix : str
Default prefix for routes under a dataset.
dataset_router_tags : list
Default OpenAPI tags for dataset level routes
dataset_ids_dependency :
Access the current dataset ids
dataset_dependency :
Load the specified dataset in path
cache_dependency :
Access the cache
"""

app_router: APIRouter = field(default_factory=APIRouter)
app_router_prefix: Optional[str] = None
app_router_tags: List[str] = field(default_factory=list)

dataset_router: APIRouter = field(default_factory=APIRouter)
dataset_router_prefix: Optional[str] = None
dataset_router_tags: List[str] = field(default_factory=list)

dataset_ids_dependency: Callable[..., List[str]] = get_dataset_ids
dataset_dependency: Callable[..., xr.Dataset] = get_dataset
cache_dependency: Callable[..., cachey.Cache] = get_cache

def __post_init__(self):
self.register_routes()

def register_routes(self):
"""Register xpublish routes."""
raise NotImplementedError()
34 changes: 34 additions & 0 deletions xpublish/plugins/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Load and configure Xpublish plugins from entry point group `xpublish.plugin`
"""
from importlib.metadata import entry_points
from typing import Dict, List, Optional

from .factory import XpublishPluginFactory


def find_plugins(exclude_plugins: Optional[List[str]] = None):
"""Find Xpublish plugins from entry point group `xpublish.plugin`"""
exclude_plugin_names = set(exclude_plugins or [])

plugins: Dict[str, XpublishPluginFactory] = {}

for entry_point in entry_points()['xpublish.plugin']:
if entry_point.name not in exclude_plugin_names:
plugins[entry_point.name] = entry_point.load()

return plugins


def configure_plugins(
plugins: Dict[str, XpublishPluginFactory], plugin_configs: Optional[Dict] = None
):
"""Initialize and configure plugins"""
initialized_plugins: Dict[str, XpublishPluginFactory] = {}
plugin_configs = plugin_configs or {}

for name, plugin in plugins.items():
kwargs = plugin_configs.get(name, {})
initialized_plugins[name] = plugin(**kwargs)

return initialized_plugins
41 changes: 41 additions & 0 deletions xpublish/plugins/module_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Version information router
"""
import importlib
import sys
from dataclasses import dataclass

from ..utils.info import get_sys_info, netcdf_and_hdf5_versions
from .factory import XpublishPluginFactory


@dataclass
class ModuleVersionPlugin(XpublishPluginFactory):
"""Module and system version information"""

def register_routes(self):
@self.app_router.get('/versions')
def get_versions():
versions = dict(get_sys_info() + netcdf_and_hdf5_versions())
modules = [
'xarray',
'zarr',
'numcodecs',
'fastapi',
'starlette',
'pandas',
'numpy',
'dask',
'distributed',
'uvicorn',
]
for modname in modules:
try:
if modname in sys.modules:
mod = sys.modules[modname]
else: # pragma: no cover
mod = importlib.import_module(modname)
versions[modname] = getattr(mod, '__version__', None)
except ImportError: # pragma: no cover
pass
return versions
107 changes: 107 additions & 0 deletions xpublish/plugins/zarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import json
import logging
from dataclasses import dataclass, field
from typing import List

import cachey
import xarray as xr
from fastapi import Depends, HTTPException
from starlette.responses import Response
from zarr.storage import array_meta_key, attrs_key, group_meta_key

from ..dependencies import get_cache, get_dataset, get_zmetadata, get_zvariables
from ..utils.api import DATASET_ID_ATTR_KEY
from ..utils.cache import CostTimer
from ..utils.zarr import encode_chunk, get_data_chunk, jsonify_zmetadata, zarr_metadata_key
from .factory import XpublishPluginFactory

logger = logging.getLogger('zarr_api')


@dataclass
class ZarrPlugin(XpublishPluginFactory):
"""Provides access to data and metadata through as Zarr compatible API."""

dataset_router_prefix: str = ''
dataset_router_tags: List[str] = field(default_factory=lambda: ['zarr'])

def register_routes(self):
@self.dataset_router.get(f'/{zarr_metadata_key}')
def get_zarr_metadata(
dataset=Depends(self.dataset_dependency),
cache=Depends(self.cache_dependency),
):
zvariables = get_zvariables(dataset, cache)
zmetadata = get_zmetadata(dataset, cache, zvariables)

zjson = jsonify_zmetadata(dataset, zmetadata)

return Response(json.dumps(zjson).encode('ascii'), media_type='application/json')

@self.dataset_router.get(f'/{group_meta_key}')
def get_zarr_group(
dataset=Depends(self.dataset_dependency),
cache=Depends(self.cache_dependency),
):
zvariables = get_zvariables(dataset, cache)
zmetadata = get_zmetadata(dataset, cache, zvariables)

return zmetadata['metadata'][group_meta_key]

@self.dataset_router.get(f'/{attrs_key}')
def get_zarr_attrs(
dataset=Depends(self.dataset_dependency),
cache=Depends(self.cache_dependency),
):
zvariables = get_zvariables(dataset, cache)
zmetadata = get_zmetadata(dataset, cache, zvariables)

return zmetadata['metadata'][attrs_key]

@self.dataset_router.get('/{var}/{chunk}')
def get_variable_chunk(
var: str,
chunk: str,
dataset: xr.Dataset = Depends(get_dataset),
cache: cachey.Cache = Depends(get_cache),
):
"""Get a zarr array chunk.

This will return cached responses when available.

"""
zvariables = get_zvariables(dataset, cache)
zmetadata = get_zmetadata(dataset, cache, zvariables)

# First check that this request wasn't for variable metadata
if array_meta_key in chunk:
return zmetadata['metadata'][f'{var}/{array_meta_key}']
elif attrs_key in chunk:
return zmetadata['metadata'][f'{var}/{attrs_key}']
elif group_meta_key in chunk:
raise HTTPException(status_code=404, detail='No subgroups')
else:
logger.debug('var is %s', var)
logger.debug('chunk is %s', chunk)

cache_key = dataset.attrs.get(DATASET_ID_ATTR_KEY, '') + '/' + f'{var}/{chunk}'
response = cache.get(cache_key)

if response is None:
with CostTimer() as ct:
arr_meta = zmetadata['metadata'][f'{var}/{array_meta_key}']
da = zvariables[var].data

data_chunk = get_data_chunk(da, chunk, out_shape=arr_meta['chunks'])

echunk = encode_chunk(
data_chunk.tobytes(),
filters=arr_meta['filters'],
compressor=arr_meta['compressor'],
)

response = Response(echunk, media_type='application/octet-stream')

cache.put(cache_key, response, ct.time, len(echunk))

return response
Loading