Skip to content

Commit

Permalink
Fix bugs in cultivated and woody cover plugins (#149)
Browse files Browse the repository at this point in the history
* round the predict output to float32 resolution

* correct the data type in woody cover

* exit gracefully if band is missing for cultivated and woody

* fix woody cover aggregation

* add the hacky fix

* comment the docker file

* please docker lint

---------

Co-authored-by: Emma Ai <[email protected]>
  • Loading branch information
emmaai and Emma Ai authored Aug 26, 2024
1 parent db360b0 commit 6a29afa
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 4 deletions.
5 changes: 5 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ ENV GDAL_DRIVER_PATH=/env/lib/gdalplugins \
GDAL_DATA=/env/share/gdal \
PATH=/env/bin:$PATH

# here is very hacky fix for the threading issue
# MUST follow up with package owner and further address the issue accordingly

RUN wget -q -O /env/lib/python3.10/site-packages/numexpr/necompiler.py https://raw.githubusercontent.com/emmaai/numexpr/master/numexpr/necompiler.py

WORKDIR /tmp

RUN odc-stats --version
16 changes: 15 additions & 1 deletion odc/stats/plugins/lc_ml_treelite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, Sequence, Optional

import os
import sys
import numpy as np
import numexpr as ne
import xarray as xr
Expand All @@ -21,6 +22,7 @@
from ._registry import StatsPluginInterface
from ._worker import TreeliteModelPlugin
import tl2cgen
import logging


def mask_and_predict(
Expand All @@ -44,6 +46,8 @@ def mask_and_predict(
if block_masked.shape[0] > 0:
dmat = tl2cgen.DMatrix(block_masked)
output_data = predictor.predict(dmat).squeeze(axis=1)
# round the number to float32 resolution
output_data = np.round(output_data, 6)
if ptype == "categorical":
prediction[mask_flat] = output_data.argmax(axis=-1)[..., np.newaxis]
else:
Expand All @@ -70,6 +74,7 @@ def __init__(
self.dask_worker_plugin = TreeliteModelPlugin(model_path)
self.output_classes = output_classes
self.mask_bands = mask_bands
self._log = logging.getLogger(__name__)

def input_data(
self, datasets: Sequence[Dataset], geobox: GeoBox, **kwargs
Expand Down Expand Up @@ -117,6 +122,7 @@ def input_data(

def preprocess_predict_input(self, xx: xr.Dataset):
images = []
veg_mask = None
for var in xx.data_vars:
image = xx[var].data
if var not in self.mask_bands:
Expand All @@ -140,6 +146,9 @@ def preprocess_predict_input(self, xx: xr.Dataset):
**{"_v": int(self.mask_bands[var])},
)

if veg_mask is None:
raise TypeError("Missing Veg Mask")

images = [
da.concatenate([image, veg_mask[..., np.newaxis]], axis=-1).rechunk(
(None, None, image.shape[-1] + veg_mask.shape[-1])
Expand All @@ -157,7 +166,12 @@ def aggregate_results_from_group(self, predict_output):
pass

def reduce(self, xx: xr.Dataset) -> xr.Dataset:
images = self.preprocess_predict_input(xx)
try:
images = self.preprocess_predict_input(xx)
except TypeError as e:
self._log.warning(e)
sys.exit(0)

res = []

for image in images:
Expand Down
6 changes: 3 additions & 3 deletions odc/stats/plugins/lc_treelite_woody.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ def aggregate_results_from_group(self, predict_output):
)

if m_size > 1:
predict_output = predict_output.sum(axis=0).astype("int")
predict_output = predict_output.sum(axis=0)

predict_output = expr_eval(
"where((a/nodata)>=_l, nodata, a%nodata)",
{"a": predict_output},
name="summary_over_classes",
dtype="uint8",
dtype="float32",
**{
"_l": m_size,
"nodata": NODATA,
Expand All @@ -80,7 +80,7 @@ def aggregate_results_from_group(self, predict_output):
"where((a>0)&(a<nodata), _nw, a)",
{"a": predict_output},
name="output_classes_herbaceous",
dtype="uint8",
dtype="float32",
**{"nodata": NODATA, "_nw": self.output_classes["herbaceous"]},
)

Expand Down
4 changes: 4 additions & 0 deletions tests/test_rf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ def test_cultivated_reduce(
== np.array([[112, 255], [112, 112]], dtype="uint8")
).all()

with pytest.raises(SystemExit) as excinfo:
cultivated.reduce(input_datasets.drop("classes_l3_l4"))
assert excinfo.value.code == 0


def test_woody_aggregate_results(
woody_input_bands,
Expand Down

0 comments on commit 6a29afa

Please sign in to comment.