Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change cultivated aggregation logic and make sensor choice an option #165

Merged
merged 7 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 55 additions & 15 deletions odc/stats/plugins/lc_ml_treelite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from odc.algo.io import load_with_native_transform

from odc.stats._algebra import expr_eval
from odc.stats.model import DateTimeRange
from ._registry import StatsPluginInterface
from ._worker import TreeliteModelPlugin
import tl2cgen
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
output_classes: Dict,
model_path: str,
mask_bands: Optional[Dict] = None,
temporal_coverage: Optional[Dict] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -74,6 +76,7 @@ def __init__(
self.dask_worker_plugin = TreeliteModelPlugin(model_path)
self.output_classes = output_classes
self.mask_bands = mask_bands
self.temporal_coverage = temporal_coverage
self._log = logging.getLogger(__name__)

def input_data(
Expand Down Expand Up @@ -107,40 +110,77 @@ def input_data(
(self.chunks["x"], self.chunks["y"], -1, -1),
dtype="float32",
name=ds.type.name + "_yxbt",
).squeeze("spec", drop=True)
).squeeze("spec")
data_vars[ds.type.name] = input_array
else:
for var in xx.data_vars:
data_vars[var] = yxt_sink(
xx[var].astype("uint8"),
(self.chunks["x"], self.chunks["y"], -1),
name=ds.type.name + "_yxt",
).squeeze("spec", drop=True)
).squeeze("spec")

coords = dict((dim, input_array.coords[dim]) for dim in input_array.dims)
return xr.Dataset(data_vars=data_vars, coords=coords)

def impute_missing_values(self, xx: xr.Dataset, image):
imputed = None
for var in xx.data_vars:
if var in self.mask_bands:
continue
nodata = xx[var].attrs.get("nodata", -999)
imputed = expr_eval(
"where((a==a)|(b<=nodata)|(b!=b), a, b)",
{
"a": image,
"b": xx[var].data,
},
name="impute_missing",
dtype="float32",
**{"nodata": nodata},
)
return imputed if imputed is not None else image

def preprocess_predict_input(self, xx: xr.Dataset):
images = []
veg_mask = None

def convert_dtype(var):
nodata = xx[var].attrs.get("nodata", -999)
image = expr_eval(
"where((a<=nodata), _nan, a)",
{
"a": xx[var].data,
},
name="convert_dtype",
dtype="float32",
**{"nodata": nodata, "_nan": np.nan},
)
return image

for var in xx.data_vars:
image = xx[var].data
if var not in self.mask_bands:
nodata = xx[var].attrs.get("nodata", -999)
image = expr_eval(
"where((a<=nodata), _nan, a)",
{
"a": image,
},
name="convert_dtype",
dtype="float32",
**{"nodata": nodata, "_nan": np.nan},
)
images += [image]
if self.temporal_coverage is not None:
# filter and impute by sensors
temporal_range = [
DateTimeRange(v) for v in self.temporal_coverage.get(var)
]
for tr in temporal_range:
if xx.solar_day.data.astype("M8[ms]") in tr:
self._log.info("Impute missing values of %s", var)
image = convert_dtype(var)
images += [
self.impute_missing_values(xx.drop_vars(var), image)
]
break
else:
# use data from all sensors
image = convert_dtype(var)
images += [image]
else:
veg_mask = expr_eval(
"where(a==_v, 1, 0)",
{"a": image},
{"a": xx[var].data},
name="make_mask",
dtype="float32",
**{"_v": int(self.mask_bands[var])},
Expand Down
22 changes: 7 additions & 15 deletions odc/stats/plugins/lc_treelite_cultivated.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class StatsCultivatedClass(StatsMLTree):

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["cultivated_class"]
_measurements = ["cultivated"]
return _measurements

def predict(self, input_array):
Expand Down Expand Up @@ -259,7 +259,7 @@ def predict(self, input_array):

def aggregate_results_from_group(self, predict_output):
# if there are >= 2 images
# any is cultivated -> final class is cultivated
# any is natural -> final class is natrual
# any is valid -> final class is valid
# for each pixel
m_size = len(predict_output)
Expand All @@ -268,14 +268,6 @@ def aggregate_results_from_group(self, predict_output):
else:
predict_output = predict_output[0]

predict_output = expr_eval(
"where(a<nodata, 1-a, a)",
{"a": predict_output},
name="invert_output",
dtype="float32",
**{"nodata": NODATA},
)

if m_size > 1:
predict_output = predict_output.sum(axis=0)

Expand All @@ -290,17 +282,17 @@ def aggregate_results_from_group(self, predict_output):
predict_output = expr_eval(
"where((a>0)&(a<nodata), _u, a)",
{"a": predict_output},
name="output_classes_cultivated",
name="output_classes_natural",
dtype="float32",
**{"_u": self.output_classes["cultivated"], "nodata": NODATA},
**{"_u": self.output_classes["natural"], "nodata": NODATA},
)

predict_output = expr_eval(
"where(a<=0, _nu, a)",
{"a": predict_output},
name="output_classes_natural",
name="output_classes_cultivated",
dtype="uint8",
**{"_nu": self.output_classes["natural"]},
**{"_nu": self.output_classes["cultivated"]},
)

return predict_output.rechunk(-1, -1)
Expand All @@ -312,7 +304,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = res[var].attrs.copy()
attrs["nodata"] = int(NODATA)
res[var].attrs = attrs
var_rename = {var: "cultivated_class"}
var_rename = dict(zip(res.data_vars, self.measurements))
return res.rename(var_rename)


Expand Down
4 changes: 2 additions & 2 deletions odc/stats/plugins/lc_treelite_woody.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class StatsWoodyCover(StatsMLTree):

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["woody_cover"]
_measurements = ["woody"]
return _measurements

def predict(self, input_array):
Expand Down Expand Up @@ -101,7 +101,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = res[var].attrs.copy()
attrs["nodata"] = int(NODATA)
res[var].attrs = attrs
var_rename = {var: "woody_cover"}
var_rename = dict(zip(res.data_vars, self.measurements))
return res.rename(var_rename)


Expand Down
15 changes: 7 additions & 8 deletions tests/test_rf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def test_cultivated_aggregate_results(
res = cultivated.aggregate_results_from_group([cultivated_results[0]])
assert (res.compute() == np.array([[112, 255], [111, 112]], dtype="uint8")).all()
res = cultivated.aggregate_results_from_group(cultivated_results)
assert (res.compute() == np.array([[111, 112], [111, 112]], dtype="uint8")).all()
assert (res.compute() == np.array([[112, 112], [111, 112]], dtype="uint8")).all()


def test_cultivated_reduce(
Expand All @@ -485,10 +485,10 @@ def test_cultivated_reduce(
)
dask_client.register_plugin(cultivated.dask_worker_plugin)
res = cultivated.reduce(input_datasets)
assert res["cultivated_class"].attrs["nodata"] == 255
assert res["cultivated_class"].data.dtype == "uint8"
assert res["cultivated"].attrs["nodata"] == 255
assert res["cultivated"].data.dtype == "uint8"
assert (
res["cultivated_class"].data.compute()
res["cultivated"].data.compute()
== np.array([[112, 255], [112, 112]], dtype="uint8")
).all()

Expand Down Expand Up @@ -528,9 +528,8 @@ def test_woody_reduce(
)
dask_client.register_plugin(woody_cover.dask_worker_plugin)
res = woody_cover.reduce(woody_inputs)
assert res["woody_cover"].attrs["nodata"] == 255
assert res["woody_cover"].data.dtype == "uint8"
assert res["woody"].attrs["nodata"] == 255
assert res["woody"].data.dtype == "uint8"
assert (
res["woody_cover"].data.compute()
== np.array([[114, 255], [114, 114]], dtype="uint8")
res["woody"].data.compute() == np.array([[114, 255], [114, 114]], dtype="uint8")
).all()
Loading