diff --git a/xpublish_wms/plugin.py b/xpublish_wms/plugin.py index c21070b..1d9c66e 100644 --- a/xpublish_wms/plugin.py +++ b/xpublish_wms/plugin.py @@ -33,11 +33,11 @@ def dataset_router(self, deps: Dependencies) -> APIRouter: @router.get("", include_in_schema=False) @router.get("/") - def wms_root( + async def wms_root( request: Request, dataset: xr.Dataset = Depends(deps.dataset), cache: cachey.Cache = Depends(deps.cache), ): - return wms_handler(request, dataset, cache) + return await wms_handler(request, dataset, cache) return router diff --git a/xpublish_wms/wms/__init__.py b/xpublish_wms/wms/__init__.py index f2def79..fd2ef35 100644 --- a/xpublish_wms/wms/__init__.py +++ b/xpublish_wms/wms/__init__.py @@ -2,6 +2,7 @@ OGC WMS router for datasets with CF convention metadata """ +import asyncio import logging import cachey @@ -21,7 +22,7 @@ logger = logging.getLogger("uvicorn") -def wms_handler( +async def wms_handler( request: Request, dataset: xr.Dataset = Depends(get_dataset), cache: cachey.Cache = Depends(get_cache), @@ -32,19 +33,19 @@ def wms_handler( logger.info(f"WMS: {method}") if method == "getcapabilities": - return get_capabilities(dataset, request, query_params) + return await asyncio.to_thread(get_capabilities, dataset, request, query_params) elif method == "getmap": getmap_service = GetMap(cache=cache) - return getmap_service.get_map(dataset, query_params) + return await getmap_service.get_map(dataset, query_params) elif method == "getfeatureinfo" or method == "gettimeseries": - return get_feature_info(dataset, query_params) + return await asyncio.to_thread(get_feature_info, dataset, query_params) elif method == "getverticalprofile": query_params["elevation"] = "all" - return get_feature_info(dataset, query_params) + return await asyncio.to_thread(get_feature_info, dataset, query_params) elif method == "getmetadata": - return get_metadata(dataset, cache, query_params) + return await get_metadata(dataset, cache, query_params) elif method == "getlegendgraphic": - return get_legend_info(dataset, query_params) + return await asyncio.to_thread(get_legend_info, dataset, query_params) else: raise HTTPException( status_code=404, diff --git a/xpublish_wms/wms/get_map.py b/xpublish_wms/wms/get_map.py index 8f09f38..e85acf7 100644 --- a/xpublish_wms/wms/get_map.py +++ b/xpublish_wms/wms/get_map.py @@ -1,3 +1,4 @@ +import asyncio import io import logging import time @@ -32,8 +33,6 @@ class GetMap: DEFAULT_STYLE: str = "raster/default" DEFAULT_PALETTE: str = "turbo" - BBOX_BUFFER = 0.18 - cache: cachey.Cache # Data selection @@ -58,7 +57,7 @@ class GetMap: def __init__(self, cache: cachey.Cache): self.cache = cache - def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse: + async def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse: """ Return the WMS map for the dataset and given parameters """ @@ -76,13 +75,13 @@ def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse: # The grid type for now. This can be revisited if we choose to interpolate or # use the contoured renderer for regular grid datasets image_buffer = io.BytesIO() - render_result = self.render(ds, da, image_buffer, False) + render_result = await self.render(ds, da, image_buffer, False) if render_result: image_buffer.seek(0) return StreamingResponse(image_buffer, media_type="image/png") - def get_minmax(self, ds: xr.Dataset, query: dict) -> dict: + async def get_minmax(self, ds: xr.Dataset, query: dict) -> dict: """ Return the range of values for the dataset and given parameters """ @@ -109,7 +108,7 @@ def get_minmax(self, ds: xr.Dataset, query: dict) -> dict: if entire_layer: return {"min": float(da.min()), "max": float(da.max())} else: - return self.render(ds, da, None, minmax_only=True) + return await self.render(ds, da, None, minmax_only=True) def ensure_query_types(self, ds: xr.Dataset, query: dict): """ @@ -255,7 +254,7 @@ def select_custom_dim(self, da: xr.DataArray) -> xr.DataArray: return da - def render( + async def render( self, ds: xr.Dataset, da: xr.DataArray, @@ -280,28 +279,41 @@ def render( logger.warning("Falling back to default minmax") return {"min": float(da.min()), "max": float(da.max())} + # x and y are only set for triangle grids, we dont subset the data for triangle grids + # at this time. + if x is None: + try: + # Grab a buffer around the bbox to ensure we have enough data to render + # TODO: Base this on actual data resolution? + if self.crs == "EPSG:4326": + coord_buffer = 0.5 # degrees + elif self.crs == "EPSG:3857": + coord_buffer = 30000 # meters + else: + # Default to 0.5, this should never happen + coord_buffer = 0.5 + + # Filter the data to only include the data within the bbox + buffer so + # we don't have to render a ton of empty space or pull down more chunks + # than we need + da = filter_data_within_bbox(da, self.bbox, coord_buffer) + except Exception as e: + logger.error(f"Error filtering data within bbox: {e}") + logger.warning("Falling back to full layer") + logger.debug(f"Projection time: {time.time() - projection_start}") start_dask = time.time() - da = da.persist() - if x is not None and y is not None: - x = x.persist() - y = y.persist() - else: - da["x"] = da.x.persist() - da["y"] = da.y.persist() + da = await asyncio.to_thread(da.compute) logger.debug(f"dask compute: {time.time() - start_dask}") if minmax_only: - da = da.persist() - data_sel = filter_data_within_bbox(da, self.bbox, self.BBOX_BUFFER) - try: return { - "min": float(np.nanmin(data_sel)), - "max": float(np.nanmax(data_sel)), + "min": float(np.nanmin(da)), + "max": float(np.nanmax(da)), } except Exception as e: logger.error( diff --git a/xpublish_wms/wms/get_metadata.py b/xpublish_wms/wms/get_metadata.py index 72064fb..e5baf8d 100644 --- a/xpublish_wms/wms/get_metadata.py +++ b/xpublish_wms/wms/get_metadata.py @@ -11,7 +11,7 @@ from .get_map import GetMap -def get_metadata(ds: xr.Dataset, cache: cachey.Cache, params: dict) -> Response: +async def get_metadata(ds: xr.Dataset, cache: cachey.Cache, params: dict) -> Response: """ Return the WMS metadata for the dataset @@ -39,7 +39,7 @@ def get_metadata(ds: xr.Dataset, cache: cachey.Cache, params: dict) -> Response: da = ds[layer_name] payload = get_timesteps(da, params) elif metadata_type == "minmax": - payload = get_minmax(ds, cache, params) + payload = await get_minmax(ds, cache, params) else: raise HTTPException( status_code=400, @@ -79,14 +79,14 @@ def get_timesteps(da: xr.DataArray, params: dict) -> dict: } -def get_minmax(ds: xr.Dataset, cache: cachey.Cache, params: dict) -> dict: +async def get_minmax(ds: xr.Dataset, cache: cachey.Cache, params: dict) -> dict: """ Returns the min and max range of values for a given layer in a given area If BBOX is not specified, the entire selected temporal and elevation range is used. """ getmap = GetMap(cache=cache) - return getmap.get_minmax(ds, params) + return await getmap.get_minmax(ds, params) def get_layer_details(ds: xr.Dataset, layer_name: str) -> dict: