-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/boundary dataloader #90
base: main
Are you sure you want to change the base?
Conversation
combine two slicing fcts into one
neural_lam/weather_dataset.py
Outdated
idx = np.abs( | ||
da_forcing_boundary.analysis_time.values | ||
- self.da_state.analysis_time.values[idx] | ||
).argmin() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the case where boundary forcing time steps are different than state time steps:
Thinking about how this will be used downstream in the model I am not sure about the choice to find the closest analysis time. I would rather for state time t find the previous or equal boundary forcing time (<= t). The point being that the same index of boundary forcing features always should correspond to "the previous boundary forcing" or "the next boundary forcing". This relates to the discussed idea of the modeling learning to implicitly interpolate between the boundary forcing at different times.
Consider this setup, with num_past_boundary_steps = 0
and num_future_boundary_steps = 1
.
And let's say we have more num_past_boundary_steps = 1
, but we would make the problem harder for the network, as it has to change between "interpolating" between the first boundary forcings in the window and the last two. I hope this makes sense, I realize this got a bit involved.
I think that what I propose could potentially also simplify the argmin-expressions, if you just keep track of the timedeltas for both the boundary forcing and state time series. It seems to me that this should be doable without comparing to all times, with just .sel
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point, but what if the closest time step from the boundary is +1h in the future of state_time[idx]. With num_past_boundary_steps = 1; num_future_boundary_steps = 0
. In your setup the two boundary steps would be -5 and -11h away from state_time[idx] giving the model a much harder problem than just using the boundary at +1h and the one at -5h.
I agree with .sel being a much simpler and faster solution than my global time diff matrices. ✔️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point, but what if the closest time step from the boundary is +1h in the future of state_time[idx]. With num_past_boundary_steps = 1; num_future_boundary_steps = 0. In your setup the two boundary steps would be -5 and -11h away from state_time[idx] giving the model a much harder problem than just using the boundary at +1h and the one at -5h.
Yes, that's a bad situation, but I think this just means that the user has made a bad choice with num_future_boundary_steps = 0
given their data. Likely you would always want to set that to at least 1 when you have this kind of time-step mismatch. Even in that situation I think it's still preferable over the model having to adjust to the forcing features at one specific index sometimes being from the future and sometimes being from the past. That information is in the delta-time features, but using those deltas in that way seems like a harder problem for the model to untangle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so basically you are saying its easier for the model to have forcing either (1) in the past or (2) in the past and future. Than it is to have forcing features in (1) + (2) and (3) only in the future. Okay I can see that. and there might be some scenario where it is technically or scientifically not valid to choose data from the future. So having the default to be the closest past time step is sound. In addition we can use xarray.sel() which is more efficient and lazy. ok I'm convinced. will get to work on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess another simple way to put what I mean is: If I look at one specific feature dimension in the boundary forcing tensor (e.g. boundary_forcing[..., 21]) , I think it's good if the value there either always comes from the past or always comes from the future. That seems easier for the model to handle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with you, implemented this with xarray's .sel
and method=pad
, which does exactly what you describe.
cfe1e27
Sorry this commit is a bit overloaded, introducing more than one change. Let me know if confusing!
neural_lam/weather_dataset.py
Outdated
da_windowed = xr.concat( | ||
[da_windowed, da_windowed.time_diff_steps], | ||
dim="forcing_feature_windowed", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How will the concatenation of these time deltas work with e.g. BaseDatastore.get_vars_names(category="forcing")
? I guess they will not be included there. But that also does not reflect the forcing windowing, so maybe one should not make the assumption that what is returned there maps to the indices of the forcing tensor returned from the dataset? That would require the datastore to have knowledge of the WeatherDataset config, which we probably don't want 🤔
I fully agree that we need these time-delta features, but it's tricky how to handle them. We've been talking about doing some form of sinusoidal embeddings of them later in the model. Would it even make sense to return these separately? Since if we want to do something special (embedding) with these features we would otherwise have to extract them from the boundary forcing tensor (assuming that they are the last features). If we would return them separately we should maybe create some form of custom batch-object, as it would start to be a little unstructured to have this tuple of 7 objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, BaseDatastore.get_vars_names(category="forcing") will not return the windowed_features + temporal_embeddings. This is important for saving model predictions in #89, too. We don't need them in the model output, but want to keep the original features exactly in order.
- ordering the forcing and boundary return, to always have the temporal embedding as the last feature, sound like an easy option for now
- I am still a bit confused how exactly the temporal embedding will be used in the model. The temporal embedding is the same for all
grid_points, time
and maps towindowed_forcing_features
based on the window-size. Can one just leave these additional temporal embeddings as additional input feature channels (sin(embedding)), or should we directly merge them with the forcing_feature in some sort of batch-norm in ar_model? Is there a world where we can already merge the embedding and windowed_forcing_features in weatherDataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can one just leave these additional temporal embeddings as additional input feature channels (sin(embedding))
I guess we should leave it open to different model implementations what to do with these. What I am thinking of first is to just do the sin(embedding) and feed them as additional input feature channels. Then we could expand on that later as needed.
or should we directly merge them with the forcing_feature in some sort of batch-norm in ar_model?
That's another option, to map from these features to the mean and std-dev used in LayerNorms (akin to how time-embeddings are often used in diffusion-models and others). It's a bit trickier to implement, and might not really be neccesary here, but still potentially interesting.
Is there a world where we can already merge the embedding and windowed_forcing_features in weatherDataset?
I don't think that would be possible (considering the LayerNorm setup discussed above). As the use of theses features in the LayerNorms involves trainable parameters, and they act on intermediate representations in the model, this would have to be done on the model side of things.
I suppose we could do the sin(embedding) already in WeatherDataset, if that does not involve trainable parameters. But I don't think we want to, as it should be better to do those computations on GPU (similar to how we want to move standardization to GPU).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay thanks for the explanations. While not totally relevent here, it's still good to have the bigger picture in mind. so for now I just focus un returning all temporal embedding features as the last channels in forcing
or boundary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I fixed this now in cfe1e27
the temporal embedding has size window
and is not repeated for each feature.
the embedding is concatenated to the existing forcing_feature_windowed
at the very end
+ some more comments
neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py
Outdated
Show resolved
Hide resolved
more and better defined scenarios
match of state with forcing/boundary is now done with .sel and "pad" renaming some variables to make the code easier to read fixing the temporal encoding to only include embeddings for window-size
I think that if we merge |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First read through: this was joy to read 😊 I have a few comments, but in general I think this is very close to done.
@@ -168,4 +172,15 @@ def load_config_and_datastore( | |||
datastore_kind=config.datastore.kind, config_path=datastore_config_path | |||
) | |||
|
|||
return config, datastore | |||
if config.datastore_boundary is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could rename the function to load_datastores
(rather than load_datastore
) and update the docstring accordingly ? :)
"no state data found in datastore" | ||
"returning grid shape from forcing data" | ||
) | ||
ds_forcing = self.unstack_grid_coords(self._ds["forcing"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could maybe just call this da_grid_reference
or something and then unstack etc
datastore : BaseDatastore | ||
The datastore to load the data from (e.g. mdp). | ||
datastore_boundary : BaseDatastore | ||
The boundary datastore to load the data from (e.g. mdp). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
datastore : BaseDatastore | |
The datastore to load the data from (e.g. mdp). | |
datastore_boundary : BaseDatastore | |
The boundary datastore to load the data from (e.g. mdp). | |
datastore : BaseDatastore | |
The datastore to load the data from. | |
datastore_boundary : BaseDatastore | |
The boundary datastore to load the data from. |
I've realised my docstring before doens't make much sense, mdp
is the shorthand name. But people need to provide an instance here not a string
|
||
self.da_state = self.datastore.get_dataarray( | ||
category="state", split=self.split | ||
) | ||
if self.da_state is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, we have relaxes the requirement that a datastore must always return state
category data :) So this line in the docstring should be changed too: https://github.com/mllam/neural-lam/pull/90/files#diff-e9c07db146431708a50478784657ea66de598c734fda9f159dfdfe6f2a08002eR197
self.da_forcing = self.datastore.get_dataarray( | ||
category="forcing", split=self.split | ||
) | ||
# XXX For now boundary data is always considered mdp-forcing data | ||
if self.datastore_boundary is not None: | ||
self.da_boundary = self.datastore_boundary.get_dataarray( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.da_boundary = self.datastore_boundary.get_dataarray( | |
self.da_boundary_forcing = self.datastore_boundary.get_dataarray( |
Could we call it boundary_forcing
instead? That makes it clearer that it is only the forcing on the boundary
def _process_windowed_data(self, da_windowed, da_state, da_target_times): | ||
"""Helper function to process windowed data. This function stacks the | ||
'forcing_feature' and 'window' dimensions and adds the time step | ||
differences to the existing features as a temporal embedding. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it maybe make sense to split adding the temporal embedding in a separate function? To me this doesn't anything to do with the windowing, but maybe I am missing something. Something like _add_temporal_embedding(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is highly related. The temporal embedding (which we should probably not refer to as such, as this is not an embedding, so from here on: the delta times) is the time differences for the different entries in the window. After these have been stacked it is not clear to me how you would know what the delta-times are, since you've got rid of the window dimension.
da_forcing_matched["window"] | ||
* (forcing_time_step / state_time_step), | ||
) | ||
time_diff_steps = da_forcing_matched.isel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Until I read this line I had misunderstood what time_diff
was referring to :D I thought it meant the difference in the time variable (as in time increments) rather than change in forcing over time. Could we maybe call it forcing_time_diff
instead? If you think that makes sense
@@ -66,10 +82,20 @@ def test_dataset_item_shapes(datastore_name): | |||
assert forcing.ndim == 3 | |||
assert forcing.shape[0] == N_pred_steps | |||
assert forcing.shape[1] == N_gridpoints | |||
assert forcing.shape[2] == datastore.get_num_data_vars("forcing") * ( | |||
# each time step in the window has one corresponding temporal embedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since this isn't actually an embedding of time (in the way that for example LLMs encode word order as a "temporal embedding") it might be more correct to call this extra feature as simply the forcing increments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree that this should not be referred to as an embedding (this is rather what we will use to create an embedding later). I've also been sloppily calling it this 😅 I think something like time delta is more clear (forcing increments makes me think it is the increment in forcing values, rather than in time).
@@ -16,40 +16,76 @@ class SinglePointDummyDatastore(BaseDatastore): | |||
root_path = None | |||
|
|||
def __init__(self, time_values, state_data, forcing_data, is_forecast): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you for making this datastore for testing more complete!
# Compute expected initial states and target states based on ar_steps | ||
offset = max(0, num_past_forcing_steps - INIT_STEPS) | ||
init_idx = INIT_STEPS + offset | ||
expected_init_states = STATE_VALUES_FORECAST[0][offset:init_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests are so great. I was just thinking it might be nice to make it explicit here that this 0
-index is due to the fact we only have implemented sampling from the first ensemble member. That is the reason why? Then it will be clearer how to update the tests once we add sampling more ensemble members
This PR introduces support for handling boundary data through an additional optional datastore. This allows forcing the model's interior state with boundary conditions during training and evaluation. This PR introduces boundary data handling up until
WeatherDataset.__getitem__
which now returns 5 elementsinit_states, target_states, forcing, **boundary**, target_times
(#84 will later implement boundary handling on the model side).Currently boundary datastore must be of type mdp containing forcing category data. The code wa developed in a way that this requirement can be easily relaxed in the future.
Motivation and Context:
In weather modeling incorporating realistic boundary conditions is crucial for accurately simulating atmospheric processes. By providing an optional boundary datastore, the model gains flexibility to accept external boundary data, enhancing prediction accuracy and allowing for more complex simulations. The current workflow using n outmost gridcells as boundary will be fully replaced.
The creation of a new ERA5 WeatherBench MDP example datastore demonstrates the implementation of these features and serves as a template for future models.
Key Changes:
num_past/future_forcing_steps
andnum_past/future_boundary_steps
as separate arguments and CLI flags.time
(using minimal time differences between state and forcing/boundary).Bugfix:
Notes:
time
,analysis_time
andelapsed_forecast_time
all have a consistent step size. This is not strictly necessary but probably what most user want and it helps with the temporal encoding.Introduced Dependencies
TODOs:
Type of change
Checklist before requesting a review
pull
with--rebase
option if possible).Checklist for reviewers
Each PR comes with its own improvements and flaws. The reviewer should check the following:
Author checklist after completed review
reflecting type of change (add section where missing):
Checklist for assignee