From 5ae90d9171f0289427bbbbf6e13e2a4ec1f148cc Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 7 Nov 2024 09:56:03 -0500 Subject: [PATCH] First pass at prototyping async getmap --- xpublish_wms/plugin.py | 4 +-- xpublish_wms/wms/__init__.py | 4 +-- xpublish_wms/wms/get_map.py | 52 ++++++++++++++++++++++-------------- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/xpublish_wms/plugin.py b/xpublish_wms/plugin.py index c21070b..1d9c66e 100644 --- a/xpublish_wms/plugin.py +++ b/xpublish_wms/plugin.py @@ -33,11 +33,11 @@ def dataset_router(self, deps: Dependencies) -> APIRouter: @router.get("", include_in_schema=False) @router.get("/") - def wms_root( + async def wms_root( request: Request, dataset: xr.Dataset = Depends(deps.dataset), cache: cachey.Cache = Depends(deps.cache), ): - return wms_handler(request, dataset, cache) + return await wms_handler(request, dataset, cache) return router diff --git a/xpublish_wms/wms/__init__.py b/xpublish_wms/wms/__init__.py index f2def79..b6dcf00 100644 --- a/xpublish_wms/wms/__init__.py +++ b/xpublish_wms/wms/__init__.py @@ -21,7 +21,7 @@ logger = logging.getLogger("uvicorn") -def wms_handler( +async def wms_handler( request: Request, dataset: xr.Dataset = Depends(get_dataset), cache: cachey.Cache = Depends(get_cache), @@ -35,7 +35,7 @@ def wms_handler( return get_capabilities(dataset, request, query_params) elif method == "getmap": getmap_service = GetMap(cache=cache) - return getmap_service.get_map(dataset, query_params) + return await getmap_service.get_map(dataset, query_params) elif method == "getfeatureinfo" or method == "gettimeseries": return get_feature_info(dataset, query_params) elif method == "getverticalprofile": diff --git a/xpublish_wms/wms/get_map.py b/xpublish_wms/wms/get_map.py index 8f09f38..53f5ac9 100644 --- a/xpublish_wms/wms/get_map.py +++ b/xpublish_wms/wms/get_map.py @@ -1,3 +1,4 @@ +import asyncio import io import logging import time @@ -32,7 +33,7 @@ class GetMap: DEFAULT_STYLE: str = "raster/default" DEFAULT_PALETTE: str = "turbo" - BBOX_BUFFER = 0.18 + BBOX_BUFFER = 30_000 # meters cache: cachey.Cache @@ -58,7 +59,7 @@ class GetMap: def __init__(self, cache: cachey.Cache): self.cache = cache - def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse: + async def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse: """ Return the WMS map for the dataset and given parameters """ @@ -76,13 +77,13 @@ def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse: # The grid type for now. This can be revisited if we choose to interpolate or # use the contoured renderer for regular grid datasets image_buffer = io.BytesIO() - render_result = self.render(ds, da, image_buffer, False) + render_result = await self.render(ds, da, image_buffer, False) if render_result: image_buffer.seek(0) return StreamingResponse(image_buffer, media_type="image/png") - def get_minmax(self, ds: xr.Dataset, query: dict) -> dict: + async def get_minmax(self, ds: xr.Dataset, query: dict) -> dict: """ Return the range of values for the dataset and given parameters """ @@ -109,7 +110,7 @@ def get_minmax(self, ds: xr.Dataset, query: dict) -> dict: if entire_layer: return {"min": float(da.min()), "max": float(da.max())} else: - return self.render(ds, da, None, minmax_only=True) + return await self.render(ds, da, None, minmax_only=True) def ensure_query_types(self, ds: xr.Dataset, query: dict): """ @@ -255,7 +256,7 @@ def select_custom_dim(self, da: xr.DataArray) -> xr.DataArray: return da - def render( + async def render( self, ds: xr.Dataset, da: xr.DataArray, @@ -279,29 +280,40 @@ def render( if minmax_only: logger.warning("Falling back to default minmax") return {"min": float(da.min()), "max": float(da.max())} + + try: + da = filter_data_within_bbox(da, self.bbox, self.BBOX_BUFFER) + except Exception as e: + logger.error(f"Error filtering data within bbox: {e}") + logger.warning("Falling back to full layer") - logger.debug(f"Projection time: {time.time() - projection_start}") + print(f"Projection time: {time.time() - projection_start}") start_dask = time.time() - da = da.persist() - if x is not None and y is not None: - x = x.persist() - y = y.persist() - else: - da["x"] = da.x.persist() - da["y"] = da.y.persist() + da = await asyncio.to_thread(da.compute) + + # da = da.persist() + # if x is not None and y is not None: + # x = x.persist() + # y = y.persist() + # else: + # da["x"] = da.x.persist() + # da["y"] = da.y.persist() + + print(da.x[1].values -da.x[0].values) + print(da.y[1].values - da.y[0].values) - logger.debug(f"dask compute: {time.time() - start_dask}") + print(f"dask compute: {time.time() - start_dask}") if minmax_only: - da = da.persist() - data_sel = filter_data_within_bbox(da, self.bbox, self.BBOX_BUFFER) + # da = da.persist() + # data_sel = filter_data_within_bbox(da, self.bbox, self.BBOX_BUFFER) try: return { - "min": float(np.nanmin(data_sel)), - "max": float(np.nanmax(data_sel)), + "min": float(np.nanmin(da)), + "max": float(np.nanmax(da)), } except Exception as e: logger.error( @@ -354,7 +366,7 @@ def render( how="linear", span=(vmin, vmax), ) - logger.debug(f"Shade time: {time.time() - start_shade}") + print(f"Shade time: {time.time() - start_shade}") im = shaded.to_pil() im.save(buffer, format="PNG")