diff --git a/.gitignore b/.gitignore index 4d985df0..a4fc5038 100644 --- a/.gitignore +++ b/.gitignore @@ -115,6 +115,9 @@ venv.bak/ .spyderproject .spyproject +# Pycharm project settings +.idea + # Rope project settings .ropeproject diff --git a/xpublish/routers/__init__.py b/xpublish/routers/__init__.py index 6a71e11b..5b272d18 100644 --- a/xpublish/routers/__init__.py +++ b/xpublish/routers/__init__.py @@ -1,3 +1,8 @@ from .base import base_router from .common import common_router, dataset_collection_router from .zarr import zarr_router + +try: + from .xyz import xyz_router +except ImportError: + pass \ No newline at end of file diff --git a/xpublish/routers/xyz.py b/xpublish/routers/xyz.py new file mode 100644 index 00000000..ab406f85 --- /dev/null +++ b/xpublish/routers/xyz.py @@ -0,0 +1,79 @@ +import xarray as xr +import cachey +from fastapi import APIRouter, Depends, Response, Query, Path +from typing import Optional + +from xpublish.utils.cache import CostTimer +from xpublish.utils.api import DATASET_ID_ATTR_KEY +from xpublish.dependencies import get_dataset, get_cache +from xpublish.utils.ows import ( + get_image_datashader, + get_bounds, + LayerOptionsMixin, + get_tiles, +) + + +class XYZRouter(APIRouter, LayerOptionsMixin): + pass + + +xyz_router = XYZRouter() + + +def query_builder(time, xleft, xright, ybottom, ytop, xlab, ylab): + query = {} + query.update({xlab: slice(xleft, xright), ylab: slice(ytop, ybottom)}) + if time: + query["time"] = time + return query + + +@xyz_router.get("/tiles/{var}/{z}/{x}/{y}") +@xyz_router.get("/tiles/{var}/{z}/{x}/{y}.{format}") +async def tiles( + var: str = Path( + ..., description="Dataset's variable. It defines the map's data layer" + ), + z: int = Path(..., description="Tiles' zoom level"), + x: int = Path(..., description="Tiles' column"), + y: int = Path(..., description="Tiles' row"), + format: str = Query("PNG", description="Image format. Default to PNG"), + time: str = Query( + None, + description="Filter by time in time-varying datasets. String time format should match dataset's time format", + ), + xlab: str = Query("x", description="Dataset x coordinate label"), + ylab: str = Query("y", description="Dataset y coordinate label"), + cache: cachey.Cache = Depends(get_cache), + dataset: xr.Dataset = Depends(get_dataset), +): + + # color mapping settings + datashader_settings = getattr(xyz_router, "datashader_settings") + + TMS = getattr(xyz_router, "TMS") + + xleft, xright, ybottom, ytop = get_bounds(TMS, z, x, y) + + query = query_builder(time, xleft, xright, ybottom, ytop, xlab, ylab) + + cache_key = ( + dataset.attrs.get(DATASET_ID_ATTR_KEY, "") + + "/" + + f"/tiles/{var}/{z}/{x}/{y}.{format}?{time}" + ) + response = cache.get(cache_key) + + if response is None: + with CostTimer() as ct: + + tile = get_tiles(var, dataset, query) + + byte_image = get_image_datashader(tile, datashader_settings, format) + + response = Response(content=byte_image, media_type=f"image/{format}") + + cache.put(cache_key, response, ct.time, len(byte_image)) + + return response diff --git a/xpublish/utils/ows.py b/xpublish/utils/ows.py new file mode 100644 index 00000000..055fdf57 --- /dev/null +++ b/xpublish/utils/ows.py @@ -0,0 +1,89 @@ +from datashader import transfer_functions as tf +import datashader as ds +import xarray as xr +from fastapi import HTTPException +import morecantile + + +# From Morecantile, morecantile.tms.list() +WEB_CRS = { + 3857: "WebMercatorQuad", + 32631: "UTM31WGS84Quad", + 3978: "CanadianNAD83_LCC", + 5482: "LINZAntarticaMapTilegrid", + 4326: "WorldCRS84Quad", + 5041: "UPSAntarcticWGS84Quad", + 3035: "EuropeanETRS89_LAEAQuad", + 3395: "WorldMercatorWGS84Quad", + 2193: "NZTM2000", +} + + +class DataValidationError(KeyError): + pass + + +class LayerOptionsMixin: + def set_options(self, crs_epsg: int = 3857, color_mapping: dict = {}) -> None: + + self.datashader_settings = color_mapping.get("datashader_settings") + self.matplotlib_settings = color_mapping.get("matplotlib_settings") + + if crs_epsg not in WEB_CRS.keys(): + raise DataValidationError(f"User input {crs_epsg} not supported") + + self.TMS = morecantile.tms.get(WEB_CRS[crs_epsg]) + + +def get_bounds(TMS, zoom, x, y): + + bbx = TMS.xy_bounds(morecantile.Tile(int(x), int(y), int(zoom))) + + return bbx.left, bbx.right, bbx.bottom, bbx.top + + +def get_tiles(var, dataset, query) -> xr.DataArray: + + if query.get("time"): + tile = dataset[var].sel(query) # noqa + else: + tile = dataset[var].sel(query) # noqa + + if 0 in tile.sizes.values(): + raise HTTPException(status_code=406, detail=f"Map outside dataset domain") + + return tile + + +def get_image_datashader(tile, datashader_settings, format): + + raster_param = datashader_settings.get("raster", {}) + shade_param = datashader_settings.get("shade", {"cmap": ["blue", "red"]}) + + cvs = ds.Canvas(plot_width=256, plot_height=256) + + agg = cvs.raster(tile, **raster_param) + + img = tf.shade(agg, **shade_param) + + img_io = img.to_bytesio(format) + + return img_io.read() + + +def get_legend(): + pass + + +def validate_dataset(dataset): + dims = dataset.dims + if "x" not in dims or "y" not in dims: + raise DataValidationError( + f" Expected spatial dimension names 'x' and 'y', found: {dims}" + ) + if "time" not in dims and len(dims) >= 3: + raise DataValidationError( + f" Expected time dimension name 'time', found: {dims}" + ) + if len(dims) > 4: + raise DataValidationError(f" Not implemented for dimensions > 4")