From 729e2947f89e71ec8745d2e58aac59f728537067 Mon Sep 17 00:00:00 2001 From: Emma Ai Date: Wed, 6 Nov 2024 04:52:09 +0000 Subject: [PATCH] align cultivated band name to official --- odc/stats/plugins/lc_treelite_cultivated.py | 4 ++-- tests/test_rf_models.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/odc/stats/plugins/lc_treelite_cultivated.py b/odc/stats/plugins/lc_treelite_cultivated.py index 64ff8c51..81edc6ab 100644 --- a/odc/stats/plugins/lc_treelite_cultivated.py +++ b/odc/stats/plugins/lc_treelite_cultivated.py @@ -228,7 +228,7 @@ class StatsCultivatedClass(StatsMLTree): @property def measurements(self) -> Tuple[str, ...]: - _measurements = ["cultivated_class"] + _measurements = ["cultivated"] return _measurements def predict(self, input_array): @@ -304,7 +304,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset: attrs = res[var].attrs.copy() attrs["nodata"] = int(NODATA) res[var].attrs = attrs - var_rename = {var: "cultivated_class"} + var_rename = dict(zip(res.data_vars, self.measurements)) return res.rename(var_rename) diff --git a/tests/test_rf_models.py b/tests/test_rf_models.py index 90646a8b..ad16e6f5 100644 --- a/tests/test_rf_models.py +++ b/tests/test_rf_models.py @@ -485,10 +485,10 @@ def test_cultivated_reduce( ) dask_client.register_plugin(cultivated.dask_worker_plugin) res = cultivated.reduce(input_datasets) - assert res["cultivated_class"].attrs["nodata"] == 255 - assert res["cultivated_class"].data.dtype == "uint8" + assert res["cultivated"].attrs["nodata"] == 255 + assert res["cultivated"].data.dtype == "uint8" assert ( - res["cultivated_class"].data.compute() + res["cultivated"].data.compute() == np.array([[112, 255], [112, 112]], dtype="uint8") ).all()