Skip to content

Commit

Permalink
Change cultivated aggregation logic and make sensor choice an option (#…
Browse files Browse the repository at this point in the history
…165)

* fix the nodata issue when default to nan with float dtype

* flexible to the case that numpy change convention

* make cultivated classification conservative

* filter data by sensor temporal coverage in ml

* correct time dim and format

* change woody cover band name and fix tests

* align cultivated band name to official

---------

Co-authored-by: Emma Ai <[email protected]>
  • Loading branch information
emmaai and Emma Ai authored Nov 7, 2024
1 parent 182de0d commit 000f79b
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 40 deletions.
70 changes: 55 additions & 15 deletions odc/stats/plugins/lc_ml_treelite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from odc.algo.io import load_with_native_transform

from odc.stats._algebra import expr_eval
from odc.stats.model import DateTimeRange
from ._registry import StatsPluginInterface
from ._worker import TreeliteModelPlugin
import tl2cgen
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
output_classes: Dict,
model_path: str,
mask_bands: Optional[Dict] = None,
temporal_coverage: Optional[Dict] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -74,6 +76,7 @@ def __init__(
self.dask_worker_plugin = TreeliteModelPlugin(model_path)
self.output_classes = output_classes
self.mask_bands = mask_bands
self.temporal_coverage = temporal_coverage
self._log = logging.getLogger(__name__)

def input_data(
Expand Down Expand Up @@ -107,40 +110,77 @@ def input_data(
(self.chunks["x"], self.chunks["y"], -1, -1),
dtype="float32",
name=ds.type.name + "_yxbt",
).squeeze("spec", drop=True)
).squeeze("spec")
data_vars[ds.type.name] = input_array
else:
for var in xx.data_vars:
data_vars[var] = yxt_sink(
xx[var].astype("uint8"),
(self.chunks["x"], self.chunks["y"], -1),
name=ds.type.name + "_yxt",
).squeeze("spec", drop=True)
).squeeze("spec")

coords = dict((dim, input_array.coords[dim]) for dim in input_array.dims)
return xr.Dataset(data_vars=data_vars, coords=coords)

def impute_missing_values(self, xx: xr.Dataset, image):
imputed = None
for var in xx.data_vars:
if var in self.mask_bands:
continue
nodata = xx[var].attrs.get("nodata", -999)
imputed = expr_eval(
"where((a==a)|(b<=nodata)|(b!=b), a, b)",
{
"a": image,
"b": xx[var].data,
},
name="impute_missing",
dtype="float32",
**{"nodata": nodata},
)
return imputed if imputed is not None else image

def preprocess_predict_input(self, xx: xr.Dataset):
images = []
veg_mask = None

def convert_dtype(var):
nodata = xx[var].attrs.get("nodata", -999)
image = expr_eval(
"where((a<=nodata), _nan, a)",
{
"a": xx[var].data,
},
name="convert_dtype",
dtype="float32",
**{"nodata": nodata, "_nan": np.nan},
)
return image

for var in xx.data_vars:
image = xx[var].data
if var not in self.mask_bands:
nodata = xx[var].attrs.get("nodata", -999)
image = expr_eval(
"where((a<=nodata), _nan, a)",
{
"a": image,
},
name="convert_dtype",
dtype="float32",
**{"nodata": nodata, "_nan": np.nan},
)
images += [image]
if self.temporal_coverage is not None:
# filter and impute by sensors
temporal_range = [
DateTimeRange(v) for v in self.temporal_coverage.get(var)
]
for tr in temporal_range:
if xx.solar_day.data.astype("M8[ms]") in tr:
self._log.info("Impute missing values of %s", var)
image = convert_dtype(var)
images += [
self.impute_missing_values(xx.drop_vars(var), image)
]
break
else:
# use data from all sensors
image = convert_dtype(var)
images += [image]
else:
veg_mask = expr_eval(
"where(a==_v, 1, 0)",
{"a": image},
{"a": xx[var].data},
name="make_mask",
dtype="float32",
**{"_v": int(self.mask_bands[var])},
Expand Down
22 changes: 7 additions & 15 deletions odc/stats/plugins/lc_treelite_cultivated.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class StatsCultivatedClass(StatsMLTree):

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["cultivated_class"]
_measurements = ["cultivated"]
return _measurements

def predict(self, input_array):
Expand Down Expand Up @@ -259,7 +259,7 @@ def predict(self, input_array):

def aggregate_results_from_group(self, predict_output):
# if there are >= 2 images
# any is cultivated -> final class is cultivated
# any is natural -> final class is natrual
# any is valid -> final class is valid
# for each pixel
m_size = len(predict_output)
Expand All @@ -268,14 +268,6 @@ def aggregate_results_from_group(self, predict_output):
else:
predict_output = predict_output[0]

predict_output = expr_eval(
"where(a<nodata, 1-a, a)",
{"a": predict_output},
name="invert_output",
dtype="float32",
**{"nodata": NODATA},
)

if m_size > 1:
predict_output = predict_output.sum(axis=0)

Expand All @@ -290,17 +282,17 @@ def aggregate_results_from_group(self, predict_output):
predict_output = expr_eval(
"where((a>0)&(a<nodata), _u, a)",
{"a": predict_output},
name="output_classes_cultivated",
name="output_classes_natural",
dtype="float32",
**{"_u": self.output_classes["cultivated"], "nodata": NODATA},
**{"_u": self.output_classes["natural"], "nodata": NODATA},
)

predict_output = expr_eval(
"where(a<=0, _nu, a)",
{"a": predict_output},
name="output_classes_natural",
name="output_classes_cultivated",
dtype="uint8",
**{"_nu": self.output_classes["natural"]},
**{"_nu": self.output_classes["cultivated"]},
)

return predict_output.rechunk(-1, -1)
Expand All @@ -312,7 +304,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = res[var].attrs.copy()
attrs["nodata"] = int(NODATA)
res[var].attrs = attrs
var_rename = {var: "cultivated_class"}
var_rename = dict(zip(res.data_vars, self.measurements))
return res.rename(var_rename)


Expand Down
4 changes: 2 additions & 2 deletions odc/stats/plugins/lc_treelite_woody.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class StatsWoodyCover(StatsMLTree):

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["woody_cover"]
_measurements = ["woody"]
return _measurements

def predict(self, input_array):
Expand Down Expand Up @@ -101,7 +101,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = res[var].attrs.copy()
attrs["nodata"] = int(NODATA)
res[var].attrs = attrs
var_rename = {var: "woody_cover"}
var_rename = dict(zip(res.data_vars, self.measurements))
return res.rename(var_rename)


Expand Down
15 changes: 7 additions & 8 deletions tests/test_rf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def test_cultivated_aggregate_results(
res = cultivated.aggregate_results_from_group([cultivated_results[0]])
assert (res.compute() == np.array([[112, 255], [111, 112]], dtype="uint8")).all()
res = cultivated.aggregate_results_from_group(cultivated_results)
assert (res.compute() == np.array([[111, 112], [111, 112]], dtype="uint8")).all()
assert (res.compute() == np.array([[112, 112], [111, 112]], dtype="uint8")).all()


def test_cultivated_reduce(
Expand All @@ -485,10 +485,10 @@ def test_cultivated_reduce(
)
dask_client.register_plugin(cultivated.dask_worker_plugin)
res = cultivated.reduce(input_datasets)
assert res["cultivated_class"].attrs["nodata"] == 255
assert res["cultivated_class"].data.dtype == "uint8"
assert res["cultivated"].attrs["nodata"] == 255
assert res["cultivated"].data.dtype == "uint8"
assert (
res["cultivated_class"].data.compute()
res["cultivated"].data.compute()
== np.array([[112, 255], [112, 112]], dtype="uint8")
).all()

Expand Down Expand Up @@ -528,9 +528,8 @@ def test_woody_reduce(
)
dask_client.register_plugin(woody_cover.dask_worker_plugin)
res = woody_cover.reduce(woody_inputs)
assert res["woody_cover"].attrs["nodata"] == 255
assert res["woody_cover"].data.dtype == "uint8"
assert res["woody"].attrs["nodata"] == 255
assert res["woody"].data.dtype == "uint8"
assert (
res["woody_cover"].data.compute()
== np.array([[114, 255], [114, 114]], dtype="uint8")
res["woody"].data.compute() == np.array([[114, 255], [114, 114]], dtype="uint8")
).all()

0 comments on commit 000f79b

Please sign in to comment.