Skip to content

Commit

Permalink
Make output band name an option (#166)
Browse files Browse the repository at this point in the history
* fix the nodata issue when default to nan with float dtype

* flexible to the case that numpy change convention

* filter data by sensor temporal coverage in ml

* correct time dim and format

* make output band name as config options

* make output band name an option for landcover final

* address too many warnings from tests

---------

Co-authored-by: Emma Ai <[email protected]>
  • Loading branch information
emmaai and Emma Ai authored Nov 13, 2024
1 parent b6c2e6d commit 1aa3496
Show file tree
Hide file tree
Showing 22 changed files with 68 additions and 118 deletions.
7 changes: 5 additions & 2 deletions odc/stats/plugins/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
rgb_clamp: Tuple[float, float] = (1.0, 3_000.0),
transform_code: Optional[str] = None,
area_of_interest: Optional[Sequence[float]] = None,
measurements: Optional[Sequence[str]] = None,
):
self.resampling = resampling
self.input_bands = input_bands if input_bands is not None else []
Expand All @@ -40,12 +41,14 @@ def __init__(
self.rgb_clamp = rgb_clamp
self.transform_code = transform_code
self.area_of_interest = area_of_interest
self._measurements = measurements
self.dask_worker_plugin = None

@property
@abstractmethod
def measurements(self) -> Tuple[str, ...]:
pass
if self._measurements is None:
raise NotImplementedError("Plugins must provide 'measurements'")
return self._measurements

def native_transform(self, xx: xr.Dataset) -> xr.Dataset:
for var in xx.data_vars:
Expand Down
13 changes: 2 additions & 11 deletions odc/stats/plugins/lc_fc_wo_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ def __init__(
self.ue_threshold = ue_threshold if ue_threshold is not None else 30
self.cloud_filters = cloud_filters if cloud_filters is not None else {}

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["veg_frequency", "water_frequency"]
return _measurements

def native_transform(self, xx):
"""
Loads data in its native projection. It performs the following:
Expand Down Expand Up @@ -217,12 +212,8 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
"veg_frequency": xr.DataArray(
max_count_veg, dims=xx["wet"].dims[1:], attrs=attrs
),
"water_frequency": xr.DataArray(
max_count_water, dims=xx["wet"].dims[1:], attrs=attrs
),
k: xr.DataArray(v, dims=xx["wet"].dims[1:], attrs=attrs)
for k, v in zip(self.measurements, [max_count_veg, max_count_water])
}
coords = dict((dim, xx.coords[dim]) for dim in xx["wet"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)
Expand Down
12 changes: 3 additions & 9 deletions odc/stats/plugins/lc_level34.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Plugin of Module A3 in LandCover PipeLine
"""

from typing import Tuple, Optional, List
from typing import Optional, List

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -55,11 +55,6 @@ def __init__(
water_seasonality_threshold if water_seasonality_threshold else 3
)

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["level3", "level4"]
return _measurements

def fuser(self, xx):
return xx

Expand Down Expand Up @@ -111,10 +106,9 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = xx.attrs.copy()
attrs["nodata"] = NODATA
dims = xx.classes_l3_l4.dims[1:]

data_vars = {
"level3": xr.DataArray(level3.squeeze(), dims=dims, attrs=attrs),
"level4": xr.DataArray(level4.squeeze(), dims=dims, attrs=attrs),
k: xr.DataArray(v, dims=dims, attrs=attrs)
for k, v in zip(self.measurements, [level3.squeeze(), level4.squeeze()])
}

coords = dict((dim, xx.coords[dim]) for dim in dims)
Expand Down
2 changes: 1 addition & 1 deletion odc/stats/plugins/lc_ml_treelite.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:

res = self.aggregate_results_from_group(res)
attrs = xx.attrs.copy()
dims = list(xx.dims.keys())[:2]
dims = list(xx.sizes.keys())[:2]
data_vars = {"predict_output": xr.DataArray(res, dims=dims, attrs=attrs)}
coords = {dim: xx.coords[dim] for dim in dims}
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
9 changes: 2 additions & 7 deletions odc/stats/plugins/lc_tf_urban.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Plugin of TF urban model in LandCover PipeLine
"""

from typing import Tuple, Dict, Sequence
from typing import Dict, Sequence

import os
import numpy as np
Expand Down Expand Up @@ -91,11 +91,6 @@ def __init__(
else:
self.crop_size = crop_size

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["urban_classes"]
return _measurements

def input_data(
self, datasets: Sequence[Dataset], geobox: GeoBox, **kwargs
) -> xr.Dataset:
Expand Down Expand Up @@ -219,7 +214,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
dims = list(xx.dims.keys())[:2]
data_vars = {"urban_classes": xr.DataArray(um, dims=dims, attrs=attrs)}
data_vars = {self.measurements[0]: xr.DataArray(um, dims=dims, attrs=attrs)}
coords = {dim: xx.coords[dim] for dim in dims}
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)

Expand Down
6 changes: 0 additions & 6 deletions odc/stats/plugins/lc_treelite_cultivated.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Plugin of RFclassfication cultivated model in LandCover PipeLine
"""

from typing import Tuple
import numpy as np
import xarray as xr
import dask.array as da
Expand Down Expand Up @@ -226,11 +225,6 @@ class StatsCultivatedClass(StatsMLTree):
VERSION = "0.0.1"
PRODUCT_FAMILY = "lccs"

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["cultivated"]
return _measurements

def predict(self, input_array):
bands_indices = dict(zip(self.input_bands, np.arange(len(self.input_bands))))
input_features = da.map_blocks(
Expand Down
6 changes: 0 additions & 6 deletions odc/stats/plugins/lc_treelite_woody.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Plugin of RFregressor woody cover model in LandCover PipeLine
"""

from typing import Tuple
import xarray as xr
import dask.array as da

Expand All @@ -19,11 +18,6 @@ class StatsWoodyCover(StatsMLTree):
VERSION = "0.0.1"
PRODUCT_FAMILY = "lccs"

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["woody"]
return _measurements

def predict(self, input_array):
wc = da.map_blocks(
mask_and_predict,
Expand Down
22 changes: 4 additions & 18 deletions odc/stats/plugins/lc_veg_class_a1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self._measurements = (
measurements if measurements is not None else self.input_bands
)

@property
def measurements(self) -> Tuple[str, ...]:
return self._measurements

def native_transform(self, xx):
# reproject cannot work with nodata being int for float
Expand Down Expand Up @@ -89,11 +82,6 @@ def __init__(
)
self.output_classes = output_classes

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["classes_l3_l4", "water_seasonality"]
return _measurements

def fuser(self, xx):
return xx

Expand Down Expand Up @@ -249,12 +237,10 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
"classes_l3_l4": xr.DataArray(
l3_mask[0], dims=xx["veg_frequency"].dims[1:], attrs=attrs
),
"water_seasonality": xr.DataArray(
water_seasonality[0], dims=xx["veg_frequency"].dims[1:], attrs=attrs
),
k: xr.DataArray(v, dims=xx["veg_frequency"].dims[1:], attrs=attrs)
for k, v in zip(
self.measurements, [l3_mask.squeeze(0), water_seasonality.squeeze(0)]
)
}
coords = dict((dim, xx.coords[dim]) for dim in xx["veg_frequency"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)
Expand Down
15 changes: 1 addition & 14 deletions tests/test_fc_percentiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def dataset():
coords = {
"x": np.linspace(10, 20, band_1.shape[2]),
"y": np.linspace(0, 5, band_1.shape[1]),
"spec": index,
}

data_vars = {
Expand All @@ -57,6 +56,7 @@ def dataset():
"ue": xr.DataArray(band_3, dims=("spec", "y", "x")),
}
xx = xr.Dataset(data_vars=data_vars, coords=coords)
xx = xx.assign_coords(xr.Coordinates.from_pandas_multiindex(index, "spec"))

return xx

Expand Down Expand Up @@ -110,10 +110,6 @@ def test_native_transform(dataset, bits):
def test_fusing(dataset):
stats_fcp = StatsFCP()
xx = stats_fcp.native_transform(dataset)
for dim in xx.dims:
if isinstance(xx.get_index(dim), pd.MultiIndex):
xx = xx.reset_index(dim)
xx = xx.set_xindex("solar_day")
xx = xx.groupby("solar_day").map(partial(StatsFCP.fuser, None))
assert xx["band_1"].attrs["test_attr"] == 57

Expand Down Expand Up @@ -142,11 +138,6 @@ def test_fusing(dataset):
# Test fusing with UE filter and sum 120 limit
stats_fcp_ue30_sum120 = StatsFCP(ue_threshold=30, max_sum_limit=120)
xx_ue30_sum120 = stats_fcp_ue30_sum120.native_transform(dataset)
for dim in xx_ue30_sum120.dims:
if isinstance(xx_ue30_sum120.get_index(dim), pd.MultiIndex):
xx_ue30_sum120 = xx_ue30_sum120.reset_index(dim)
xx_ue30_sum120 = xx_ue30_sum120.set_xindex("solar_day")

xx_ue30_sum120 = xx_ue30_sum120.groupby("solar_day").map(
partial(StatsFCP.fuser, None)
)
Expand All @@ -164,10 +155,6 @@ def test_fusing(dataset):
def test_reduce(dataset):
stats_fcp = StatsFCP(count_valid=True)
xx = stats_fcp.native_transform(dataset)
for dim in xx.dims:
if isinstance(xx.get_index(dim), pd.MultiIndex):
xx = xx.reset_index(dim)
xx = xx.set_xindex("solar_day")
xx = xx.groupby("solar_day").map(partial(StatsFCP.fuser, None))
xx = xx.compute()
xx = stats_fcp.reduce(xx)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gm_ls_bitmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def dataset(usgs_ls8_sr_definition):
coords = {
"x": np.linspace(10, 20, band_red.shape[2]),
"y": np.linspace(0, 5, band_pq.shape[1]),
"spec": index,
}
pq_flags_definition = {}
for measurement in usgs_ls8_sr_definition["measurements"]:
Expand All @@ -55,6 +54,7 @@ def dataset(usgs_ls8_sr_definition):
"QA_PIXEL": (("spec", "y", "x"), band_pq, attrs),
}
xx = xr.Dataset(data_vars=data_vars, coords=coords)
xx = xx.assign_coords(xr.Coordinates.from_pandas_multiindex(index, "spec"))
xx["band_red"].attrs["nodata"] = 0
return xx

Expand Down
16 changes: 8 additions & 8 deletions tests/test_landcover_plugin_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def dataset_md():
coords = {
"x": np.linspace(10, 20, band_1.shape[2]),
"y": np.linspace(0, 5, band_1.shape[1]),
"spec": index,
}
data_vars = {
"band_1": xr.DataArray(
band_1, dims=("spec", "y", "x"), attrs={"nodata": np.nan}
),
}
xx = xr.Dataset(data_vars=data_vars, coords=coords)
xx = xx.assign_coords(xr.Coordinates.from_pandas_multiindex(index, "spec"))
return xx


Expand Down Expand Up @@ -301,7 +301,6 @@ def fc_wo_dataset():
coords = {
"x": np.linspace(10, 20, water.shape[2]),
"y": np.linspace(0, 5, water.shape[1]),
"spec": index,
}
data_vars = {
"water": xr.DataArray(water, dims=("spec", "y", "x"), attrs={"nodata": 1}),
Expand All @@ -311,6 +310,7 @@ def fc_wo_dataset():
"bs": xr.DataArray(bs, dims=("spec", "y", "x"), attrs={"nodata": 255}),
}
xx = xr.Dataset(data_vars=data_vars, coords=coords)
xx = xx.assign_coords(xr.Coordinates.from_pandas_multiindex(index, "spec"))

return xx

Expand All @@ -319,7 +319,7 @@ def fc_wo_dataset():
def test_native_transform(fc_wo_dataset, bits):
xx = fc_wo_dataset.copy()
xx["water"] = da.bitwise_or(xx["water"], bits)
stats_veg = StatsVegCount()
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
out_xx = stats_veg.native_transform(xx).compute()

expected_valid = (
Expand Down Expand Up @@ -349,7 +349,7 @@ def test_native_transform(fc_wo_dataset, bits):


def test_fusing(fc_wo_dataset):
stats_veg = StatsVegCount()
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
xx = stats_veg.native_transform(fc_wo_dataset)
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None)).compute()
valid_index = (
Expand All @@ -369,7 +369,7 @@ def test_fusing(fc_wo_dataset):


def test_veg_or_not(fc_wo_dataset):
stats_veg = StatsVegCount()
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
xx = stats_veg.native_transform(fc_wo_dataset)
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
yy = stats_veg._veg_or_not(xx).compute()
Expand All @@ -386,7 +386,7 @@ def test_veg_or_not(fc_wo_dataset):


def test_water_or_not(fc_wo_dataset):
stats_veg = StatsVegCount()
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
xx = stats_veg.native_transform(fc_wo_dataset)
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
yy = stats_veg._water_or_not(xx).compute()
Expand All @@ -403,7 +403,7 @@ def test_water_or_not(fc_wo_dataset):


def test_reduce(fc_wo_dataset):
stats_veg = StatsVegCount()
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
xx = stats_veg.native_transform(fc_wo_dataset)
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
xx = stats_veg.reduce(xx).compute()
Expand Down Expand Up @@ -437,7 +437,7 @@ def test_reduce(fc_wo_dataset):


def test_consecutive_month(consecutive_count):
stats_veg = StatsVegCount()
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
xx = stats_veg._max_consecutive_months(consecutive_count, 255).compute()
expected_value = np.array(
[
Expand Down
5 changes: 4 additions & 1 deletion tests/test_landcover_plugin_a1.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def dataset():
coords = {
"x": np.linspace(10, 20, wo_fq.shape[2]),
"y": np.linspace(0, 5, wo_fq.shape[1]),
"spec": index,
}
data_vars = {
"frequency": xr.DataArray(
Expand All @@ -122,6 +121,7 @@ def dataset():
),
}
xx = xr.Dataset(data_vars=data_vars, coords=coords)
xx = xx.assign_coords(xr.Coordinates.from_pandas_multiindex(index, "spec"))
return xx


Expand All @@ -135,6 +135,7 @@ def test_l3_classes(dataset):
"surface": 210,
},
optional_bands=["canopy_cover_class", "elevation"],
measurements=["level_3_4", "water_season"],
)

expected_res = np.array(
Expand Down Expand Up @@ -163,6 +164,7 @@ def test_l4_water_seasonality(dataset):
"surface": 210,
},
optional_bands=["canopy_cover_class", "elevation"],
measurements=["level_3_4", "water_season"],
)

wo_fq = np.array(
Expand Down Expand Up @@ -208,6 +210,7 @@ def test_reduce(dataset):
"surface": 210,
},
optional_bands=["canopy_cover_class", "elevation"],
measurements=["level_3_4", "water_season"],
)
res = stats_l3.reduce(dataset)

Expand Down
Loading

0 comments on commit 1aa3496

Please sign in to comment.