From 6a29afaf48354352f7be8a4f9969d312627ff9fe Mon Sep 17 00:00:00 2001 From: Emma Ai Date: Mon, 26 Aug 2024 10:28:08 +0930 Subject: [PATCH] Fix bugs in cultivated and woody cover plugins (#149) * round the predict output to float32 resolution * correct the data type in woody cover * exit gracefully if band is missing for cultivated and woody * fix woody cover aggregation * add the hacky fix * comment the docker file * please docker lint --------- Co-authored-by: Emma Ai --- docker/Dockerfile | 5 +++++ odc/stats/plugins/lc_ml_treelite.py | 16 +++++++++++++++- odc/stats/plugins/lc_treelite_woody.py | 6 +++--- tests/test_rf_models.py | 4 ++++ 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 00a0026e..50ec182a 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,6 +22,11 @@ ENV GDAL_DRIVER_PATH=/env/lib/gdalplugins \ GDAL_DATA=/env/share/gdal \ PATH=/env/bin:$PATH +# here is very hacky fix for the threading issue +# MUST follow up with package owner and further address the issue accordingly + +RUN wget -q -O /env/lib/python3.10/site-packages/numexpr/necompiler.py https://raw.githubusercontent.com/emmaai/numexpr/master/numexpr/necompiler.py + WORKDIR /tmp RUN odc-stats --version diff --git a/odc/stats/plugins/lc_ml_treelite.py b/odc/stats/plugins/lc_ml_treelite.py index 8334f4bc..dcf1c528 100644 --- a/odc/stats/plugins/lc_ml_treelite.py +++ b/odc/stats/plugins/lc_ml_treelite.py @@ -6,6 +6,7 @@ from typing import Dict, Sequence, Optional import os +import sys import numpy as np import numexpr as ne import xarray as xr @@ -21,6 +22,7 @@ from ._registry import StatsPluginInterface from ._worker import TreeliteModelPlugin import tl2cgen +import logging def mask_and_predict( @@ -44,6 +46,8 @@ def mask_and_predict( if block_masked.shape[0] > 0: dmat = tl2cgen.DMatrix(block_masked) output_data = predictor.predict(dmat).squeeze(axis=1) + # round the number to float32 resolution + output_data = np.round(output_data, 6) if ptype == "categorical": prediction[mask_flat] = output_data.argmax(axis=-1)[..., np.newaxis] else: @@ -70,6 +74,7 @@ def __init__( self.dask_worker_plugin = TreeliteModelPlugin(model_path) self.output_classes = output_classes self.mask_bands = mask_bands + self._log = logging.getLogger(__name__) def input_data( self, datasets: Sequence[Dataset], geobox: GeoBox, **kwargs @@ -117,6 +122,7 @@ def input_data( def preprocess_predict_input(self, xx: xr.Dataset): images = [] + veg_mask = None for var in xx.data_vars: image = xx[var].data if var not in self.mask_bands: @@ -140,6 +146,9 @@ def preprocess_predict_input(self, xx: xr.Dataset): **{"_v": int(self.mask_bands[var])}, ) + if veg_mask is None: + raise TypeError("Missing Veg Mask") + images = [ da.concatenate([image, veg_mask[..., np.newaxis]], axis=-1).rechunk( (None, None, image.shape[-1] + veg_mask.shape[-1]) @@ -157,7 +166,12 @@ def aggregate_results_from_group(self, predict_output): pass def reduce(self, xx: xr.Dataset) -> xr.Dataset: - images = self.preprocess_predict_input(xx) + try: + images = self.preprocess_predict_input(xx) + except TypeError as e: + self._log.warning(e) + sys.exit(0) + res = [] for image in images: diff --git a/odc/stats/plugins/lc_treelite_woody.py b/odc/stats/plugins/lc_treelite_woody.py index 62612a10..9136541a 100644 --- a/odc/stats/plugins/lc_treelite_woody.py +++ b/odc/stats/plugins/lc_treelite_woody.py @@ -63,13 +63,13 @@ def aggregate_results_from_group(self, predict_output): ) if m_size > 1: - predict_output = predict_output.sum(axis=0).astype("int") + predict_output = predict_output.sum(axis=0) predict_output = expr_eval( "where((a/nodata)>=_l, nodata, a%nodata)", {"a": predict_output}, name="summary_over_classes", - dtype="uint8", + dtype="float32", **{ "_l": m_size, "nodata": NODATA, @@ -80,7 +80,7 @@ def aggregate_results_from_group(self, predict_output): "where((a>0)&(a