Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/boundary dataloader #90

Open
wants to merge 88 commits into
base: main
Choose a base branch
from

Conversation

sadamov
Copy link
Collaborator

@sadamov sadamov commented Nov 21, 2024

This PR introduces support for handling boundary data through an additional optional datastore. This allows forcing the model's interior state with boundary conditions during training and evaluation. This PR introduces boundary data handling up until WeatherDataset.__getitem__ which now returns 5 elements init_states, target_states, forcing, **boundary**, target_times (#84 will later implement boundary handling on the model side).
Currently boundary datastore must be of type mdp containing forcing category data. The code wa developed in a way that this requirement can be easily relaxed in the future.

Motivation and Context:

In weather modeling incorporating realistic boundary conditions is crucial for accurately simulating atmospheric processes. By providing an optional boundary datastore, the model gains flexibility to accept external boundary data, enhancing prediction accuracy and allowing for more complex simulations. The current workflow using n outmost gridcells as boundary will be fully replaced.
The creation of a new ERA5 WeatherBench MDP example datastore demonstrates the implementation of these features and serves as a template for future models.

Key Changes:

  • Introduced num_past/future_forcing_steps and num_past/future_boundary_steps as separate arguments and CLI flags.
  • Created a new ERA5 WeatherBench MDP example datastore: Developed an example datastore using ERA5 WeatherBench data within the Model-Data-Platform (MDP) framework.
  • Ensured num_grid_points also works for the forcing datastore (i.e. not containing state vars)
  • Combined time slicing for state and forcing and make it timestep-based. The same function now handles state forcing and boundary data and matched the corresponding samples based on actual time (using minimal time differences between state and forcing/boundary).
  • Updated global configuration to support two datastores: Modified the global configuration to accommodate both state and boundary/forcing datastores simultaneously.
  • Added temporal embedding of time_step_diffs: Introduced temporal embeddings by adding time step differences as additional input features to improve the model's temporal awareness.
  • Removed interior/exterior masks from the codebase: Eliminated the use of data masks to simplify the codebase. This is now handled by having two separate datastores
  • Accept boundary data in common_step but do not return it: Modified the common_step method to accept boundary data as input without returning it, preparing for future boundary/state implementation on the model side.
  • Moved the handling of ensemble members to the initialization of WeatherDataset to prevent duplicate time steps in the dataset.

Bugfix:

  • Fixed a bug in analysis time retrieval in the MEPS store: Corrected an issue with retrieving analysis_time from the MEPS data store where duplicate time steps for one member were returned.

Notes:

  • I have introduced strict assertion that time , analysis_time and elapsed_forecast_time all have a consistent step size. This is not strictly necessary but probably what most user want and it helps with the temporal encoding.

Introduced Dependencies

  • "gcsfs>=2021.10.0" Where should we download era5 from, do we want this dependency to Google Cloud Storage?

TODOs:

  • Upload MEPS example with dates close to Danra and next to each other for test/val/train @leifdenby This allows for testing both npy and mdp datastores with the same boundary datastore and limits the required data.

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • author has added an entry to the changelog (and designated the change as added, changed or fixed)
  • Once the PR is ready to be merged, squash commits and merge the PR.

sadamov and others added 30 commits November 21, 2024 13:49
combine two slicing fcts into one
Comment on lines 396 to 399
idx = np.abs(
da_forcing_boundary.analysis_time.values
- self.da_state.analysis_time.values[idx]
).argmin()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the case where boundary forcing time steps are different than state time steps:

Thinking about how this will be used downstream in the model I am not sure about the choice to find the closest analysis time. I would rather for state time t find the previous or equal boundary forcing time (<= t). The point being that the same index of boundary forcing features always should correspond to "the previous boundary forcing" or "the next boundary forcing". This relates to the discussed idea of the modeling learning to implicitly interpolate between the boundary forcing at different times.

Consider this setup, with num_past_boundary_steps = 0 and num_future_boundary_steps = 1.
boundary_forcing_timescales
And let's say we have more $X$ in between 0 h and 6 h. If I am not misunderstanding this, with the current approach $X_\text{1h}$ would be paired with a window of $F_\text{0h}$ and $F_\text{6h}$ (as desired). But for $X_\text{5h}$, the closest boundary forcing time is 6 h, meaning it would be paired with $F_\text{6h}$ and $F_\text{12h}$. This makes this "implicit interpolation" to boundary forcing at 5 h impossible, as the information from $F_\text{0h}$ is not included. Now one could technically solve this by setting num_past_boundary_steps = 1, but we would make the problem harder for the network, as it has to change between "interpolating" between the first boundary forcings in the window and the last two. I hope this makes sense, I realize this got a bit involved.

I think that what I propose could potentially also simplify the argmin-expressions, if you just keep track of the timedeltas for both the boundary forcing and state time series. It seems to me that this should be doable without comparing to all times, with just .sel.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point, but what if the closest time step from the boundary is +1h in the future of state_time[idx]. With num_past_boundary_steps = 1; num_future_boundary_steps = 0. In your setup the two boundary steps would be -5 and -11h away from state_time[idx] giving the model a much harder problem than just using the boundary at +1h and the one at -5h.

I agree with .sel being a much simpler and faster solution than my global time diff matrices. ✔️

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point, but what if the closest time step from the boundary is +1h in the future of state_time[idx]. With num_past_boundary_steps = 1; num_future_boundary_steps = 0. In your setup the two boundary steps would be -5 and -11h away from state_time[idx] giving the model a much harder problem than just using the boundary at +1h and the one at -5h.

Yes, that's a bad situation, but I think this just means that the user has made a bad choice with num_future_boundary_steps = 0 given their data. Likely you would always want to set that to at least 1 when you have this kind of time-step mismatch. Even in that situation I think it's still preferable over the model having to adjust to the forcing features at one specific index sometimes being from the future and sometimes being from the past. That information is in the delta-time features, but using those deltas in that way seems like a harder problem for the model to untangle.

Copy link
Collaborator Author

@sadamov sadamov Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so basically you are saying its easier for the model to have forcing either (1) in the past or (2) in the past and future. Than it is to have forcing features in (1) + (2) and (3) only in the future. Okay I can see that. and there might be some scenario where it is technically or scientifically not valid to choose data from the future. So having the default to be the closest past time step is sound. In addition we can use xarray.sel() which is more efficient and lazy. ok I'm convinced. will get to work on this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess another simple way to put what I mean is: If I look at one specific feature dimension in the boundary forcing tensor (e.g. boundary_forcing[..., 21]) , I think it's good if the value there either always comes from the past or always comes from the future. That seems easier for the model to handle.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with you, implemented this with xarray's .sel and method=pad, which does exactly what you describe.
cfe1e27
Sorry this commit is a bit overloaded, introducing more than one change. Let me know if confusing!

Comment on lines 522 to 525
da_windowed = xr.concat(
[da_windowed, da_windowed.time_diff_steps],
dim="forcing_feature_windowed",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How will the concatenation of these time deltas work with e.g. BaseDatastore.get_vars_names(category="forcing")? I guess they will not be included there. But that also does not reflect the forcing windowing, so maybe one should not make the assumption that what is returned there maps to the indices of the forcing tensor returned from the dataset? That would require the datastore to have knowledge of the WeatherDataset config, which we probably don't want 🤔

I fully agree that we need these time-delta features, but it's tricky how to handle them. We've been talking about doing some form of sinusoidal embeddings of them later in the model. Would it even make sense to return these separately? Since if we want to do something special (embedding) with these features we would otherwise have to extract them from the boundary forcing tensor (assuming that they are the last features). If we would return them separately we should maybe create some form of custom batch-object, as it would start to be a little unstructured to have this tuple of 7 objects.

Copy link
Collaborator Author

@sadamov sadamov Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, BaseDatastore.get_vars_names(category="forcing") will not return the windowed_features + temporal_embeddings. This is important for saving model predictions in #89, too. We don't need them in the model output, but want to keep the original features exactly in order.

  • ordering the forcing and boundary return, to always have the temporal embedding as the last feature, sound like an easy option for now
  • I am still a bit confused how exactly the temporal embedding will be used in the model. The temporal embedding is the same for all grid_points, time and maps to windowed_forcing_features based on the window-size. Can one just leave these additional temporal embeddings as additional input feature channels (sin(embedding)), or should we directly merge them with the forcing_feature in some sort of batch-norm in ar_model? Is there a world where we can already merge the embedding and windowed_forcing_features in weatherDataset?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can one just leave these additional temporal embeddings as additional input feature channels (sin(embedding))

I guess we should leave it open to different model implementations what to do with these. What I am thinking of first is to just do the sin(embedding) and feed them as additional input feature channels. Then we could expand on that later as needed.

or should we directly merge them with the forcing_feature in some sort of batch-norm in ar_model?

That's another option, to map from these features to the mean and std-dev used in LayerNorms (akin to how time-embeddings are often used in diffusion-models and others). It's a bit trickier to implement, and might not really be neccesary here, but still potentially interesting.

Is there a world where we can already merge the embedding and windowed_forcing_features in weatherDataset?

I don't think that would be possible (considering the LayerNorm setup discussed above). As the use of theses features in the LayerNorms involves trainable parameters, and they act on intermediate representations in the model, this would have to be done on the model side of things.

I suppose we could do the sin(embedding) already in WeatherDataset, if that does not involve trainable parameters. But I don't think we want to, as it should be better to do those computations on GPU (similar to how we want to move standardization to GPU).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay thanks for the explanations. While not totally relevent here, it's still good to have the bigger picture in mind. so for now I just focus un returning all temporal embedding features as the last channels in forcing or boundary

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I fixed this now in cfe1e27
the temporal embedding has size window and is not repeated for each feature.
the embedding is concatenated to the existing forcing_feature_windowed at the very end

sadamov and others added 4 commits December 5, 2024 13:26
more and better defined scenarios
match of state with forcing/boundary is now done with .sel and "pad"
renaming some variables to make the code easier to read
fixing the temporal encoding to only include embeddings for window-size
@sadamov sadamov marked this pull request as ready for review December 5, 2024 12:51
@joeloskarsson
Copy link
Collaborator

I think that if we merge main into here the tests will pass. They seem to have failed for the same reason as on main, which was fixed with #94.

Copy link
Member

@leifdenby leifdenby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First read through: this was joy to read 😊 I have a few comments, but in general I think this is very close to done.

@@ -168,4 +172,15 @@ def load_config_and_datastore(
datastore_kind=config.datastore.kind, config_path=datastore_config_path
)

return config, datastore
if config.datastore_boundary is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could rename the function to load_datastores (rather than load_datastore) and update the docstring accordingly ? :)

"no state data found in datastore"
"returning grid shape from forcing data"
)
ds_forcing = self.unstack_grid_coords(self._ds["forcing"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could maybe just call this da_grid_reference or something and then unstack etc

Comment on lines 23 to +26
datastore : BaseDatastore
The datastore to load the data from (e.g. mdp).
datastore_boundary : BaseDatastore
The boundary datastore to load the data from (e.g. mdp).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
datastore : BaseDatastore
The datastore to load the data from (e.g. mdp).
datastore_boundary : BaseDatastore
The boundary datastore to load the data from (e.g. mdp).
datastore : BaseDatastore
The datastore to load the data from.
datastore_boundary : BaseDatastore
The boundary datastore to load the data from.

I've realised my docstring before doens't make much sense, mdp is the shorthand name. But people need to provide an instance here not a string


self.da_state = self.datastore.get_dataarray(
category="state", split=self.split
)
if self.da_state is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, we have relaxes the requirement that a datastore must always return state category data :) So this line in the docstring should be changed too: https://github.com/mllam/neural-lam/pull/90/files#diff-e9c07db146431708a50478784657ea66de598c734fda9f159dfdfe6f2a08002eR197

self.da_forcing = self.datastore.get_dataarray(
category="forcing", split=self.split
)
# XXX For now boundary data is always considered mdp-forcing data
if self.datastore_boundary is not None:
self.da_boundary = self.datastore_boundary.get_dataarray(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.da_boundary = self.datastore_boundary.get_dataarray(
self.da_boundary_forcing = self.datastore_boundary.get_dataarray(

Could we call it boundary_forcing instead? That makes it clearer that it is only the forcing on the boundary

def _process_windowed_data(self, da_windowed, da_state, da_target_times):
"""Helper function to process windowed data. This function stacks the
'forcing_feature' and 'window' dimensions and adds the time step
differences to the existing features as a temporal embedding.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it maybe make sense to split adding the temporal embedding in a separate function? To me this doesn't anything to do with the windowing, but maybe I am missing something. Something like _add_temporal_embedding(...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is highly related. The temporal embedding (which we should probably not refer to as such, as this is not an embedding, so from here on: the delta times) is the time differences for the different entries in the window. After these have been stacked it is not clear to me how you would know what the delta-times are, since you've got rid of the window dimension.

da_forcing_matched["window"]
* (forcing_time_step / state_time_step),
)
time_diff_steps = da_forcing_matched.isel(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Until I read this line I had misunderstood what time_diff was referring to :D I thought it meant the difference in the time variable (as in time increments) rather than change in forcing over time. Could we maybe call it forcing_time_diff instead? If you think that makes sense

@@ -66,10 +82,20 @@ def test_dataset_item_shapes(datastore_name):
assert forcing.ndim == 3
assert forcing.shape[0] == N_pred_steps
assert forcing.shape[1] == N_gridpoints
assert forcing.shape[2] == datastore.get_num_data_vars("forcing") * (
# each time step in the window has one corresponding temporal embedding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this isn't actually an embedding of time (in the way that for example LLMs encode word order as a "temporal embedding") it might be more correct to call this extra feature as simply the forcing increments?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that this should not be referred to as an embedding (this is rather what we will use to create an embedding later). I've also been sloppily calling it this 😅 I think something like time delta is more clear (forcing increments makes me think it is the increment in forcing values, rather than in time).

@@ -16,40 +16,76 @@ class SinglePointDummyDatastore(BaseDatastore):
root_path = None

def __init__(self, time_values, state_data, forcing_data, is_forecast):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for making this datastore for testing more complete!

# Compute expected initial states and target states based on ar_steps
offset = max(0, num_past_forcing_steps - INIT_STEPS)
init_idx = INIT_STEPS + offset
expected_init_states = STATE_VALUES_FORECAST[0][offset:init_idx]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are so great. I was just thinking it might be nice to make it explicit here that this 0-index is due to the fact we only have implemented sampling from the first ensemble member. That is the reason why? Then it will be clearer how to update the tests once we add sampling more ensemble members

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants