Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove batch-static tensor from dataset class and models #13

Merged
merged 3 commits into from
Mar 18, 2024

Conversation

joeloskarsson
Copy link
Collaborator

The batch-static tensor contained forcing that differed between initialization times, but stayed static for all lead times of a forecast. For the MEPS data we used this for the land-water-mask, as this could be different throughout the year, but we could not produce separate values per lead time (as all other forcing).

This PR removes the batch-static features as an explicit extra input. The motivation is:

  1. Having such input features is quite a rare and a highly specific case.
  2. If such inputs exists, it is better to just treat them as any other type of forcing. Then the values have to be repeated over the temporal dimension, but this can either be handled in pre-processing or easily in the Dataset class. In this PR the MEPS Dataset class is changed to take this approach.
  3. Needing to pass around the batch-static features clutter up the code. For most dataset they would not be used, requiring constant special checks for if they are None.

This PR changes:

  1. Bake the batch-static features into the normal forcing in the MEPS Dataset class.
  2. Change the Dataset class to only return 3 tensors per sample (init, target, forcing).
  3. Remove the batch-static tensor from being extracted from the batch and passed around in the graph-based models. This while making sure that input dimensions line up so older checkpoints can still be loaded correctly.

@@ -115,7 +112,6 @@ def predict_step(
(
prev_state,
prev_prev_state,
batch_static_features,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the batch-static features are now put as the first feature dimension in forcing. Earlier they were stacked right on top of forcing in this tensor. This results in no change to how grid_features looks like for a sample. Importantly, this means that models trained before this PR can be loaded and works without any problems.

@joeloskarsson joeloskarsson requested a review from sadamov March 18, 2024 08:17
@joeloskarsson
Copy link
Collaborator Author

@sadamov Hope it's ok that I put you to review PRs like this :) I think it's valuable to get a second pair of eyes to look at the changes, and also good for you to get an update on small things I am changing.

The changes to the MEPS Dataset class here are not very important, this is really motivated by moving away from things being too specific for that data.

Copy link
Collaborator

@sadamov sadamov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with both the general direction of and the explicit changes to the codebase.

  • In general, making the dataloader more flexible and reducing the complexity of input feature types, allows for easier onboarding of new collaborators.
  • I tested the explicit changes with the meps_example dataset and training is successful without batch_static_features.

@joeloskarsson
Copy link
Collaborator Author

Thanks for taking a look!

I just realized I forgot to change create_parameter_weights.py, as the Dataset class is used in there also. Will fix that (should only be a tiny change of index) and then merge.

@joeloskarsson joeloskarsson merged commit b0050b9 into main Mar 18, 2024
1 check passed
@joeloskarsson joeloskarsson deleted the remove_batch_static branch March 18, 2024 09:56
joeloskarsson added a commit to gitvicky/neural-lam-CP that referenced this pull request Apr 17, 2024
Squashed commit of the following:

commit b0050b9
Author: Joel Oskarsson <[email protected]>
Date:   Mon Mar 18 10:56:45 2024 +0100

    Remove batch-static tensor from dataset class and models (mllam#13)

    * Bake the batch-static features into the normal forcing in the MEPS Dataset class.
    * Change the Dataset class to only return 3 tensors per sample (init, target, forcing).
    * Remove the batch-static tensor from being extracted from the batch and passed around in the graph-based models. This while making sure that input dimensions line up so older checkpoints can still be loaded correctly.

commit 0669ff4
Author: Joel Oskarsson <[email protected]>
Date:   Thu Feb 29 11:50:27 2024 +0100

    Re-define RMSE metric to take sqrt after sample averaging (mllam#10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants