Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/boundary dataloader #90

Open
wants to merge 88 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 75 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
5df1bff
add datastore_boundary to neural_lam
sadamov Nov 18, 2024
46590ef
complete integration of boundary in weatherDataset
sadamov Nov 18, 2024
b990f49
Add test to check timestep length and spacing
sadamov Nov 18, 2024
3fd1d6b
setting default mdp boundary to 0 gridcells
sadamov Nov 18, 2024
1f2499c
implement time-based slicing
sadamov Nov 18, 2024
1af1481
remove all interior_mask and boundary_mask
sadamov Nov 19, 2024
d545cb7
added gcsfs dependency for era5 weatherbench download
sadamov Nov 19, 2024
5c1a7d7
added new era5 datastore config for boundary
sadamov Nov 19, 2024
30e4f05
removed left-over boundary-mask references
sadamov Nov 19, 2024
6a8c593
make check for existing category in datastore more flexible (for boun…
sadamov Nov 19, 2024
17c920d
implement xarray based (mostly) time slicing and windowing
sadamov Nov 20, 2024
7919995
cleanup analysis based time-slicing
sadamov Nov 21, 2024
9bafcee
implement datastore_boundary in existing tests
sadamov Nov 19, 2024
ce06bbc
allow for grid shape retrieval from forcing data
sadamov Nov 21, 2024
884b5c6
rearrange time slicing, boundary first
sadamov Nov 21, 2024
5904cbe
identified issue, cleanup next
leifdenby Nov 25, 2024
efe0302
use xarray plot only
leifdenby Nov 26, 2024
a489c2e
don't reraise
leifdenby Nov 26, 2024
242d08b
remove debug plot
leifdenby Nov 26, 2024
c1f706c
remove extent calc used in diagnosing issue
leifdenby Nov 26, 2024
cf8e3e4
add type annotation
leifdenby Nov 29, 2024
85160ce
ensure tensor copy to cpu mem before data-array creation
leifdenby Nov 29, 2024
52c4528
apply time-indexing to support ar_steps_val > 1
leifdenby Nov 29, 2024
b96d8eb
renaming test datastores
sadamov Nov 30, 2024
72da25f
adding num_past/future_boundary_step args
sadamov Nov 30, 2024
244f1cc
using combined config file
sadamov Nov 30, 2024
a9cc36e
proper handling of state/forcing/boundary in dataset
sadamov Nov 30, 2024
dcc0b46
datastore_boundars=None introduced
sadamov Nov 30, 2024
a3b3bde
bug fix for file retrieval per member
sadamov Nov 30, 2024
3ffc413
rename datastore for tests
sadamov Nov 30, 2024
85aad66
aligned time with danra for easier boundary testing
sadamov Nov 30, 2024
64f057f
Fixed test for temporal embedding
sadamov Nov 30, 2024
6205dbd
pin dataclass-wizard <0.31.0 to avoid bug in dataclass-wizard
leifdenby Dec 2, 2024
551cd26
allow boundary as input to ar_model.common_step
sadamov Dec 2, 2024
fc95350
linting
sadamov Dec 2, 2024
01fa807
improved docstrings and added some assertions
sadamov Dec 2, 2024
5a749f3
update mdp dependency
sadamov Dec 2, 2024
45ba607
remove boundary datastore from tests that don't need it
sadamov Dec 2, 2024
f36f360
fix scope of _get_slice_time
sadamov Dec 2, 2024
105108e
fix scope of _get_time_step
sadamov Dec 2, 2024
d760145
Merge branch 'feat/boundary_dataloader' of https://github.com/sadamov…
sadamov Dec 2, 2024
ae0cf76
added information about optional boundary datastore
sadamov Dec 2, 2024
9af27e0
add datastore_boundary to neural_lam
sadamov Nov 18, 2024
c25fb30
complete integration of boundary in weatherDataset
sadamov Nov 18, 2024
505ceeb
Add test to check timestep length and spacing
sadamov Nov 18, 2024
e733066
setting default mdp boundary to 0 gridcells
sadamov Nov 18, 2024
d8349a4
implement time-based slicing
sadamov Nov 18, 2024
fd791bf
remove all interior_mask and boundary_mask
sadamov Nov 19, 2024
ae82cdb
added gcsfs dependency for era5 weatherbench download
sadamov Nov 19, 2024
34a6cc7
added new era5 datastore config for boundary
sadamov Nov 19, 2024
2dc67a0
removed left-over boundary-mask references
sadamov Nov 19, 2024
9f8628e
make check for existing category in datastore more flexible (for boun…
sadamov Nov 19, 2024
388c79d
implement xarray based (mostly) time slicing and windowing
sadamov Nov 20, 2024
2529969
cleanup analysis based time-slicing
sadamov Nov 21, 2024
179a035
implement datastore_boundary in existing tests
sadamov Nov 19, 2024
2daeb16
allow for grid shape retrieval from forcing data
sadamov Nov 21, 2024
cbcdcae
rearrange time slicing, boundary first
sadamov Nov 21, 2024
e6ace27
renaming test datastores
sadamov Nov 30, 2024
42818f0
adding num_past/future_boundary_step args
sadamov Nov 30, 2024
0103b6e
using combined config file
sadamov Nov 30, 2024
0896344
proper handling of state/forcing/boundary in dataset
sadamov Nov 30, 2024
355423c
datastore_boundars=None introduced
sadamov Nov 30, 2024
121d460
bug fix for file retrieval per member
sadamov Nov 30, 2024
7e82eef
rename datastore for tests
sadamov Nov 30, 2024
320d7c4
aligned time with danra for easier boundary testing
sadamov Nov 30, 2024
f18dcc2
Fixed test for temporal embedding
sadamov Nov 30, 2024
e6327d8
allow boundary as input to ar_model.common_step
sadamov Dec 2, 2024
1374a19
linting
sadamov Dec 2, 2024
779f3e9
improved docstrings and added some assertions
sadamov Dec 2, 2024
f126ec2
remove boundary datastore from tests that don't need it
sadamov Dec 2, 2024
4b656da
fix scope of _get_time_step
sadamov Dec 2, 2024
75db4b8
added information about optional boundary datastore
sadamov Dec 2, 2024
58b4af6
Merge branch 'feat/boundary_dataloader' of https://github.com/sadamov…
sadamov Dec 2, 2024
4c17545
moved gcsfs to dev group
sadamov Dec 3, 2024
a700350
linting
sadamov Dec 3, 2024
16d5d04
Fixed issue with temporal encoding dimensions
sadamov Dec 3, 2024
f1f3f73
format docstrings
sadamov Dec 3, 2024
8fd7a10
introduced time slicing test for forecast type data
sadamov Dec 3, 2024
252a33c
bugfix temporal embedding dimension
sadamov Dec 3, 2024
8a9114a
linting
sadamov Dec 3, 2024
8c7709a
switched to low-res data
sadamov Dec 3, 2024
24cbf13
add datastore_boundary as explicit attribute
sadamov Dec 3, 2024
1d53ce7
fixing up forecast type data tests,
sadamov Dec 5, 2024
cfe1e27
time step can and should be retrieved in __init__
sadamov Dec 5, 2024
e4e4e37
Fix dataset issue in npy stat script
joeloskarsson Dec 4, 2024
3df3fcb
Merge remote-tracking branch 'mllam/main' into feat/boundary_dataloader
sadamov Dec 5, 2024
f8613da
added static feature to era5 boundary test datastore
sadamov Dec 5, 2024
f0a7046
Merge remote-tracking branch 'mllam/main' into feat/boundary_dataloader
sadamov Dec 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ Once `neural-lam` is installed you will be able to train/evaluate models. For th
interface that provides the data in a data-structure that can be used within
neural-lam. A datastore is used to create a `pytorch.Dataset`-derived
class that samples the data in time to create individual samples for
training, validation and testing.
training, validation and testing. A secondary datastore can be provided
for the boundary data. Currently, boundary datastore must be of type `mdp`
and only contain forcing features. This can easily be expanded in the future.

2. **The graph structure** is used to define message-passing GNN layers,
that are trained to emulate fluid flow in the atmosphere over time. The
Expand All @@ -121,7 +123,7 @@ different aspects about the training and evaluation of the model.

The path you provide to the neural-lam config (`config.yaml`) also sets the
root directory relative to which all other paths are resolved, as in the parent
directory of the config becomes the root directory. Both the datastore and
directory of the config becomes the root directory. Both the datastores and
graphs you generate are then stored in subdirectories of this root directory.
Exactly how and where a specific datastore expects its source data to be stored
and where it stores its derived data is up to the implementation of the
Expand All @@ -134,6 +136,7 @@ assume you placed `config.yaml` in a folder called `data`):
data/
├── config.yaml - Configuration file for neural-lam
├── danra.datastore.yaml - Configuration file for the datastore, referred to from config.yaml
├── era5.datastore.zarr/ - Optional configuration file for the boundary datastore, referred to from config.yaml
└── graphs/ - Directory containing graphs for training
```

Expand All @@ -142,18 +145,20 @@ And the content of `config.yaml` could in this case look like:
datastore:
kind: mdp
config_path: danra.datastore.yaml
datastore_boundary:
kind: mdp
config_path: era5.datastore.yaml
training:
state_feature_weighting:
__config_class__: ManualStateFeatureWeighting
values:
weights:
u100m: 1.0
v100m: 1.0
```

For now the neural-lam config only defines two things: 1) the kind of data
store and the path to its config, and 2) the weighting of different features in
the loss function. If you don't define the state feature weighting it will default
to weighting all features equally.
For now the neural-lam config only defines two things:
1) the kind of datastores and the path to their config
2) the weighting of different features in the loss function. If you don't define the state feature weighting it will default to weighting all features equally.

(This example is taken from the `tests/datastore_examples/mdp` directory.)

Expand Down Expand Up @@ -525,5 +530,4 @@ Furthermore, all tests in the ```tests``` directory will be run upon pushing cha

# Contact
If you are interested in machine learning models for LAM, have questions about the implementation or ideas for extending it, feel free to get in touch.
There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join (after following the link you have to request to join, this is to avoid spam bots).
You can also open a github issue on this page, or (if more suitable) send an email to [[email protected]](mailto:[email protected]).
There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join. You can also open a github issue on this page, or (if more suitable) send an email to [[email protected]](mailto:[email protected]).
13 changes: 12 additions & 1 deletion neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,15 @@ def load_config_and_datastore(
datastore_kind=config.datastore.kind, config_path=datastore_config_path
)

return config, datastore
if config.datastore_boundary is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could rename the function to load_datastores (rather than load_datastore) and update the docstring accordingly ? :)

datastore_boundary_config_path = (
Path(config_path).parent / config.datastore_boundary.config_path
)
datastore_boundary = init_datastore(
datastore_kind=config.datastore_boundary.kind,
config_path=datastore_boundary_config_path,
)
else:
datastore_boundary = None

return config, datastore, datastore_boundary
26 changes: 7 additions & 19 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,23 +228,6 @@ def get_dataarray(
"""
pass

@cached_property
@abc.abstractmethod
def boundary_mask(self) -> xr.DataArray:
"""
Return the boundary mask for the dataset, with spatial dimensions
stacked. Where the value is 1, the grid point is a boundary point, and
where the value is 0, the grid point is not a boundary point.

Returns
-------
xr.DataArray
The boundary mask for the dataset, with dimensions
`('grid_index',)`.

"""
pass

@abc.abstractmethod
def get_xy(self, category: str) -> np.ndarray:
"""
Expand Down Expand Up @@ -295,8 +278,13 @@ def get_xy_extent(self, category: str) -> List[float]:
The extent of the x, y coordinates.

"""
xy = self.get_xy(category, stacked=False)
extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
xy = self.get_xy(category, stacked=True)
extent = [
xy[:, 0].min(),
xy[:, 0].max(),
xy[:, 1].min(),
xy[:, 1].max(),
]
return [float(v) for v in extent]

@property
Expand Down
75 changes: 26 additions & 49 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Standard library
import copy
import warnings
from functools import cached_property
from pathlib import Path
Expand Down Expand Up @@ -26,11 +27,10 @@ class MDPDatastore(BaseRegularGridDatastore):

SHORT_NAME = "mdp"

def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
def __init__(self, config_path, reuse_existing=True):
"""
Construct a new MDPDatastore from the configuration file at
`config_path`. A boundary mask is created with `n_boundary_points`
boundary points. If `reuse_existing` is True, the dataset is loaded
`config_path`. If `reuse_existing` is True, the dataset is loaded
from a zarr file if it exists (unless the config has been modified
since the zarr was created), otherwise it is created from the
configuration file.
Expand All @@ -41,8 +41,6 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
The path to the configuration file, this will be fed to the
`mllam_data_prep.Config.from_yaml_file` method to then call
`mllam_data_prep.create_dataset` to create the dataset.
n_boundary_points : int
The number of boundary points to use in the boundary mask.
reuse_existing : bool
Whether to reuse an existing dataset zarr file if it exists and its
creation date is newer than the configuration file.
Expand All @@ -69,7 +67,6 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
if self._ds is None:
self._ds = mdp.create_dataset(config=self._config)
self._ds.to_zarr(fp_ds)
self._n_boundary_points = n_boundary_points

print("The loaded datastore contains the following features:")
for category in ["state", "forcing", "static"]:
Expand Down Expand Up @@ -157,8 +154,8 @@ def get_vars_units(self, category: str) -> List[str]:
The units of the variables in the given category.

"""
if category not in self._ds and category == "forcing":
warnings.warn("no forcing data found in datastore")
if category not in self._ds:
warnings.warn(f"no {category} data found in datastore")
return []
return self._ds[f"{category}_feature_units"].values.tolist()

Expand All @@ -176,8 +173,8 @@ def get_vars_names(self, category: str) -> List[str]:
The names of the variables in the given category.

"""
if category not in self._ds and category == "forcing":
warnings.warn("no forcing data found in datastore")
if category not in self._ds:
warnings.warn(f"no {category} data found in datastore")
return []
return self._ds[f"{category}_feature"].values.tolist()

Expand All @@ -196,8 +193,8 @@ def get_vars_long_names(self, category: str) -> List[str]:
The long names of the variables in the given category.

"""
if category not in self._ds and category == "forcing":
warnings.warn("no forcing data found in datastore")
if category not in self._ds:
warnings.warn(f"no {category} data found in datastore")
return []
return self._ds[f"{category}_feature_long_name"].values.tolist()

Expand Down Expand Up @@ -252,9 +249,9 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
The xarray DataArray object with processed dataset.

"""
if category not in self._ds and category == "forcing":
warnings.warn("no forcing data found in datastore")
return None
if category not in self._ds:
warnings.warn(f"no {category} data found in datastore")
return []

da_category = self._ds[category]

Expand Down Expand Up @@ -318,37 +315,6 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
ds_stats = self._ds[stats_variables.keys()].rename(stats_variables)
return ds_stats

@cached_property
def boundary_mask(self) -> xr.DataArray:
"""
Produce a 0/1 mask for the boundary points of the dataset, these will
sit at the edges of the domain (in x/y extent) and will be used to mask
out the boundary points from the loss function and to overwrite the
boundary points from the prediction. For now this is created when the
mask is requested, but in the future this could be saved to the zarr
file.

Returns
-------
xr.DataArray
A 0/1 mask for the boundary points of the dataset, where 1 is a
boundary point and 0 is not.

"""
ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds)
da_state_variable = (
ds_unstacked["state"].isel(time=0).isel(state_feature=0)
)
da_domain_allzero = xr.zeros_like(da_state_variable)
ds_unstacked["boundary_mask"] = da_domain_allzero.isel(
x=slice(self._n_boundary_points, -self._n_boundary_points),
y=slice(self._n_boundary_points, -self._n_boundary_points),
)
ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(
1
).astype(int)
return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask)

@property
def coords_projection(self) -> ccrs.Projection:
"""
Expand Down Expand Up @@ -394,7 +360,9 @@ def coords_projection(self) -> ccrs.Projection:

class_name = projection_info["class_name"]
ProjectionClass = getattr(ccrs, class_name)
kwargs = projection_info["kwargs"]
# need to copy otherwise we modify the dict stored in the dataclass
# in-place
kwargs = copy.deepcopy(projection_info["kwargs"])

globe_kwargs = kwargs.pop("globe", {})
if len(globe_kwargs) > 0:
Expand All @@ -412,8 +380,17 @@ def grid_shape_state(self):
The shape of the cartesian grid for the state variables.

"""
ds_state = self.unstack_grid_coords(self._ds["state"])
da_x, da_y = ds_state.x, ds_state.y
# Boundary data often has no state features
if "state" not in self._ds:
warnings.warn(
"no state data found in datastore"
"returning grid shape from forcing data"
)
ds_forcing = self.unstack_grid_coords(self._ds["forcing"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could maybe just call this da_grid_reference or something and then unstack etc

da_x, da_y = ds_forcing.x, ds_forcing.y
else:
ds_state = self.unstack_grid_coords(self._ds["state"])
da_x, da_y = ds_state.x, ds_state.y
assert da_x.ndim == da_y.ndim == 1
return CartesianGridShape(x=da_x.size, y=da_y.size)

Expand Down
sadamov marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def main(
ar_steps = 63
ds = WeatherDataset(
datastore=datastore,
datastore_boundary=None,
split="train",
ar_steps=ar_steps,
standardize=False,
Expand Down
42 changes: 11 additions & 31 deletions neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,9 @@ def _get_single_timeseries_dataarray(
* np.timedelta64(1, "h")
)
elif d == "analysis_time":
coord_values = self._get_analysis_times(split=split)
coord_values = self._get_analysis_times(
split=split, member_id=member
)
elif d == "y":
coord_values = y
elif d == "x":
Expand Down Expand Up @@ -505,23 +507,29 @@ def _get_single_timeseries_dataarray(

return da

def _get_analysis_times(self, split) -> List[np.datetime64]:
def _get_analysis_times(self, split, member_id) -> List[np.datetime64]:
"""Get the analysis times for the given split by parsing the filenames
of all the files found for the given split.

Parameters
----------
split : str
The dataset split to get the analysis times for.
member_id : int
The ensemble member to get the analysis times for.

Returns
-------
List[dt.datetime]
The analysis times for the given split.

"""
if member_id is None:
# Only interior state data files have member_id, to avoid duplicates
# we only look at the first member for all other categories
member_id = 0
pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT)
pattern = re.sub(r"{member_id:[^}]*}", "*", pattern)
pattern = re.sub(r"{member_id:[^}]*}", f"{member_id:03d}", pattern)

sample_dir = self.root_path / "samples" / split
sample_files = sample_dir.glob(pattern)
Expand Down Expand Up @@ -668,34 +676,6 @@ def grid_shape_state(self) -> CartesianGridShape:
ny, nx = self.config.grid_shape_state
return CartesianGridShape(x=nx, y=ny)

@cached_property
def boundary_mask(self) -> xr.DataArray:
"""The boundary mask for the dataset. This is a binary mask that is 1
where the grid cell is on the boundary of the domain, and 0 otherwise.

Returns
-------
xr.DataArray
The boundary mask for the dataset, with dimensions `[grid_index]`.

"""
xy = self.get_xy(category="state", stacked=False)
xs = xy[:, :, 0]
ys = xy[:, :, 1]
# Check if x-coordinates are constant along columns
assert np.allclose(xs, xs[:, [0]]), "x-coordinates are not constant"
# Check if y-coordinates are constant along rows
assert np.allclose(ys, ys[[0], :]), "y-coordinates are not constant"
# Extract unique x and y coordinates
x = xs[:, 0] # Unique x-coordinates (changes along the first axis)
y = ys[0, :] # Unique y-coordinates (changes along the second axis)
values = np.load(self.root_path / "static" / "border_mask.npy")
da_mask = xr.DataArray(
values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask"
)
da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int)
return da_mask_stacked_xy

def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""Return the standardization dataarray for the given category. This
should contain a `{category}_mean` and `{category}_std` variable for
Expand Down
Loading
Loading