Skip to content

Commit

Permalink
Add entry point based plugins
Browse files Browse the repository at this point in the history
Builds on top of @benbovy 's work in building router factories in xpublish-community#89 to build a plugin system.

The plugin system uses entry points, which are most commonly used for console or GUI scripts. The entry_point group is `xpublish.plugin` Right now plugins can provide dataset specific and general (app) routes, with default prefixes and tags for both.

Xpublish will by default load plugins via the entry point. Additionally, plugins can also be loaded directly via the init, as well as being disabled, or configured. The existing dataset router pattern also still works, so that folks aren't forced into using plugins

Entry point reference:
- https://setuptools.pypa.io/en/latest/userguide/entry_point.html
- https://packaging.python.org/en/latest/specifications/entry-points/
- https://amir.rachum.com/amp/blog/2017/07/28/python-entry-points.html
  • Loading branch information
abkfenris committed Dec 10, 2022
1 parent cdb717a commit 900c3d1
Show file tree
Hide file tree
Showing 12 changed files with 281 additions and 103 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ repos:
- id: check-yaml
- id: double-quote-string-fixer

- repo: https://github.com/ambv/black
rev: 21.6b0
- repo: https://github.com/psf/black
rev: 22.12.0
hooks:
- id: black
args: ["--line-length", "100", "--skip-string-normalization"]
Expand Down
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,11 @@
keywords=['xarray', 'zarr', 'api'],
use_scm_version={'version_scheme': 'post-release', 'local_scheme': 'dirty-tag'},
setup_requires=['setuptools_scm', 'setuptools>=30.3.0'],
entry_points={
'xpublish.plugin': [
'base = xpublish.plugins.base:BasePlugin',
'zarr = xpublish.plugins.zarr:ZarrPlugin',
'module_version = xpublish.plugins.module_version:ModuleVersionPlugin'
]
},
)
2 changes: 2 additions & 0 deletions xpublish/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .factory import XpublishPluginFactory # noqa: F401
from .load import configure_plugins, find_plugins # noqa: F401
18 changes: 11 additions & 7 deletions xpublish/routers/base.py → xpublish/plugins/base.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List

import xarray as xr
from fastapi import Depends
from starlette.responses import HTMLResponse
from zarr.storage import attrs_key

from ..dependencies import get_zmetadata, get_zvariables
from .factory import XpublishFactory
from .factory import XpublishPluginFactory


@dataclass
class BaseFactory(XpublishFactory):
class BasePlugin(XpublishPluginFactory):
"""API entry-points providing basic information about the dataset(s)."""

dataset_router_prefix: str = '/info'
dataset_router_tags: List[str] = field(default_factory=lambda: ['info'])

def register_routes(self):
@self.router.get('/')
@self.dataset_router.get('/')
def html_representation(
dataset=Depends(self.dataset_dependency),
):
Expand All @@ -23,19 +27,19 @@ def html_representation(
with xr.set_options(display_style='html'):
return HTMLResponse(dataset._repr_html_())

@self.router.get('/keys')
@self.dataset_router.get('/keys')
def list_keys(
dataset=Depends(self.dataset_dependency),
):
return list(dataset.variables)

@self.router.get('/dict')
@self.dataset_router.get('/dict')
def to_dict(
dataset=Depends(self.dataset_dependency),
):
return dataset.to_dict(data=False)

@self.router.get('/info')
@self.dataset_router.get('/info')
def info(
dataset=Depends(self.dataset_dependency),
cache=Depends(self.cache_dependency),
Expand Down
39 changes: 39 additions & 0 deletions xpublish/plugins/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from dataclasses import dataclass, field
from typing import Callable, List, Optional

import cachey
import xarray as xr
from fastapi import APIRouter

from ..dependencies import get_cache, get_dataset, get_dataset_ids


@dataclass
class XpublishPluginFactory:
"""Xpublish plugin factory.
Xpublish plugins are designed to be automatically loaded via the entry point
group `xpublish.plugin` from any installed package.
Plugins can define both app (top-level) and dataset based routes, and
default prefixes and tags for both.
"""

app_router: APIRouter = field(default_factory=APIRouter)
app_router_prefix: Optional[str] = None
app_router_tags: List[str] = field(default_factory=list)

dataset_router: APIRouter = field(default_factory=APIRouter)
dataset_router_prefix: Optional[str] = None
dataset_router_tags: List[str] = field(default_factory=list)

dataset_ids_dependency: Callable[..., List[str]] = get_dataset_ids
dataset_dependency: Callable[..., xr.Dataset] = get_dataset
cache_dependency: Callable[..., cachey.Cache] = get_cache

def __post_init__(self):
self.register_routes()

def register_routes(self):
"""Register xpublish routes."""
raise NotImplementedError()
34 changes: 34 additions & 0 deletions xpublish/plugins/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Load and configure Xpublish plugins from entry point group `xpublish.plugin`
"""
from importlib.metadata import entry_points
from typing import Dict, List, Optional

from .factory import XpublishPluginFactory


def find_plugins(exclude_plugins: Optional[List[str]] = None):
"""Find Xpublish plugins from entry point group `xpublish.plugin`"""
exclude_plugin_names = set(exclude_plugins or [])

plugins: Dict[str, XpublishPluginFactory] = {}

for entry_point in entry_points()['xpublish.plugin']:
if entry_point.name not in exclude_plugin_names:
plugins[entry_point.name] = entry_point.load()

return plugins


def configure_plugins(
plugins: Dict[str, XpublishPluginFactory], plugin_configs: Optional[Dict] = None
):
"""Initialize and configure plugins"""
initialized_plugins: Dict[str, XpublishPluginFactory] = {}
plugin_configs = plugin_configs or {}

for name, plugin in plugins.items():
kwargs = plugin_configs.get(name, {})
initialized_plugins[name] = plugin(**kwargs)

return initialized_plugins
42 changes: 42 additions & 0 deletions xpublish/plugins/module_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Version information router
"""
from dataclasses import dataclass
import importlib
import sys
from typing import List

from ..utils.info import get_sys_info, netcdf_and_hdf5_versions
from .factory import XpublishPluginFactory


@dataclass
class ModuleVersionPlugin(XpublishPluginFactory):
"""Module and system version information"""

def register_routes(self):
@self.app_router.get("/versions")
def get_versions():
versions = dict(get_sys_info() + netcdf_and_hdf5_versions())
modules = [
'xarray',
'zarr',
'numcodecs',
'fastapi',
'starlette',
'pandas',
'numpy',
'dask',
'distributed',
'uvicorn',
]
for modname in modules:
try:
if modname in sys.modules:
mod = sys.modules[modname]
else: # pragma: no cover
mod = importlib.import_module(modname)
versions[modname] = getattr(mod, '__version__', None)
except ImportError: # pragma: no cover
pass
return versions
18 changes: 11 additions & 7 deletions xpublish/routers/zarr.py → xpublish/plugins/zarr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List

import cachey
import xarray as xr
Expand All @@ -12,17 +13,20 @@
from ..utils.api import DATASET_ID_ATTR_KEY
from ..utils.cache import CostTimer
from ..utils.zarr import encode_chunk, get_data_chunk, jsonify_zmetadata, zarr_metadata_key
from .factory import XpublishFactory
from .factory import XpublishPluginFactory

logger = logging.getLogger('zarr_api')


@dataclass
class ZarrFactory(XpublishFactory):
class ZarrPlugin(XpublishPluginFactory):
"""Provides access to data and metadata through as Zarr compatible API."""

dataset_router_prefix: str = '/zarr'
dataset_router_tags: List[str] = field(default_factory=lambda: ['zarr'])

def register_routes(self):
@self.router.get(f'/{zarr_metadata_key}')
@self.dataset_router.get(f'/{zarr_metadata_key}')
def get_zarr_metadata(
dataset=Depends(self.dataset_dependency),
cache=Depends(self.cache_dependency),
Expand All @@ -34,7 +38,7 @@ def get_zarr_metadata(

return Response(json.dumps(zjson).encode('ascii'), media_type='application/json')

@self.router.get(f'/{group_meta_key}')
@self.dataset_router.get(f'/{group_meta_key}')
def get_zarr_group(
dataset=Depends(self.dataset_dependency),
cache=Depends(self.cache_dependency),
Expand All @@ -44,7 +48,7 @@ def get_zarr_group(

return zmetadata['metadata'][group_meta_key]

@self.router.get(f'/{attrs_key}')
@self.dataset_router.get(f'/{attrs_key}')
def get_zarr_attrs(
dataset=Depends(self.dataset_dependency),
cache=Depends(self.cache_dependency),
Expand All @@ -54,7 +58,7 @@ def get_zarr_attrs(

return zmetadata['metadata'][attrs_key]

@self.router.get('/{var}/{chunk}')
@self.dataset_router.get('/{var}/{chunk}')
def get_variable_chunk(
var: str,
chunk: str,
Expand Down
Loading

0 comments on commit 900c3d1

Please sign in to comment.