-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/boundary dataloader #90
Open
sadamov
wants to merge
88
commits into
mllam:main
Choose a base branch
from
sadamov:feat/boundary_dataloader
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 75 commits
Commits
Show all changes
88 commits
Select commit
Hold shift + click to select a range
5df1bff
add datastore_boundary to neural_lam
sadamov 46590ef
complete integration of boundary in weatherDataset
sadamov b990f49
Add test to check timestep length and spacing
sadamov 3fd1d6b
setting default mdp boundary to 0 gridcells
sadamov 1f2499c
implement time-based slicing
sadamov 1af1481
remove all interior_mask and boundary_mask
sadamov d545cb7
added gcsfs dependency for era5 weatherbench download
sadamov 5c1a7d7
added new era5 datastore config for boundary
sadamov 30e4f05
removed left-over boundary-mask references
sadamov 6a8c593
make check for existing category in datastore more flexible (for boun…
sadamov 17c920d
implement xarray based (mostly) time slicing and windowing
sadamov 7919995
cleanup analysis based time-slicing
sadamov 9bafcee
implement datastore_boundary in existing tests
sadamov ce06bbc
allow for grid shape retrieval from forcing data
sadamov 884b5c6
rearrange time slicing, boundary first
sadamov 5904cbe
identified issue, cleanup next
leifdenby efe0302
use xarray plot only
leifdenby a489c2e
don't reraise
leifdenby 242d08b
remove debug plot
leifdenby c1f706c
remove extent calc used in diagnosing issue
leifdenby cf8e3e4
add type annotation
leifdenby 85160ce
ensure tensor copy to cpu mem before data-array creation
leifdenby 52c4528
apply time-indexing to support ar_steps_val > 1
leifdenby b96d8eb
renaming test datastores
sadamov 72da25f
adding num_past/future_boundary_step args
sadamov 244f1cc
using combined config file
sadamov a9cc36e
proper handling of state/forcing/boundary in dataset
sadamov dcc0b46
datastore_boundars=None introduced
sadamov a3b3bde
bug fix for file retrieval per member
sadamov 3ffc413
rename datastore for tests
sadamov 85aad66
aligned time with danra for easier boundary testing
sadamov 64f057f
Fixed test for temporal embedding
sadamov 6205dbd
pin dataclass-wizard <0.31.0 to avoid bug in dataclass-wizard
leifdenby 551cd26
allow boundary as input to ar_model.common_step
sadamov fc95350
linting
sadamov 01fa807
improved docstrings and added some assertions
sadamov 5a749f3
update mdp dependency
sadamov 45ba607
remove boundary datastore from tests that don't need it
sadamov f36f360
fix scope of _get_slice_time
sadamov 105108e
fix scope of _get_time_step
sadamov d760145
Merge branch 'feat/boundary_dataloader' of https://github.com/sadamov…
sadamov ae0cf76
added information about optional boundary datastore
sadamov 9af27e0
add datastore_boundary to neural_lam
sadamov c25fb30
complete integration of boundary in weatherDataset
sadamov 505ceeb
Add test to check timestep length and spacing
sadamov e733066
setting default mdp boundary to 0 gridcells
sadamov d8349a4
implement time-based slicing
sadamov fd791bf
remove all interior_mask and boundary_mask
sadamov ae82cdb
added gcsfs dependency for era5 weatherbench download
sadamov 34a6cc7
added new era5 datastore config for boundary
sadamov 2dc67a0
removed left-over boundary-mask references
sadamov 9f8628e
make check for existing category in datastore more flexible (for boun…
sadamov 388c79d
implement xarray based (mostly) time slicing and windowing
sadamov 2529969
cleanup analysis based time-slicing
sadamov 179a035
implement datastore_boundary in existing tests
sadamov 2daeb16
allow for grid shape retrieval from forcing data
sadamov cbcdcae
rearrange time slicing, boundary first
sadamov e6ace27
renaming test datastores
sadamov 42818f0
adding num_past/future_boundary_step args
sadamov 0103b6e
using combined config file
sadamov 0896344
proper handling of state/forcing/boundary in dataset
sadamov 355423c
datastore_boundars=None introduced
sadamov 121d460
bug fix for file retrieval per member
sadamov 7e82eef
rename datastore for tests
sadamov 320d7c4
aligned time with danra for easier boundary testing
sadamov f18dcc2
Fixed test for temporal embedding
sadamov e6327d8
allow boundary as input to ar_model.common_step
sadamov 1374a19
linting
sadamov 779f3e9
improved docstrings and added some assertions
sadamov f126ec2
remove boundary datastore from tests that don't need it
sadamov 4b656da
fix scope of _get_time_step
sadamov 75db4b8
added information about optional boundary datastore
sadamov 58b4af6
Merge branch 'feat/boundary_dataloader' of https://github.com/sadamov…
sadamov 4c17545
moved gcsfs to dev group
sadamov a700350
linting
sadamov 16d5d04
Fixed issue with temporal encoding dimensions
sadamov f1f3f73
format docstrings
sadamov 8fd7a10
introduced time slicing test for forecast type data
sadamov 252a33c
bugfix temporal embedding dimension
sadamov 8a9114a
linting
sadamov 8c7709a
switched to low-res data
sadamov 24cbf13
add datastore_boundary as explicit attribute
sadamov 1d53ce7
fixing up forecast type data tests,
sadamov cfe1e27
time step can and should be retrieved in __init__
sadamov e4e4e37
Fix dataset issue in npy stat script
joeloskarsson 3df3fcb
Merge remote-tracking branch 'mllam/main' into feat/boundary_dataloader
sadamov f8613da
added static feature to era5 boundary test datastore
sadamov f0a7046
Merge remote-tracking branch 'mllam/main' into feat/boundary_dataloader
sadamov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,7 +108,9 @@ Once `neural-lam` is installed you will be able to train/evaluate models. For th | |
interface that provides the data in a data-structure that can be used within | ||
neural-lam. A datastore is used to create a `pytorch.Dataset`-derived | ||
class that samples the data in time to create individual samples for | ||
training, validation and testing. | ||
training, validation and testing. A secondary datastore can be provided | ||
for the boundary data. Currently, boundary datastore must be of type `mdp` | ||
and only contain forcing features. This can easily be expanded in the future. | ||
|
||
2. **The graph structure** is used to define message-passing GNN layers, | ||
that are trained to emulate fluid flow in the atmosphere over time. The | ||
|
@@ -121,7 +123,7 @@ different aspects about the training and evaluation of the model. | |
|
||
The path you provide to the neural-lam config (`config.yaml`) also sets the | ||
root directory relative to which all other paths are resolved, as in the parent | ||
directory of the config becomes the root directory. Both the datastore and | ||
directory of the config becomes the root directory. Both the datastores and | ||
graphs you generate are then stored in subdirectories of this root directory. | ||
Exactly how and where a specific datastore expects its source data to be stored | ||
and where it stores its derived data is up to the implementation of the | ||
|
@@ -134,6 +136,7 @@ assume you placed `config.yaml` in a folder called `data`): | |
data/ | ||
├── config.yaml - Configuration file for neural-lam | ||
├── danra.datastore.yaml - Configuration file for the datastore, referred to from config.yaml | ||
├── era5.datastore.zarr/ - Optional configuration file for the boundary datastore, referred to from config.yaml | ||
└── graphs/ - Directory containing graphs for training | ||
``` | ||
|
||
|
@@ -142,18 +145,20 @@ And the content of `config.yaml` could in this case look like: | |
datastore: | ||
kind: mdp | ||
config_path: danra.datastore.yaml | ||
datastore_boundary: | ||
kind: mdp | ||
config_path: era5.datastore.yaml | ||
training: | ||
state_feature_weighting: | ||
__config_class__: ManualStateFeatureWeighting | ||
values: | ||
weights: | ||
u100m: 1.0 | ||
v100m: 1.0 | ||
``` | ||
|
||
For now the neural-lam config only defines two things: 1) the kind of data | ||
store and the path to its config, and 2) the weighting of different features in | ||
the loss function. If you don't define the state feature weighting it will default | ||
to weighting all features equally. | ||
For now the neural-lam config only defines two things: | ||
1) the kind of datastores and the path to their config | ||
2) the weighting of different features in the loss function. If you don't define the state feature weighting it will default to weighting all features equally. | ||
|
||
(This example is taken from the `tests/datastore_examples/mdp` directory.) | ||
|
||
|
@@ -525,5 +530,4 @@ Furthermore, all tests in the ```tests``` directory will be run upon pushing cha | |
|
||
# Contact | ||
If you are interested in machine learning models for LAM, have questions about the implementation or ideas for extending it, feel free to get in touch. | ||
There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join (after following the link you have to request to join, this is to avoid spam bots). | ||
You can also open a github issue on this page, or (if more suitable) send an email to [[email protected]](mailto:[email protected]). | ||
There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join. You can also open a github issue on this page, or (if more suitable) send an email to [[email protected]](mailto:[email protected]). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Standard library | ||
import copy | ||
import warnings | ||
from functools import cached_property | ||
from pathlib import Path | ||
|
@@ -26,11 +27,10 @@ class MDPDatastore(BaseRegularGridDatastore): | |
|
||
SHORT_NAME = "mdp" | ||
|
||
def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): | ||
def __init__(self, config_path, reuse_existing=True): | ||
""" | ||
Construct a new MDPDatastore from the configuration file at | ||
`config_path`. A boundary mask is created with `n_boundary_points` | ||
boundary points. If `reuse_existing` is True, the dataset is loaded | ||
`config_path`. If `reuse_existing` is True, the dataset is loaded | ||
from a zarr file if it exists (unless the config has been modified | ||
since the zarr was created), otherwise it is created from the | ||
configuration file. | ||
|
@@ -41,8 +41,6 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): | |
The path to the configuration file, this will be fed to the | ||
`mllam_data_prep.Config.from_yaml_file` method to then call | ||
`mllam_data_prep.create_dataset` to create the dataset. | ||
n_boundary_points : int | ||
The number of boundary points to use in the boundary mask. | ||
reuse_existing : bool | ||
Whether to reuse an existing dataset zarr file if it exists and its | ||
creation date is newer than the configuration file. | ||
|
@@ -69,7 +67,6 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): | |
if self._ds is None: | ||
self._ds = mdp.create_dataset(config=self._config) | ||
self._ds.to_zarr(fp_ds) | ||
self._n_boundary_points = n_boundary_points | ||
|
||
print("The loaded datastore contains the following features:") | ||
for category in ["state", "forcing", "static"]: | ||
|
@@ -157,8 +154,8 @@ def get_vars_units(self, category: str) -> List[str]: | |
The units of the variables in the given category. | ||
|
||
""" | ||
if category not in self._ds and category == "forcing": | ||
warnings.warn("no forcing data found in datastore") | ||
if category not in self._ds: | ||
warnings.warn(f"no {category} data found in datastore") | ||
return [] | ||
return self._ds[f"{category}_feature_units"].values.tolist() | ||
|
||
|
@@ -176,8 +173,8 @@ def get_vars_names(self, category: str) -> List[str]: | |
The names of the variables in the given category. | ||
|
||
""" | ||
if category not in self._ds and category == "forcing": | ||
warnings.warn("no forcing data found in datastore") | ||
if category not in self._ds: | ||
warnings.warn(f"no {category} data found in datastore") | ||
return [] | ||
return self._ds[f"{category}_feature"].values.tolist() | ||
|
||
|
@@ -196,8 +193,8 @@ def get_vars_long_names(self, category: str) -> List[str]: | |
The long names of the variables in the given category. | ||
|
||
""" | ||
if category not in self._ds and category == "forcing": | ||
warnings.warn("no forcing data found in datastore") | ||
if category not in self._ds: | ||
warnings.warn(f"no {category} data found in datastore") | ||
return [] | ||
return self._ds[f"{category}_feature_long_name"].values.tolist() | ||
|
||
|
@@ -252,9 +249,9 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: | |
The xarray DataArray object with processed dataset. | ||
|
||
""" | ||
if category not in self._ds and category == "forcing": | ||
warnings.warn("no forcing data found in datastore") | ||
return None | ||
if category not in self._ds: | ||
warnings.warn(f"no {category} data found in datastore") | ||
return [] | ||
|
||
da_category = self._ds[category] | ||
|
||
|
@@ -318,37 +315,6 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: | |
ds_stats = self._ds[stats_variables.keys()].rename(stats_variables) | ||
return ds_stats | ||
|
||
@cached_property | ||
def boundary_mask(self) -> xr.DataArray: | ||
""" | ||
Produce a 0/1 mask for the boundary points of the dataset, these will | ||
sit at the edges of the domain (in x/y extent) and will be used to mask | ||
out the boundary points from the loss function and to overwrite the | ||
boundary points from the prediction. For now this is created when the | ||
mask is requested, but in the future this could be saved to the zarr | ||
file. | ||
|
||
Returns | ||
------- | ||
xr.DataArray | ||
A 0/1 mask for the boundary points of the dataset, where 1 is a | ||
boundary point and 0 is not. | ||
|
||
""" | ||
ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) | ||
da_state_variable = ( | ||
ds_unstacked["state"].isel(time=0).isel(state_feature=0) | ||
) | ||
da_domain_allzero = xr.zeros_like(da_state_variable) | ||
ds_unstacked["boundary_mask"] = da_domain_allzero.isel( | ||
x=slice(self._n_boundary_points, -self._n_boundary_points), | ||
y=slice(self._n_boundary_points, -self._n_boundary_points), | ||
) | ||
ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( | ||
1 | ||
).astype(int) | ||
return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) | ||
|
||
@property | ||
def coords_projection(self) -> ccrs.Projection: | ||
""" | ||
|
@@ -394,7 +360,9 @@ def coords_projection(self) -> ccrs.Projection: | |
|
||
class_name = projection_info["class_name"] | ||
ProjectionClass = getattr(ccrs, class_name) | ||
kwargs = projection_info["kwargs"] | ||
# need to copy otherwise we modify the dict stored in the dataclass | ||
# in-place | ||
kwargs = copy.deepcopy(projection_info["kwargs"]) | ||
|
||
globe_kwargs = kwargs.pop("globe", {}) | ||
if len(globe_kwargs) > 0: | ||
|
@@ -412,8 +380,17 @@ def grid_shape_state(self): | |
The shape of the cartesian grid for the state variables. | ||
|
||
""" | ||
ds_state = self.unstack_grid_coords(self._ds["state"]) | ||
da_x, da_y = ds_state.x, ds_state.y | ||
# Boundary data often has no state features | ||
if "state" not in self._ds: | ||
warnings.warn( | ||
"no state data found in datastore" | ||
"returning grid shape from forcing data" | ||
) | ||
ds_forcing = self.unstack_grid_coords(self._ds["forcing"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could maybe just call this |
||
da_x, da_y = ds_forcing.x, ds_forcing.y | ||
else: | ||
ds_state = self.unstack_grid_coords(self._ds["state"]) | ||
da_x, da_y = ds_state.x, ds_state.y | ||
assert da_x.ndim == da_y.ndim == 1 | ||
return CartesianGridShape(x=da_x.size, y=da_y.size) | ||
|
||
|
sadamov marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could rename the function to
load_datastores
(rather thanload_datastore
) and update the docstring accordingly ? :)