Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch WMS router to use async #98

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions xpublish_wms/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def dataset_router(self, deps: Dependencies) -> APIRouter:

@router.get("", include_in_schema=False)
@router.get("/")
def wms_root(
async def wms_root(
request: Request,
dataset: xr.Dataset = Depends(deps.dataset),
cache: cachey.Cache = Depends(deps.cache),
):
return wms_handler(request, dataset, cache)
return await wms_handler(request, dataset, cache)

return router
15 changes: 8 additions & 7 deletions xpublish_wms/wms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
OGC WMS router for datasets with CF convention metadata
"""

import asyncio
import logging

import cachey
Expand All @@ -21,7 +22,7 @@
logger = logging.getLogger("uvicorn")


def wms_handler(
async def wms_handler(
request: Request,
dataset: xr.Dataset = Depends(get_dataset),
cache: cachey.Cache = Depends(get_cache),
Expand All @@ -32,19 +33,19 @@ def wms_handler(
logger.info(f"WMS: {method}")

if method == "getcapabilities":
return get_capabilities(dataset, request, query_params)
return await asyncio.to_thread(get_capabilities, dataset, request, query_params)
elif method == "getmap":
getmap_service = GetMap(cache=cache)
return getmap_service.get_map(dataset, query_params)
return await getmap_service.get_map(dataset, query_params)
elif method == "getfeatureinfo" or method == "gettimeseries":
return get_feature_info(dataset, query_params)
return await asyncio.to_thread(get_feature_info, dataset, query_params)
elif method == "getverticalprofile":
query_params["elevation"] = "all"
return get_feature_info(dataset, query_params)
return await asyncio.to_thread(get_feature_info, dataset, query_params)
elif method == "getmetadata":
return get_metadata(dataset, cache, query_params)
return await get_metadata(dataset, cache, query_params)
elif method == "getlegendgraphic":
return get_legend_info(dataset, query_params)
return await asyncio.to_thread(get_legend_info, dataset, query_params)
else:
raise HTTPException(
status_code=404,
Expand Down
50 changes: 31 additions & 19 deletions xpublish_wms/wms/get_map.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import io
import logging
import time
Expand Down Expand Up @@ -32,8 +33,6 @@ class GetMap:
DEFAULT_STYLE: str = "raster/default"
DEFAULT_PALETTE: str = "turbo"

BBOX_BUFFER = 0.18

cache: cachey.Cache

# Data selection
Expand All @@ -58,7 +57,7 @@ class GetMap:
def __init__(self, cache: cachey.Cache):
self.cache = cache

def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse:
async def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse:
"""
Return the WMS map for the dataset and given parameters
"""
Expand All @@ -76,13 +75,13 @@ def get_map(self, ds: xr.Dataset, query: dict) -> StreamingResponse:
# The grid type for now. This can be revisited if we choose to interpolate or
# use the contoured renderer for regular grid datasets
image_buffer = io.BytesIO()
render_result = self.render(ds, da, image_buffer, False)
render_result = await self.render(ds, da, image_buffer, False)
if render_result:
image_buffer.seek(0)

return StreamingResponse(image_buffer, media_type="image/png")

def get_minmax(self, ds: xr.Dataset, query: dict) -> dict:
async def get_minmax(self, ds: xr.Dataset, query: dict) -> dict:
"""
Return the range of values for the dataset and given parameters
"""
Expand All @@ -109,7 +108,7 @@ def get_minmax(self, ds: xr.Dataset, query: dict) -> dict:
if entire_layer:
return {"min": float(da.min()), "max": float(da.max())}
else:
return self.render(ds, da, None, minmax_only=True)
return await self.render(ds, da, None, minmax_only=True)

def ensure_query_types(self, ds: xr.Dataset, query: dict):
"""
Expand Down Expand Up @@ -255,7 +254,7 @@ def select_custom_dim(self, da: xr.DataArray) -> xr.DataArray:

return da

def render(
async def render(
self,
ds: xr.Dataset,
da: xr.DataArray,
Expand All @@ -280,28 +279,41 @@ def render(
logger.warning("Falling back to default minmax")
return {"min": float(da.min()), "max": float(da.max())}

# x and y are only set for triangle grids, we dont subset the data for triangle grids
# at this time.
if x is None:
try:
# Grab a buffer around the bbox to ensure we have enough data to render
# TODO: Base this on actual data resolution?
if self.crs == "EPSG:4326":
coord_buffer = 0.5 # degrees
elif self.crs == "EPSG:3857":
coord_buffer = 30000 # meters
else:
# Default to 0.5, this should never happen
coord_buffer = 0.5

# Filter the data to only include the data within the bbox + buffer so
# we don't have to render a ton of empty space or pull down more chunks
# than we need
da = filter_data_within_bbox(da, self.bbox, coord_buffer)
except Exception as e:
logger.error(f"Error filtering data within bbox: {e}")
logger.warning("Falling back to full layer")

logger.debug(f"Projection time: {time.time() - projection_start}")

start_dask = time.time()

da = da.persist()
if x is not None and y is not None:
x = x.persist()
y = y.persist()
else:
da["x"] = da.x.persist()
da["y"] = da.y.persist()
da = await asyncio.to_thread(da.compute)

logger.debug(f"dask compute: {time.time() - start_dask}")

if minmax_only:
da = da.persist()
data_sel = filter_data_within_bbox(da, self.bbox, self.BBOX_BUFFER)

try:
return {
"min": float(np.nanmin(data_sel)),
"max": float(np.nanmax(data_sel)),
"min": float(np.nanmin(da)),
"max": float(np.nanmax(da)),
}
except Exception as e:
logger.error(
Expand Down
8 changes: 4 additions & 4 deletions xpublish_wms/wms/get_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .get_map import GetMap


def get_metadata(ds: xr.Dataset, cache: cachey.Cache, params: dict) -> Response:
async def get_metadata(ds: xr.Dataset, cache: cachey.Cache, params: dict) -> Response:
"""
Return the WMS metadata for the dataset

Expand Down Expand Up @@ -39,7 +39,7 @@ def get_metadata(ds: xr.Dataset, cache: cachey.Cache, params: dict) -> Response:
da = ds[layer_name]
payload = get_timesteps(da, params)
elif metadata_type == "minmax":
payload = get_minmax(ds, cache, params)
payload = await get_minmax(ds, cache, params)
else:
raise HTTPException(
status_code=400,
Expand Down Expand Up @@ -79,14 +79,14 @@ def get_timesteps(da: xr.DataArray, params: dict) -> dict:
}


def get_minmax(ds: xr.Dataset, cache: cachey.Cache, params: dict) -> dict:
async def get_minmax(ds: xr.Dataset, cache: cachey.Cache, params: dict) -> dict:
"""
Returns the min and max range of values for a given layer in a given area

If BBOX is not specified, the entire selected temporal and elevation range is used.
"""
getmap = GetMap(cache=cache)
return getmap.get_minmax(ds, params)
return await getmap.get_minmax(ds, params)


def get_layer_details(ds: xr.Dataset, layer_name: str) -> dict:
Expand Down
Loading