Skip to content

Commit

Permalink
Dev icenet-ai#186: dynamic masking of land within the dataset generation
Browse files Browse the repository at this point in the history
  • Loading branch information
JimCircadian committed Aug 14, 2024
1 parent 55870fd commit 50fae3c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 40 deletions.
2 changes: 1 addition & 1 deletion icenet/data/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _construct_channels(self):
self._channels[var_name] = 1
self._add_channel_files(
var_name,
meta_channel["files"])
meta_channel["processed_files"][var_name])

logging.debug(
"Channel quantities deduced:\n{}\n\nTotal channels: {}".format(
Expand Down
17 changes: 12 additions & 5 deletions icenet/data/loaders/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ def __init__(self,
super().__init__(*args, **kwargs)

# FIXME
self._masks = da.array(
[np.load(self._config["masks"]["active_grid_cell"][month-1]) for month in range(1, 13)])
# self._masks = da.array(
# [np.load(self._config["masks"]["active_grid_cell"][month-1]) for month in range(1, 13)])
self._masks = {var_name: xr.open_dataarray(mask_cfg["processed_files"][var_name][0])
for var_name, mask_cfg in self._config["masks"].items()}

self._futures = futures_per_worker

Expand Down Expand Up @@ -439,7 +441,10 @@ def generate_sample(forecast_date: object,
sample_weight = da.zeros(shape, dtype)
else:
# Zero loss outside of 'active grid cells'
sample_weight = masks[forecast_day.month - 1]
#sample_weight = masks["active_grid_cell"][forecast_day.month - 1]
sample_weight = masks["active_grid_cell"].sel(month=forecast_day.month).data
sample_weight[masks["land"]] = True
# TODO: dynamic inclusion of polarhole
sample_weight = sample_weight.astype(dtype)

# We can pick up nans, which messes up training
Expand Down Expand Up @@ -475,8 +480,10 @@ def generate_sample(forecast_date: object,
channel_data = []
for idx in channel_idxs:
try:
channel_data.append(
getattr(channel_ds, var_name).isel(time=idx))
data = getattr(channel_ds, var_name).isel(time=idx)
if var_name.startswith("siconca"):
data = da.ma.where(masks["land"], 0., data)
channel_data.append(data)
except KeyError:
channel_data.append(da.zeros(shape))

Expand Down
81 changes: 47 additions & 34 deletions icenet/data/masks/osisaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(self,
self._dataset_config = mask_ds.save_config()

super().__init__(mask_ds,
absolute_vars=["active_grid_cell", "land", "polarhole"],
absolute_vars=["active_grid_cell", "land", "land_map", "polarhole"],
identifier="masks",
**kwargs)

Expand All @@ -192,58 +192,67 @@ def get_config(self,

def process(self):
# Active grid cell mask preparation
destination_filename = os.path.join(self.path, "active_grid_cell.nc")
mask_files = self._source_files["active_grid_cell"]

if not os.path.exists(destination_filename):
da = xr.DataArray(data=[np.load(acgm_month_file) for acgm_month_file in mask_files],
dims=["month", "yc", "xc"],
coords=dict(
month=range(1, 13),
),
attrs=dict(description="IceNet active grid cell mask metadata"))
self.save_processed_file("active_grid_cell", os.path.basename(destination_filename), da)
da = xr.DataArray(data=[np.load(acgm_month_file) for acgm_month_file in mask_files],
dims=["month", "yc", "xc"],
coords=dict(
month=range(1, 13),
),
attrs=dict(description="IceNet active grid cell mask metadata"))

self.save_processed_file("active_grid_cell",
os.path.basename(self.active_grid_cell_filename),
da,
overwrite=False)

# Land mask preparation
filename = "land.nc"
if not os.path.exists(os.path.join(self.path, filename)):
land_mask = np.load(self._source_files["land"])
land_map = np.ones(land_mask.shape, dtype=self.dtype)
land_map[~land_mask] = -1.
land_mask = np.load(self._source_files["land"])

da = xr.DataArray(data=land_mask,
dims=["yc", "xc"],
attrs=dict(description="IceNet land mask metadata"))

self.save_processed_file("land", os.path.basename(self.land_filename), da, overwrite=False)

land_map = np.ones(land_mask.shape, dtype=self.dtype)
land_map[~land_mask] = -1.
da = xr.DataArray(data=land_map,
dims=["yc", "xc"],
attrs=dict(description="IceNet land map metadata"))

da = xr.DataArray(data=land_map,
dims=["yc", "xc"],
attrs=dict(description="IceNet land mask metadata"))
self.save_processed_file("land", filename, da)
self.save_processed_file("land_map", os.path.basename(self.land_map_filename), da, overwrite=False)

# Polar hole mask preparation
destination_filename = os.path.join(self.path, "polarhole.nc")
mask_files = self._source_files["polarhole"]

if not os.path.exists(destination_filename):
da = xr.DataArray(data=[np.load(polarhole_file) for polarhole_file in mask_files],
dims=["polarhole", "yc", "xc"],
coords=dict(
polarhole=[pd.Timestamp(el) for el in [
dt.date(1987, 6, 1),
dt.date(2005, 10, 1),
dt.date(2015, 12, 1),
]],
),
attrs=dict(description="IceNet polar hole mask metadata"))
self.save_processed_file("polarhole", os.path.basename(destination_filename), da)
da = xr.DataArray(data=[np.load(polarhole_file) for polarhole_file in mask_files],
dims=["polarhole", "yc", "xc"],
coords=dict(
polarhole=[pd.Timestamp(el) for el in [
dt.date(1987, 6, 1),
dt.date(2005, 10, 1),
dt.date(2015, 12, 1),
]],
),
attrs=dict(description="IceNet polar hole mask metadata"))

self.save_processed_file("polarhole",
os.path.basename(self.polarhole_filename),
da,
overwrite=False)

self.save_config()

def active_grid_cell(self, date=None, *args, **kwargs):
da = xr.open_dataarray(self.active_grid_cell_filename)
da = da.sel(month=pd.to_datetime(date).month)
return da.data
return ~da.data

# TODO: caching please
def land(self, *args, **kwargs):
da = xr.open_dataarray(self.land_filename)
return da.data > 0
return da.data

def polarhole(self, date, *args, **kwargs):
da = xr.open_dataarray(self.polarhole_filename)
Expand All @@ -265,6 +274,10 @@ def active_grid_cell_filename(self):
def land_filename(self):
return os.path.join(self.path, "land.nc")

@property
def land_map_filename(self):
return os.path.join(self.path, "land_map.nc")

@property
def polarhole_filename(self):
return os.path.join(self.path, "polarhole.nc")

0 comments on commit 50fae3c

Please sign in to comment.