Skip to content

Commit

Permalink
First pass at prototyping async getmap
Browse files Browse the repository at this point in the history
  • Loading branch information
mpiannucci committed Nov 7, 2024
1 parent b8d19cc commit 5ae90d9
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 24 deletions.
4 changes: 2 additions & 2 deletions xpublish_wms/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def dataset_router(self, deps: Dependencies) -> APIRouter:

@router.get("", include_in_schema=False)
@router.get("/")
def wms_root(
async def wms_root(
request: Request,
dataset: xr.Dataset = Depends(deps.dataset),
cache: cachey.Cache = Depends(deps.cache),
):
return wms_handler(request, dataset, cache)
return await wms_handler(request, dataset, cache)

return router
4 changes: 2 additions & 2 deletions xpublish_wms/wms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
logger = logging.getLogger("uvicorn")


def wms_handler(
async def wms_handler(
request: Request,
dataset: xr.Dataset = Depends(get_dataset),
cache: cachey.Cache = Depends(get_cache),
Expand All @@ -35,7 +35,7 @@ def wms_handler(
return get_capabilities(dataset, request, query_params)
elif method == "getmap":
getmap_service = GetMap(cache=cache)
return getmap_service.get_map(dataset, query_params)
return await getmap_service.get_map(dataset, query_params)
elif method == "getfeatureinfo" or method == "gettimeseries":
return get_feature_info(dataset, query_params)
elif method == "getverticalprofile":
Expand Down
52 changes: 32 additions & 20 deletions xpublish_wms/wms/get_map.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import io
import logging
import time
Expand Down Expand Up @@ -32,7 +33,7 @@ class GetMap:
DEFAULT_STYLE: str = "raster/default"
DEFAULT_PALETTE: str = "turbo"

BBOX_BUFFER = 0.18
BBOX_BUFFER = 30_000 # meters

cache: cachey.Cache

Expand All @@ -58,7 +59,7 @@ class GetMap:
def __init__(self, cache: cachey.Cache):
self.cache = cache

def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse:
async def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse:
"""
Return the WMS map for the dataset and given parameters
"""
Expand All @@ -76,13 +77,13 @@ def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse:
# The grid type for now. This can be revisited if we choose to interpolate or
# use the contoured renderer for regular grid datasets
image_buffer = io.BytesIO()
render_result = self.render(ds, da, image_buffer, False)
render_result = await self.render(ds, da, image_buffer, False)
if render_result:
image_buffer.seek(0)

return StreamingResponse(image_buffer, media_type="image/png")

def get_minmax(self, ds: xr.Dataset, query: dict) -> dict:
async def get_minmax(self, ds: xr.Dataset, query: dict) -> dict:
"""
Return the range of values for the dataset and given parameters
"""
Expand All @@ -109,7 +110,7 @@ def get_minmax(self, ds: xr.Dataset, query: dict) -> dict:
if entire_layer:
return {"min": float(da.min()), "max": float(da.max())}
else:
return self.render(ds, da, None, minmax_only=True)
return await self.render(ds, da, None, minmax_only=True)

def ensure_query_types(self, ds: xr.Dataset, query: dict):
"""
Expand Down Expand Up @@ -255,7 +256,7 @@ def select_custom_dim(self, da: xr.DataArray) -> xr.DataArray:

return da

def render(
async def render(
self,
ds: xr.Dataset,
da: xr.DataArray,
Expand All @@ -279,29 +280,40 @@ def render(
if minmax_only:
logger.warning("Falling back to default minmax")
return {"min": float(da.min()), "max": float(da.max())}

try:
da = filter_data_within_bbox(da, self.bbox, self.BBOX_BUFFER)
except Exception as e:
logger.error(f"Error filtering data within bbox: {e}")
logger.warning("Falling back to full layer")

logger.debug(f"Projection time: {time.time() - projection_start}")
print(f"Projection time: {time.time() - projection_start}")

start_dask = time.time()

da = da.persist()
if x is not None and y is not None:
x = x.persist()
y = y.persist()
else:
da["x"] = da.x.persist()
da["y"] = da.y.persist()
da = await asyncio.to_thread(da.compute)

# da = da.persist()
# if x is not None and y is not None:
# x = x.persist()
# y = y.persist()
# else:
# da["x"] = da.x.persist()
# da["y"] = da.y.persist()

print(da.x[1].values -da.x[0].values)
print(da.y[1].values - da.y[0].values)

logger.debug(f"dask compute: {time.time() - start_dask}")
print(f"dask compute: {time.time() - start_dask}")

if minmax_only:
da = da.persist()
data_sel = filter_data_within_bbox(da, self.bbox, self.BBOX_BUFFER)
# da = da.persist()
# data_sel = filter_data_within_bbox(da, self.bbox, self.BBOX_BUFFER)

try:
return {
"min": float(np.nanmin(data_sel)),
"max": float(np.nanmax(data_sel)),
"min": float(np.nanmin(da)),
"max": float(np.nanmax(da)),
}
except Exception as e:
logger.error(
Expand Down Expand Up @@ -354,7 +366,7 @@ def render(
how="linear",
span=(vmin, vmax),
)
logger.debug(f"Shade time: {time.time() - start_shade}")
print(f"Shade time: {time.time() - start_shade}")

im = shaded.to_pil()
im.save(buffer, format="PNG")
Expand Down

0 comments on commit 5ae90d9

Please sign in to comment.