diff --git a/README.md b/README.md index e21b7c2..7a5e5ca 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,9 @@ Once `neural-lam` is installed you will be able to train/evaluate models. For th interface that provides the data in a data-structure that can be used within neural-lam. A datastore is used to create a `pytorch.Dataset`-derived class that samples the data in time to create individual samples for - training, validation and testing. + training, validation and testing. A secondary datastore can be provided + for the boundary data. Currently, boundary datastore must be of type `mdp` + and only contain forcing features. This can easily be expanded in the future. 2. **The graph structure** is used to define message-passing GNN layers, that are trained to emulate fluid flow in the atmosphere over time. The @@ -121,7 +123,7 @@ different aspects about the training and evaluation of the model. The path you provide to the neural-lam config (`config.yaml`) also sets the root directory relative to which all other paths are resolved, as in the parent -directory of the config becomes the root directory. Both the datastore and +directory of the config becomes the root directory. Both the datastores and graphs you generate are then stored in subdirectories of this root directory. Exactly how and where a specific datastore expects its source data to be stored and where it stores its derived data is up to the implementation of the @@ -134,6 +136,7 @@ assume you placed `config.yaml` in a folder called `data`): data/ ├── config.yaml - Configuration file for neural-lam ├── danra.datastore.yaml - Configuration file for the datastore, referred to from config.yaml +├── era5.datastore.zarr/ - Optional configuration file for the boundary datastore, referred to from config.yaml └── graphs/ - Directory containing graphs for training ``` @@ -142,18 +145,20 @@ And the content of `config.yaml` could in this case look like: datastore: kind: mdp config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml training: state_feature_weighting: __config_class__: ManualStateFeatureWeighting - values: + weights: u100m: 1.0 v100m: 1.0 ``` -For now the neural-lam config only defines two things: 1) the kind of data -store and the path to its config, and 2) the weighting of different features in -the loss function. If you don't define the state feature weighting it will default -to weighting all features equally. +For now the neural-lam config only defines two things: +1) the kind of datastores and the path to their config +2) the weighting of different features in the loss function. If you don't define the state feature weighting it will default to weighting all features equally. (This example is taken from the `tests/datastore_examples/mdp` directory.) @@ -525,5 +530,4 @@ Furthermore, all tests in the ```tests``` directory will be run upon pushing cha # Contact If you are interested in machine learning models for LAM, have questions about the implementation or ideas for extending it, feel free to get in touch. -There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join (after following the link you have to request to join, this is to avoid spam bots). -You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). +There is an open [mllam slack channel](https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ) that anyone can join. You can also open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/neural_lam/config.py b/neural_lam/config.py index d3e0969..f887981 100644 --- a/neural_lam/config.py +++ b/neural_lam/config.py @@ -97,11 +97,15 @@ class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard): ---------- datastore : DatastoreSelection The configuration for the datastore to use. + datastore_boundary : Union[DatastoreSelection, None] + The configuration for the boundary datastore to use, if any. If None, + no boundary datastore is used. training : TrainingConfig The configuration for training the model. """ datastore: DatastoreSelection + datastore_boundary: Union[DatastoreSelection, None] = None training: TrainingConfig = dataclasses.field(default_factory=TrainingConfig) class _(dataclass_wizard.JSONWizard.Meta): @@ -168,4 +172,15 @@ def load_config_and_datastore( datastore_kind=config.datastore.kind, config_path=datastore_config_path ) - return config, datastore + if config.datastore_boundary is not None: + datastore_boundary_config_path = ( + Path(config_path).parent / config.datastore_boundary.config_path + ) + datastore_boundary = init_datastore( + datastore_kind=config.datastore_boundary.kind, + config_path=datastore_boundary_config_path, + ) + else: + datastore_boundary = None + + return config, datastore, datastore_boundary diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index b0055e3..e2d2140 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -228,23 +228,6 @@ def get_dataarray( """ pass - @cached_property - @abc.abstractmethod - def boundary_mask(self) -> xr.DataArray: - """ - Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions - `('grid_index',)`. - - """ - pass - @abc.abstractmethod def get_xy(self, category: str) -> np.ndarray: """ diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 0d1aac7..809bbdb 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -27,11 +27,10 @@ class MDPDatastore(BaseRegularGridDatastore): SHORT_NAME = "mdp" - def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): + def __init__(self, config_path, reuse_existing=True): """ Construct a new MDPDatastore from the configuration file at - `config_path`. A boundary mask is created with `n_boundary_points` - boundary points. If `reuse_existing` is True, the dataset is loaded + `config_path`. If `reuse_existing` is True, the dataset is loaded from a zarr file if it exists (unless the config has been modified since the zarr was created), otherwise it is created from the configuration file. @@ -42,8 +41,6 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): The path to the configuration file, this will be fed to the `mllam_data_prep.Config.from_yaml_file` method to then call `mllam_data_prep.create_dataset` to create the dataset. - n_boundary_points : int - The number of boundary points to use in the boundary mask. reuse_existing : bool Whether to reuse an existing dataset zarr file if it exists and its creation date is newer than the configuration file. @@ -70,7 +67,6 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): if self._ds is None: self._ds = mdp.create_dataset(config=self._config) self._ds.to_zarr(fp_ds) - self._n_boundary_points = n_boundary_points print("The loaded datastore contains the following features:") for category in ["state", "forcing", "static"]: @@ -158,8 +154,8 @@ def get_vars_units(self, category: str) -> List[str]: The units of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature_units"].values.tolist() @@ -177,8 +173,8 @@ def get_vars_names(self, category: str) -> List[str]: The names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature"].values.tolist() @@ -197,8 +193,8 @@ def get_vars_long_names(self, category: str) -> List[str]: The long names of the variables in the given category. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") return [] return self._ds[f"{category}_feature_long_name"].values.tolist() @@ -253,9 +249,9 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: The xarray DataArray object with processed dataset. """ - if category not in self._ds and category == "forcing": - warnings.warn("no forcing data found in datastore") - return None + if category not in self._ds: + warnings.warn(f"no {category} data found in datastore") + return [] da_category = self._ds[category] @@ -319,37 +315,6 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: ds_stats = self._ds[stats_variables.keys()].rename(stats_variables) return ds_stats - @cached_property - def boundary_mask(self) -> xr.DataArray: - """ - Produce a 0/1 mask for the boundary points of the dataset, these will - sit at the edges of the domain (in x/y extent) and will be used to mask - out the boundary points from the loss function and to overwrite the - boundary points from the prediction. For now this is created when the - mask is requested, but in the future this could be saved to the zarr - file. - - Returns - ------- - xr.DataArray - A 0/1 mask for the boundary points of the dataset, where 1 is a - boundary point and 0 is not. - - """ - ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds) - da_state_variable = ( - ds_unstacked["state"].isel(time=0).isel(state_feature=0) - ) - da_domain_allzero = xr.zeros_like(da_state_variable) - ds_unstacked["boundary_mask"] = da_domain_allzero.isel( - x=slice(self._n_boundary_points, -self._n_boundary_points), - y=slice(self._n_boundary_points, -self._n_boundary_points), - ) - ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna( - 1 - ).astype(int) - return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask) - @property def coords_projection(self) -> ccrs.Projection: """ @@ -415,8 +380,17 @@ def grid_shape_state(self): The shape of the cartesian grid for the state variables. """ - ds_state = self.unstack_grid_coords(self._ds["state"]) - da_x, da_y = ds_state.x, ds_state.y + # Boundary data often has no state features + if "state" not in self._ds: + warnings.warn( + "no state data found in datastore" + "returning grid shape from forcing data" + ) + ds_forcing = self.unstack_grid_coords(self._ds["forcing"]) + da_x, da_y = ds_forcing.x, ds_forcing.y + else: + ds_state = self.unstack_grid_coords(self._ds["state"]) + da_x, da_y = ds_state.x, ds_state.y assert da_x.ndim == da_y.ndim == 1 return CartesianGridShape(x=da_x.size, y=da_y.size) diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py index f2c80e8..1f1c694 100644 --- a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py +++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py @@ -172,6 +172,7 @@ def main( ar_steps = 63 ds = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=ar_steps, standardize=False, @@ -201,7 +202,7 @@ def main( print("Computing mean and std.-dev. for parameters...") means, squares, flux_means, flux_squares = [], [], [], [] - for init_batch, target_batch, forcing_batch, _ in tqdm(loader): + for init_batch, target_batch, forcing_batch, _, _ in tqdm(loader): if distributed: init_batch, target_batch, forcing_batch = ( init_batch.to(device), @@ -275,6 +276,7 @@ def main( print("Computing mean and std.-dev. for one-step differences...") ds_standard = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=ar_steps, standardize=True, @@ -303,7 +305,7 @@ def main( diff_means, diff_squares = [], [] - for init_batch, target_batch, _, _ in tqdm( + for init_batch, target_batch, _, _, _ in tqdm( loader_standard, disable=rank != 0 ): if distributed: diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 42e8070..24349e7 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -445,7 +445,9 @@ def _get_single_timeseries_dataarray( * np.timedelta64(1, "h") ) elif d == "analysis_time": - coord_values = self._get_analysis_times(split=split) + coord_values = self._get_analysis_times( + split=split, member_id=member + ) elif d == "y": coord_values = y elif d == "x": @@ -505,7 +507,7 @@ def _get_single_timeseries_dataarray( return da - def _get_analysis_times(self, split) -> List[np.datetime64]: + def _get_analysis_times(self, split, member_id) -> List[np.datetime64]: """Get the analysis times for the given split by parsing the filenames of all the files found for the given split. @@ -513,6 +515,8 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: ---------- split : str The dataset split to get the analysis times for. + member_id : int + The ensemble member to get the analysis times for. Returns ------- @@ -520,8 +524,12 @@ def _get_analysis_times(self, split) -> List[np.datetime64]: The analysis times for the given split. """ + if member_id is None: + # Only interior state data files have member_id, to avoid duplicates + # we only look at the first member for all other categories + member_id = 0 pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT) - pattern = re.sub(r"{member_id:[^}]*}", "*", pattern) + pattern = re.sub(r"{member_id:[^}]*}", f"{member_id:03d}", pattern) sample_dir = self.root_path / "samples" / split sample_files = sample_dir.glob(pattern) @@ -668,34 +676,6 @@ def grid_shape_state(self) -> CartesianGridShape: ny, nx = self.config.grid_shape_state return CartesianGridShape(x=nx, y=ny) - @cached_property - def boundary_mask(self) -> xr.DataArray: - """The boundary mask for the dataset. This is a binary mask that is 1 - where the grid cell is on the boundary of the domain, and 0 otherwise. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions `[grid_index]`. - - """ - xy = self.get_xy(category="state", stacked=False) - xs = xy[:, :, 0] - ys = xy[:, :, 1] - # Check if x-coordinates are constant along columns - assert np.allclose(xs, xs[:, [0]]), "x-coordinates are not constant" - # Check if y-coordinates are constant along rows - assert np.allclose(ys, ys[[0], :]), "y-coordinates are not constant" - # Extract unique x and y coordinates - x = xs[:, 0] # Unique x-coordinates (changes along the first axis) - y = ys[0, :] # Unique y-coordinates (changes along the second axis) - values = np.load(self.root_path / "static" / "border_mask.npy") - da_mask = xr.DataArray( - values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask" - ) - da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int) - return da_mask_stacked_xy - def get_standardization_dataarray(self, category: str) -> xr.Dataset: """Return the standardization dataarray for the given category. This should contain a `{category}_mean` and `{category}_std` variable for diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 44baf9c..81d5a62 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -45,7 +45,6 @@ def __init__( da_state_stats = datastore.get_standardization_dataarray( category="state" ) - da_boundary_mask = datastore.boundary_mask num_past_forcing_steps = args.num_past_forcing_steps num_future_forcing_steps = args.num_future_forcing_steps @@ -111,25 +110,14 @@ def __init__( self.grid_dim = ( 2 * self.grid_output_dim + grid_static_dim - + num_forcing_vars + # Temporal Embedding counts as one additional forcing_feature + + (num_forcing_vars + 1) * (num_past_forcing_steps + num_future_forcing_steps + 1) ) # Instantiate loss function self.loss = metrics.get_metric(args.loss) - boundary_mask = torch.tensor( - da_boundary_mask.values, dtype=torch.float32 - ).unsqueeze( - 1 - ) # add feature dim - - self.register_buffer("boundary_mask", boundary_mask, persistent=False) - # Pre-compute interior mask for use in loss function - self.register_buffer( - "interior_mask", 1.0 - self.boundary_mask, persistent=False - ) # (num_grid_nodes, 1), 1 for non-border - self.val_metrics = { "mse": [], } @@ -194,13 +182,6 @@ def configure_optimizers(self): ) return opt - @property - def interior_mask_bool(self): - """ - Get the interior mask as a boolean (N,) mask. - """ - return self.interior_mask[:, 0].to(torch.bool) - @staticmethod def expand_to_batch(x, batch_size): """ @@ -211,18 +192,18 @@ def expand_to_batch(x, batch_size): def predict_step(self, prev_state, prev_prev_state, forcing): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 - prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B, - num_grid_nodes, feature_dim), X_{t-1} forcing: (B, num_grid_nodes, - forcing_dim) + prev_state: (B, num_grid_nodes, feature_dim), X_t + prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} + forcing: (B, num_grid_nodes, forcing_dim) """ raise NotImplementedError("No prediction step implemented") def unroll_prediction(self, init_states, forcing_features, true_states): """ Roll out prediction taking multiple autoregressive steps with model - init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B, - pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps, - num_grid_nodes, d_f) + init_states: (B, 2, num_grid_nodes, d_f) + forcing_features: (B, pred_steps, num_grid_nodes, d_static_f) + true_states: (B, pred_steps, num_grid_nodes, d_f) """ prev_prev_state = init_states[:, 0] prev_state = init_states[:, 1] @@ -232,7 +213,6 @@ def unroll_prediction(self, init_states, forcing_features, true_states): for i in range(pred_steps): forcing = forcing_features[:, i] - border_state = true_states[:, i] pred_state, pred_std = self.predict_step( prev_state, prev_prev_state, forcing @@ -240,19 +220,13 @@ def unroll_prediction(self, init_states, forcing_features, true_states): # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes, # d_f) or None - # Overwrite border with true state - new_state = ( - self.boundary_mask * border_state - + self.interior_mask * pred_state - ) - - prediction_list.append(new_state) + prediction_list.append(pred_state) if self.output_std: pred_std_list.append(pred_std) # Update conditioning states prev_prev_state = prev_state - prev_state = new_state + prev_state = pred_state prediction = torch.stack( prediction_list, dim=1 @@ -268,19 +242,20 @@ def unroll_prediction(self, init_states, forcing_features, true_states): def common_step(self, batch): """ - Predict on single batch batch consists of: init_states: (B, 2, - num_grid_nodes, d_features) target_states: (B, pred_steps, - num_grid_nodes, d_features) forcing_features: (B, pred_steps, - num_grid_nodes, d_forcing), - where index 0 corresponds to index 1 of init_states + Predict on single batch batch consists of: + init_states: (B, 2,num_grid_nodes, d_features) + target_states: (B, pred_steps,num_grid_nodes, d_features) + forcing_features: (B, pred_steps,num_grid_nodes, d_forcing) + boundary_features: (B, pred_steps,num_grid_nodes, d_boundaries) + batch_times: (B, pred_steps) """ - (init_states, target_states, forcing_features, batch_times) = batch + (init_states, target_states, forcing_features, _, batch_times) = batch prediction, pred_std = self.unroll_prediction( init_states, forcing_features, target_states - ) # (B, pred_steps, num_grid_nodes, d_f) - # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, - # pred_steps, num_grid_nodes, d_f) or (d_f,) + ) + # prediction: (B, pred_steps, num_grid_nodes, d_f) + # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) return prediction, target_states, pred_std, batch_times @@ -290,12 +265,14 @@ def training_step(self, batch): """ prediction, target, pred_std, _ = self.common_step(batch) - # Compute loss + # Compute loss - mean over unrolled times and batch batch_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, ) - ) # mean over unrolled times and batch + ) log_dict = {"train_loss": batch_loss} self.log_dict( @@ -328,9 +305,7 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, _ = self.common_step(batch) time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), + self.loss(prediction, target, pred_std), dim=0, ) # (time_steps-1) mean_loss = torch.mean(time_step_loss) @@ -355,7 +330,6 @@ def validation_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.val_metrics["mse"].append(entry_mses) @@ -382,9 +356,7 @@ def test_step(self, batch, batch_idx): # pred_steps, num_grid_nodes, d_f) or (d_f,) time_step_loss = torch.mean( - self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool - ), + self.loss(prediction, target, pred_std), dim=0, ) # (time_steps-1,) mean_loss = torch.mean(time_step_loss) @@ -413,16 +385,13 @@ def test_step(self, batch, batch_idx): prediction, target, pred_std, - mask=self.interior_mask_bool, sum_vars=False, ) # (B, pred_steps, d_f) self.test_metrics[metric_name].append(batch_metric_vals) if self.output_std: # Store output std. per variable, spatially averaged - mean_pred_std = torch.mean( - pred_std[..., self.interior_mask_bool, :], dim=-2 - ) # (B, pred_steps, d_f) + mean_pred_std = torch.mean(pred_std, dim=-2) # (B, pred_steps, d_f) self.test_metrics["output_std"].append(mean_pred_std) # Save per-sample spatial loss for specific times diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 74146c8..2a61e86 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -203,6 +203,18 @@ def main(input_args=None): default=1, help="Number of future time steps to use as input for forcing data", ) + parser.add_argument( + "--num_past_boundary_steps", + type=int, + default=1, + help="Number of past time steps to use as input for boundary data", + ) + parser.add_argument( + "--num_future_boundary_steps", + type=int, + default=1, + help="Number of future time steps to use as input for boundary data", + ) args = parser.parse_args(input_args) args.var_leads_metrics_watch = { int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() @@ -226,16 +238,21 @@ def main(input_args=None): seed.seed_everything(args.seed) # Load neural-lam configuration and datastore to use - config, datastore = load_config_and_datastore(config_path=args.config_path) + config, datastore, datastore_boundary = load_config_and_datastore( + config_path=args.config_path + ) # Create datamodule data_module = WeatherDataModule( datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=args.ar_steps_train, ar_steps_eval=args.ar_steps_eval, standardize=True, num_past_forcing_steps=args.num_past_forcing_steps, num_future_forcing_steps=args.num_future_forcing_steps, + num_past_boundary_steps=args.num_past_boundary_steps, + num_future_boundary_steps=args.num_future_boundary_steps, batch_size=args.batch_size, num_workers=args.num_workers, ) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index d6b57f8..efab20b 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -87,11 +87,6 @@ def plot_prediction( extent = datastore.get_xy_extent("state") - # Set up masking of border region - da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_values = np.invert(da_mask.values.astype(bool)).astype(float) - pixel_alpha = mask_values.clip(0.7, 1) # Faded border region - fig, axes = plt.subplots( 1, 2, @@ -107,7 +102,6 @@ def plot_prediction( origin="lower", x="x", extent=extent, - alpha=pixel_alpha.T, vmin=vmin, vmax=vmax, cmap="plasma", @@ -141,11 +135,6 @@ def plot_spatial_error( extent = datastore.get_xy_extent("state") - # Set up masking of border region - da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = mask_reshaped.clip(0.7, 1) # Faded border region - fig, ax = plt.subplots( figsize=(5, 4.8), subplot_kw={"projection": datastore.coords_projection}, @@ -164,7 +153,6 @@ def plot_spatial_error( error_grid, origin="lower", extent=extent, - alpha=pixel_alpha, vmin=vmin, vmax=vmax, cmap="OrRd", diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index b5f8558..0ddad87 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -22,6 +22,8 @@ class WeatherDataset(torch.utils.data.Dataset): ---------- datastore : BaseDatastore The datastore to load the data from (e.g. mdp). + datastore_boundary : BaseDatastore + The boundary datastore to load the data from (e.g. mdp). split : str, optional The data split to use ("train", "val" or "test"). Default is "train". ar_steps : int, optional @@ -36,6 +38,16 @@ class WeatherDataset(torch.utils.data.Dataset): forcing from times t, t+1, ..., t+j-1, t+j (and potentially times before t, given num_past_forcing_steps) are included as forcing inputs at time t. Default is 1. + num_past_boundary_steps: int, optional + Number of past time steps to include in boundary input. If set to i, + boundary from times t-i, t-i+1, ..., t-1, t (and potentially beyond, + given num_future_forcing_steps) are included as boundary inputs at time + t Default is 1. + num_future_boundary_steps: int, optional + Number of future time steps to include in boundary input. If set to j, + boundary from times t, t+1, ..., t+j-1, t+j (and potentially times + before t, given num_past_forcing_steps) are included as boundary inputs + at time t. Default is 1. standardize : bool, optional Whether to standardize the data. Default is True. """ @@ -43,10 +55,13 @@ class WeatherDataset(torch.utils.data.Dataset): def __init__( self, datastore: BaseDatastore, + datastore_boundary: BaseDatastore, split="train", ar_steps=3, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, standardize=True, ): super().__init__() @@ -54,15 +69,31 @@ def __init__( self.split = split self.ar_steps = ar_steps self.datastore = datastore + self.datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps self.da_state = self.datastore.get_dataarray( category="state", split=self.split ) + if self.da_state is None: + raise ValueError( + "A non-empty state dataarray must be provided. " + "The datastore.get_dataarray() returned None or empty array " + "for category='state'" + ) self.da_forcing = self.datastore.get_dataarray( category="forcing", split=self.split ) + # XXX For now boundary data is always considered mdp-forcing data + if self.datastore_boundary is not None: + self.da_boundary = self.datastore_boundary.get_dataarray( + category="forcing", split=self.split + ) + else: + self.da_boundary = None # check that with the provided data-arrays and ar_steps that we have a # non-zero amount of samples @@ -94,6 +125,98 @@ def __init__( "the data in `BaseDatastore.get_dataarray`?" ) + # handling ensemble data + if self.datastore.is_ensemble: + # for the now the strategy is to only include the first ensemble + # member + # XXX: this could be changed to include all ensemble members by + # splitting `idx` into two parts, one for the analysis time and one + # for the ensemble member and then increasing self.__len__ to + # include all ensemble members + warnings.warn( + "only use of ensemble member 0 (the first member) is " + "implemented for ensemble data" + ) + i_ensemble = 0 + self.da_state = self.da_state.isel(ensemble_member=i_ensemble) + else: + self.da_state = self.da_state + + # Check time step consistency in state data + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + self.forecast_step_state = self._get_time_step( + self.da_state.elapsed_forecast_duration + ) + else: + state_times = self.da_state.time + self.time_step_state = self._get_time_step(state_times) + + # Check time coverage for forcing and boundary data + if self.da_forcing is not None or self.da_boundary is not None: + if self.datastore.is_forecast: + state_times = self.da_state.analysis_time + else: + state_times = self.da_state.time + state_time_min = state_times.min().values + state_time_max = state_times.max().values + + if self.da_forcing is not None: + # Forcing data is part of the same datastore as state data + # During creation the time dimension of the forcing data + # is matched to the state data + if self.datastore.is_forecast: + forcing_times = self.da_forcing.analysis_time + self.forecast_step_forcing = self._get_time_step( + self.da_forcing.elapsed_forecast_duration + ) + else: + forcing_times = self.da_forcing.time + self.time_step_forcing = self._get_time_step( + forcing_times.values + ) + + if self.da_boundary is not None: + # Boundary data is part of a separate datastore + # The boundary data is allowed to have a different time_step + # Check that the boundary data covers the required time range + if self.datastore_boundary.is_forecast: + boundary_times = self.da_boundary.analysis_time + self.forecast_step_boundary = self._get_time_step( + self.da_boundary.elapsed_forecast_duration + ) + else: + boundary_times = self.da_boundary.time + self.time_step_boundary = self._get_time_step( + boundary_times.values + ) + boundary_time_min = boundary_times.min().values + boundary_time_max = boundary_times.max().values + + # Calculate required bounds for boundary using its time step + boundary_required_time_min = ( + state_time_min + - self.num_past_forcing_steps * self.time_step_boundary + ) + boundary_required_time_max = ( + state_time_max + + self.num_future_forcing_steps * self.time_step_boundary + ) + + if boundary_time_min > boundary_required_time_min: + raise ValueError( + f"Boundary data starts too late." + f"Required start: {boundary_required_time_min}, " + f"but boundary starts at {boundary_time_min}." + ) + + if boundary_time_max < boundary_required_time_max: + raise ValueError( + f"Boundary data ends too early." + f"Required end: {boundary_required_time_max}, " + f"but boundary ends at {boundary_time_max}." + ) + # Set up for standardization # TODO: This will become part of ar_model.py soon! self.standardize = standardize @@ -114,6 +237,16 @@ def __init__( self.da_forcing_mean = self.ds_forcing_stats.forcing_mean self.da_forcing_std = self.ds_forcing_stats.forcing_std + # XXX: Again, the boundary data is considered forcing data for now + if self.da_boundary is not None: + self.ds_boundary_stats = ( + self.datastore_boundary.get_standardization_dataarray( + category="forcing" + ) + ) + self.da_boundary_mean = self.ds_boundary_stats.forcing_mean + self.da_boundary_std = self.ds_boundary_stats.forcing_std + def __len__(self): if self.datastore.is_forecast: # for now we simply create a single sample for each analysis time @@ -126,7 +259,7 @@ def __len__(self): warnings.warn( "only using first ensemble member, so dataset size is " " effectively reduced by the number of ensemble members " - f"({self.da_state.ensemble_member.size})", + f"({self.datastore._num_ensemble_members})", UserWarning, ) @@ -160,32 +293,75 @@ def __len__(self): - self.num_future_forcing_steps ) - def _slice_state_time(self, da_state, idx, n_steps: int): + def _get_time_step(self, times): + """Calculate the time step from the data + + Parameters + ---------- + times : xr.DataArray + The time dataarray to calculate the time step from. + + Returns + ------- + time_step : float + The time step in the the format of the times dataarray. + """ + time_diffs = np.diff(times) + if not np.all(time_diffs == time_diffs[0]): + raise ValueError( + "Inconsistent time steps in data. " + f"Found different time steps: {np.unique(time_diffs)}" + ) + return time_diffs[0] + + def _slice_time( + self, + da_state, + idx, + n_steps: int, + da_forcing=None, + num_past_steps=None, + num_future_steps=None, + is_boundary=False, + ): """ - Produce a time slice of the given dataarray `da_state` (state) starting - at `idx` and with `n_steps` steps. An `offset`is calculated based on the - `num_past_forcing_steps` class attribute. `Offset` is used to offset the - start of the sample, to assert that enough previous time steps are - available for the 2 initial states and any corresponding forcings - (calculated in `_slice_forcing_time`). + Produce time slices of the given dataarrays `da_state` (state) and + `da_forcing`. For the state data, slicing is done + based on `idx`. For the forcing/boundary data, nearest neighbor matching + is performed based on the state times. Additionally, the time difference + between the matched forcing/boundary times and state times (in multiples + of state time steps) is added to the forcing dataarray. This will be + used as an additional input feature in the model (temporal embedding). Parameters ---------- da_state : xr.DataArray - The dataarray to slice. This is expected to have a `time` dimension - if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_duration` dimensions if the - datastore is providing forecast data. + The state dataarray to slice. idx : int - The index of the time step to start the sample from. + The index of the time step to start the sample from in the state + data. n_steps : int The number of time steps to include in the sample. + da_forcing : xr.DataArray + The forcing/boundary dataarray to slice. + num_past_steps : int, optional + The number of past time steps to include in the forcing/boundary + data. Default is `None`. + num_future_steps : int, optional + The number of future time steps to include in the forcing/boundary + data. Default is `None`. + is_boundary : bool, optional + Whether the data is boundary data. Default is `False`. Returns ------- - da_sliced : xr.DataArray - The sliced dataarray with dims ('time', 'grid_index', + da_state_sliced : xr.DataArray + The sliced state dataarray with dims ('time', 'grid_index', 'state_feature'). + da_forcing_matched : xr.DataArray + The sliced state dataarray with dims ('time', 'grid_index', + 'forcing/boundary_feature_windowed'). + If no forcing/boundary data is provided, this will be `None`. """ # The current implementation requires at least 2 time steps for the # initial state (see GraphCast). @@ -199,84 +375,58 @@ def _slice_state_time(self, da_state, idx, n_steps: int): # simply select a analysis time and the first `n_steps` forecast # times (given no offset). Note that this means that we get one # sample per forecast, always starting at forecast time 2. - da_sliced = da_state.isel( + da_state_sliced = da_state.isel( analysis_time=idx, elapsed_forecast_duration=slice(start_idx, end_idx), ) # create a new time dimension so that the produced sample has a # `time` dimension, similarly to the analysis only data - da_sliced["time"] = ( - da_sliced.analysis_time + da_sliced.elapsed_forecast_duration + da_state_sliced["time"] = ( + da_state_sliced.analysis_time + + da_state_sliced.elapsed_forecast_duration ) - da_sliced = da_sliced.swap_dims( + da_state_sliced = da_state_sliced.swap_dims( {"elapsed_forecast_duration": "time"} ) + else: # For analysis data we slice the time dimension directly. The offset # is only relevant for the very first (and last) samples in the # dataset. - start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) - end_idx = ( - idx + max(init_steps, self.num_past_forcing_steps) + n_steps - ) - da_sliced = da_state.isel(time=slice(start_idx, end_idx)) - return da_sliced + start_idx = idx + max(0, num_past_steps - init_steps) + end_idx = idx + max(init_steps, num_past_steps) + n_steps + da_state_sliced = da_state.isel(time=slice(start_idx, end_idx)) - def _slice_forcing_time(self, da_forcing, idx, n_steps: int): - """ - Produce a time slice of the given dataarray `da_forcing` (forcing) - starting at `idx` and with `n_steps` steps. An `offset` is calculated - based on the `num_past_forcing_steps` class attribute. It is used to - offset the start of the sample, to ensure that enough previous time - steps are available for the forcing data. The forcing data is windowed - around the current autoregressive time step to include the past and - future forcings. + if da_forcing is None: + return da_state_sliced, None - Parameters - ---------- - da_forcing : xr.DataArray - The forcing dataarray to slice. This is expected to have a `time` - dimension if the datastore is providing analysis only data, and a - `analysis_time` and `elapsed_forecast_duration` dimensions if the - datastore is providing forecast data. - idx : int - The index of the time step to start the sample from. - n_steps : int - The number of time steps to include in the sample. - - Returns - ------- - da_concat : xr.DataArray - The sliced dataarray with dims ('time', 'grid_index', - 'window', 'forcing_feature'). - """ - # The current implementation requires at least 2 time steps for the - # initial state (see GraphCast). The forcing data is windowed around the - # current autregressive time step. The two `init_steps` can also be used - # as past forcings. - init_steps = 2 + # Get the state times and its temporal resolution for matching with + # forcing data. + state_times = da_state_sliced["time"] da_list = [] - - if self.datastore.is_forecast: - # This implies that the data will have both `analysis_time` and - # `elapsed_forecast_duration` dimensions for forecasts. We for now - # simply select an analysis time and the first `n_steps` forecast - # times (given no offset). Note that this means that we get one - # sample per forecast. - # Add a 'time' dimension using the actual forecast times - offset = max(init_steps, self.num_past_forcing_steps) - for step in range(n_steps): - start_idx = offset + step - self.num_past_forcing_steps - end_idx = offset + step + self.num_future_forcing_steps + # Here we cannot check 'self.datastore.is_forecast' directly because we + # might be dealing with a datastore_boundary + if "analysis_time" in da_forcing.dims: + # For forecast data with analysis_time and elapsed_forecast_duration + # Select the closest analysis_time in the past in the + # forcing/boundary data + offset = max(0, num_past_steps - init_steps) + state_time = state_times[init_steps].values + forcing_analysis_time_idx = da_forcing.analysis_time.get_index( + "analysis_time" + ).get_indexer([state_time], method="pad")[0] + for step_idx in range(init_steps, len(state_times)): + start_idx = offset + step_idx - num_past_steps + end_idx = offset + step_idx + num_future_steps + 1 current_time = ( - da_forcing.analysis_time[idx] - + da_forcing.elapsed_forecast_duration[offset + step] + forcing_analysis_time_idx + + da_forcing.elapsed_forecast_duration[step_idx] ) da_sliced = da_forcing.isel( - analysis_time=idx, - elapsed_forecast_duration=slice(start_idx, end_idx + 1), + analysis_time=forcing_analysis_time_idx, + elapsed_forecast_duration=slice(start_idx, end_idx), ) da_sliced = da_sliced.rename( @@ -285,7 +435,7 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int): # Assign the 'window' coordinate to be relative positions da_sliced = da_sliced.assign_coords( - window=np.arange(len(da_sliced.window)) + window=np.arange(-num_past_steps, num_future_steps + 1) ) da_sliced = da_sliced.expand_dims( @@ -294,46 +444,133 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int): da_list.append(da_sliced) - # Concatenate the list of DataArrays along the 'time' dimension - da_concat = xr.concat(da_list, dim="time") - else: - # For analysis data, we slice the time dimension directly. The - # offset is only relevant for the very first (and last) samples in - # the dataset. - offset = idx + max(init_steps, self.num_past_forcing_steps) - for step in range(n_steps): - start_idx = offset + step - self.num_past_forcing_steps - end_idx = offset + step + self.num_future_forcing_steps - - # Slice the data over the desired time window - da_sliced = da_forcing.isel(time=slice(start_idx, end_idx + 1)) + for idx_time in range(init_steps, len(state_times)): + state_time = state_times[idx_time].values + + # Select the closest time in the past from forcing data using + # sel with method="pad" + forcing_time_idx = da_forcing.time.get_index( + "time" + ).get_indexer([state_time], method="pad")[0] + + # Use isel to select the window + da_window = da_forcing.isel( + time=slice( + forcing_time_idx - num_past_steps, + forcing_time_idx + num_future_steps + 1, + ), + ) - da_sliced = da_sliced.rename({"time": "window"}) + da_window = da_window.rename({"time": "window"}) - # Assign the 'window' coordinate to be relative positions - da_sliced = da_sliced.assign_coords( - window=np.arange(len(da_sliced.window)) + # Assign 'window' coordinate + da_window = da_window.assign_coords( + window=np.arange(-num_past_steps, num_future_steps + 1) ) - # Add a 'time' dimension to keep track of steps using actual - # time coordinates - current_time = da_forcing.time[offset + step] - da_sliced = da_sliced.expand_dims( - dim={"time": [current_time.values]} - ) + da_window = da_window.expand_dims(dim={"time": [state_time]}) - da_list.append(da_sliced) + da_list.append(da_window) + + da_forcing_matched = xr.concat(da_list, dim="time") + + # Generate temporal embedding `time_diff_steps` for the + # forcing/boundary data. This is the time difference in multiples + # of state time steps between the forcing/boundary time and the + # state time - # Concatenate the list of DataArrays along the 'time' dimension - da_concat = xr.concat(da_list, dim="time") + if is_boundary: + if self.datastore_boundary.is_forecast: + boundary_time_step = self.forecast_step_boundary + state_time_step = self.forecast_step_state + else: + boundary_time_step = self.time_step_boundary + state_time_step = self.time_step_state + time_diff_steps = ( + da_forcing_matched["window"] + * (boundary_time_step / state_time_step), + ) + else: + if self.datastore.is_forecast: + forcing_time_step = self.forecast_step_forcing + state_time_step = self.forecast_step_state + else: + forcing_time_step = self.time_step_forcing + state_time_step = self.time_step_state + time_diff_steps = ( + da_forcing_matched["window"] + * (forcing_time_step / state_time_step), + ) + time_diff_steps = da_forcing_matched.isel( + grid_index=0, forcing_feature=0 + ).window.values + # Add time difference as a new coordinate to concatenate to the + # forcing features later as temporal embedding + da_forcing_matched["time_diff_steps"] = ( + ("window"), + time_diff_steps, + ) - return da_concat + return da_state_sliced, da_forcing_matched + + def _process_windowed_data(self, da_windowed, da_state, da_target_times): + """Helper function to process windowed data. This function stacks the + 'forcing_feature' and 'window' dimensions and adds the time step + differences to the existing features as a temporal embedding. + + Parameters + ---------- + da_windowed : xr.DataArray + The windowed data to process. Can be `None` if no data is provided. + da_state : xr.DataArray + The state dataarray. + da_target_times : xr.DataArray + The target times. + + Returns + ------- + da_windowed : xr.DataArray + The processed windowed data. If `da_windowed` is `None`, an empty + DataArray with the correct dimensions and coordinates is returned. + + """ + stacked_dim = "forcing_feature_windowed" + if da_windowed is not None: + window_size = da_windowed.window.size + # Stack the 'feature' and 'window' dimensions and add the + # time step differences to the existing features as a temporal + # embedding + da_windowed = da_windowed.stack( + {stacked_dim: ("forcing_feature", "window")} + ) + # Add the time step differences as a new feature to the windowed + # data + time_diff_steps = da_windowed["time_diff_steps"].isel( + forcing_feature_windowed=slice(0, window_size) + ) + # All data variables share the same temporal embedding + da_windowed = xr.concat( + [da_windowed, time_diff_steps], + dim="forcing_feature_windowed", + ) + else: + # Create empty DataArray with the correct dimensions and coordinates + da_windowed = xr.DataArray( + data=np.empty((self.ar_steps, da_state.grid_index.size, 0)), + dims=("time", "grid_index", f"{stacked_dim}"), + coords={ + "time": da_target_times, + "grid_index": da_state.grid_index, + f"{stacked_dim}": [], + }, + ) + return da_windowed def _build_item_dataarrays(self, idx): """ - Create the dataarrays for the initial states, target states and forcing - data for the sample at index `idx`. + Create the dataarrays for the initial states, target states, forcing + and boundary data for the sample at index `idx`. Parameters ---------- @@ -348,26 +585,13 @@ def _build_item_dataarrays(self, idx): The dataarray for the target states. da_forcing_windowed : xr.DataArray The dataarray for the forcing data, windowed for the sample. + da_boundary_windowed : xr.DataArray + The dataarray for the boundary data, windowed for the sample. + Boundary data is always considered forcing data. da_target_times : xr.DataArray The dataarray for the target times. """ - # handling ensemble data - if self.datastore.is_ensemble: - # for the now the strategy is to only include the first ensemble - # member - # XXX: this could be changed to include all ensemble members by - # splitting `idx` into two parts, one for the analysis time and one - # for the ensemble member and then increasing self.__len__ to - # include all ensemble members - warnings.warn( - "only use of ensemble member 0 (the first member) is " - "implemented for ensemble data" - ) - i_ensemble = 0 - da_state = self.da_state.isel(ensemble_member=i_ensemble) - else: - da_state = self.da_state - + da_state = self.da_state if self.da_forcing is not None: if "ensemble_member" in self.da_forcing.dims: raise NotImplementedError( @@ -377,20 +601,45 @@ def _build_item_dataarrays(self, idx): else: da_forcing = None - # handle time sampling in a way that is compatible with both analysis - # and forecast data - da_state = self._slice_state_time( - da_state=da_state, idx=idx, n_steps=self.ar_steps - ) - if da_forcing is not None: - da_forcing_windowed = self._slice_forcing_time( - da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps + if self.da_boundary is not None: + da_boundary = self.da_boundary + else: + da_boundary = None + + # This function will return a slice of the state data and the forcing + # and boundary data (if provided) for one sample (idx). + # If da_forcing is None, the function will return None for + # da_forcing_windowed. + if da_boundary is not None: + _, da_boundary_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_boundary, + num_future_steps=self.num_future_boundary_steps, + num_past_steps=self.num_past_boundary_steps, + is_boundary=True, ) + else: + da_boundary_windowed = None + # XXX: Currently, the order of the `slice_time` calls is important + # as `da_state` is modified in the second call. This should be + # refactored to be more robust. + da_state, da_forcing_windowed = self._slice_time( + da_state=da_state, + idx=idx, + n_steps=self.ar_steps, + da_forcing=da_forcing, + num_future_steps=self.num_future_forcing_steps, + num_past_steps=self.num_past_forcing_steps, + ) # load the data into memory da_state.load() if da_forcing is not None: da_forcing_windowed.load() + if da_boundary is not None: + da_boundary_windowed.load() da_init_states = da_state.isel(time=slice(0, 2)) da_target_states = da_state.isel(time=slice(2, None)) @@ -413,30 +662,27 @@ def _build_item_dataarrays(self, idx): da_forcing_windowed - self.da_forcing_mean ) / self.da_forcing_std - if da_forcing is not None: - # stack the `forcing_feature` and `window_sample` dimensions into a - # single `forcing_feature` dimension - da_forcing_windowed = da_forcing_windowed.stack( - forcing_feature_windowed=("forcing_feature", "window") - ) - else: - # create an empty forcing tensor with the right shape - da_forcing_windowed = xr.DataArray( - data=np.empty( - (self.ar_steps, da_state.grid_index.size, 0), - ), - dims=("time", "grid_index", "forcing_feature"), - coords={ - "time": da_target_times, - "grid_index": da_state.grid_index, - "forcing_feature": [], - }, - ) + if da_boundary is not None: + da_boundary_windowed = ( + da_boundary_windowed - self.da_boundary_mean + ) / self.da_boundary_std + + # This function handles the stacking of the forcing and boundary data + # and adds the time step differences as a temporal embedding. + # It can handle `None` inputs for the forcing and boundary data + # (and simlpy return an empty DataArray in that case). + da_forcing_windowed = self._process_windowed_data( + da_forcing_windowed, da_state, da_target_times + ) + da_boundary_windowed = self._process_windowed_data( + da_boundary_windowed, da_state, da_target_times + ) return ( da_init_states, da_target_states, da_forcing_windowed, + da_boundary_windowed, da_target_times, ) @@ -471,6 +717,7 @@ def __getitem__(self, idx): da_init_states, da_target_states, da_forcing_windowed, + da_boundary_windowed, da_target_times, ) = self._build_item_dataarrays(idx=idx) @@ -487,13 +734,20 @@ def __getitem__(self, idx): ) forcing = torch.tensor(da_forcing_windowed.values, dtype=tensor_dtype) + boundary = torch.tensor(da_boundary_windowed.values, dtype=tensor_dtype) # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) # forcing: (ar_steps, N_grid, d_windowed_forcing) + # boundary: (ar_steps, N_grid, d_windowed_boundary) # target_times: (ar_steps,) - return init_states, target_states, forcing, target_times + # Assert that the boundary data is an empty tensor if the corresponding + # datastore_boundary is `None` + if self.datastore_boundary is None: + assert boundary.numel() == 0 + + return init_states, target_states, forcing, boundary, target_times def __iter__(self): """ @@ -606,18 +860,24 @@ class WeatherDataModule(pl.LightningDataModule): def __init__( self, datastore: BaseDatastore, + datastore_boundary: BaseDatastore, ar_steps_train=3, ar_steps_eval=25, standardize=True, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, batch_size=4, num_workers=16, ): super().__init__() self._datastore = datastore + self._datastore_boundary = datastore_boundary self.num_past_forcing_steps = num_past_forcing_steps self.num_future_forcing_steps = num_future_forcing_steps + self.num_past_boundary_steps = num_past_boundary_steps + self.num_future_boundary_steps = num_future_boundary_steps self.ar_steps_train = ar_steps_train self.ar_steps_eval = ar_steps_eval self.standardize = standardize @@ -627,8 +887,10 @@ def __init__( self.val_dataset = None self.test_dataset = None if num_workers > 0: - # default to spawn for now, as the default on linux "fork" hangs - # when using dask (which the npyfilesmeps datastore uses) + # BUG: There also seem to be issues with "spawn" and `gloo`, to be + # investigated. Defaults to spawn for now, as the default on linux + # "fork" hangs when using dask (which the npyfilesmeps datastore + # uses) self.multiprocessing_context = "spawn" else: self.multiprocessing_context = None @@ -637,29 +899,38 @@ def setup(self, stage=None): if stage == "fit" or stage is None: self.train_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="train", ar_steps=self.ar_steps_train, standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) self.val_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="val", ar_steps=self.ar_steps_eval, standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) if stage == "test" or stage is None: self.test_dataset = WeatherDataset( datastore=self._datastore, + datastore_boundary=self._datastore_boundary, split="test", ar_steps=self.ar_steps_eval, standardize=self.standardize, num_past_forcing_steps=self.num_past_forcing_steps, num_future_forcing_steps=self.num_future_forcing_steps, + num_past_boundary_steps=self.num_past_boundary_steps, + num_future_boundary_steps=self.num_future_boundary_steps, ) def train_dataloader(self): diff --git a/pyproject.toml b/pyproject.toml index fdcb7f3..f556ef6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ requires-python = ">=3.9" [project.optional-dependencies] -dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"] +dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2", "gcsfs>=2021.10.0"] [tool.setuptools] py-modules = ["neural_lam"] diff --git a/tests/conftest.py b/tests/conftest.py index 5d799c7..15ee159 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,6 +94,15 @@ def download_meps_example_reduced_dataset(): dummydata=None, ) +DATASTORES_BOUNDARY_EXAMPLES = { + "mdp": ( + DATASTORE_EXAMPLES_ROOT_PATH + / "mdp" + / "era5_1000hPa_danra_100m_winds" + / "era5.datastore.yaml" + ), +} + DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore @@ -102,5 +111,13 @@ def init_datastore_example(datastore_kind): datastore_kind=datastore_kind, config_path=DATASTORES_EXAMPLES[datastore_kind], ) - return datastore + + +def init_datastore_boundary_example(datastore_kind): + datastore_boundary = init_datastore( + datastore_kind=datastore_kind, + config_path=DATASTORES_BOUNDARY_EXAMPLES[datastore_kind], + ) + + return datastore_boundary diff --git a/tests/datastore_examples/.gitignore b/tests/datastore_examples/.gitignore index e84e649..4fbd232 100644 --- a/tests/datastore_examples/.gitignore +++ b/tests/datastore_examples/.gitignore @@ -1,2 +1,3 @@ npyfilesmeps/*.zip -npyfilesmeps/meps_example_reduced/ +npyfilesmeps/meps_example_reduced +npyfilesmeps/era5_1000hPa_temp_meps_example_reduced diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore new file mode 100644 index 0000000..f2828f4 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/.gitignore @@ -0,0 +1,2 @@ +*.zarr/ +graph/ diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml new file mode 100644 index 0000000..a158bee --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/config.yaml @@ -0,0 +1,12 @@ +datastore: + kind: mdp + config_path: danra.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + u100m: 1.0 + v100m: 1.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml new file mode 100644 index 0000000..3edf126 --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/danra.datastore.yaml @@ -0,0 +1,99 @@ +schema_version: v0.5.0 +dataset_version: v0.1.0 + +output: + variables: + static: [grid_index, static_feature] + state: [time, grid_index, state_feature] + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-03T00:00 + end: 1990-09-09T00:00 + step: PT3H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-03T00:00 + end: 1990-09-06T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-06T00:00 + end: 1990-09-07T00:00 + test: + start: 1990-09-07T00:00 + end: 1990-09-09T00:00 + +inputs: + danra_height_levels: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr + dims: [time, x, y, altitude] + variables: + u: + altitude: + values: [100,] + units: m + v: + altitude: + values: [100, ] + units: m + dim_mapping: + time: + method: rename + dim: time + state_feature: + method: stack_variables_by_var_name + dims: [altitude] + name_format: "{var_name}{altitude}m" + grid_index: + method: stack + dims: [x, y] + target_output_variable: state + + danra_surface: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr + dims: [time, x, y] + variables: + # use surface incoming shortwave radiation as forcing + - swavr0m + dim_mapping: + time: + method: rename + dim: time + grid_index: + method: stack + dims: [x, y] + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: forcing + + danra_lsm: + path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr + dims: [x, y] + variables: + - lsm + dim_mapping: + grid_index: + method: stack + dims: [x, y] + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + target_output_variable: static + +extra: + projection: + class_name: LambertConformal + kwargs: + central_longitude: 25.0 + central_latitude: 56.7 + standard_parallels: [56.7, 56.7] + globe: + semimajor_axis: 6367470.0 + semiminor_axis: 6367470.0 diff --git a/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml new file mode 100644 index 0000000..c83489c --- /dev/null +++ b/tests/datastore_examples/mdp/era5_1000hPa_danra_100m_winds/era5.datastore.yaml @@ -0,0 +1,106 @@ +schema_version: v0.5.0 +dataset_version: v1.0.0 + +output: + variables: + static: [grid_index, static_feature] + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + step: PT6H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + test: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + +inputs: + era_height_levels: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + u_component_of_wind: + level: + values: [1000,] + units: hPa + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + dims: [level] + name_format: "{var_name}{level}hPa" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_surface: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + - mean_sea_level_pressure + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_static: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + - land_sea_mask + dim_mapping: + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: static + +extra: + projection: + class_name: PlateCarree + kwargs: + central_longitude: 0.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml new file mode 100644 index 0000000..27cc976 --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/config.yaml @@ -0,0 +1,18 @@ +datastore: + kind: npyfilesmeps + config_path: meps_example_reduced.datastore.yaml +datastore_boundary: + kind: mdp + config_path: era5.datastore.yaml +training: + state_feature_weighting: + __config_class__: ManualStateFeatureWeighting + weights: + nlwrs_0: 1.0 + nswrs_0: 1.0 + pres_0g: 1.0 + pres_0s: 1.0 + r_2: 1.0 + r_65: 1.0 + t_2: 1.0 + t_65: 1.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml new file mode 100644 index 0000000..c83489c --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/era5.datastore.yaml @@ -0,0 +1,106 @@ +schema_version: v0.5.0 +dataset_version: v1.0.0 + +output: + variables: + static: [grid_index, static_feature] + forcing: [time, grid_index, forcing_feature] + coord_ranges: + time: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + step: PT6H + chunking: + time: 1 + splitting: + dim: time + splits: + train: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + compute_statistics: + ops: [mean, std, diff_mean, diff_std] + dims: [grid_index, time] + val: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + test: + start: 1990-09-01T00:00 + end: 2022-09-30T00:00 + +inputs: + era_height_levels: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + u_component_of_wind: + level: + values: [1000,] + units: hPa + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + dims: [level] + name_format: "{var_name}{level}hPa" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_surface: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + - mean_sea_level_pressure + dim_mapping: + time: + method: rename + dim: time + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + forcing_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: forcing + + era5_static: + path: 'gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-64x32_equiangular_conservative.zarr' + dims: [time, longitude, latitude, level] + variables: + - land_sea_mask + dim_mapping: + x: + method: rename + dim: longitude + y: + method: rename + dim: latitude + static_feature: + method: stack_variables_by_var_name + name_format: "{var_name}" + grid_index: + method: stack + dims: [x, y] + target_output_variable: static + +extra: + projection: + class_name: PlateCarree + kwargs: + central_longitude: 0.0 diff --git a/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml new file mode 100644 index 0000000..3d88d4a --- /dev/null +++ b/tests/datastore_examples/npyfilesmeps/era5_1000hPa_temp_meps_example_reduced/meps_example_reduced.datastore.yaml @@ -0,0 +1,44 @@ +dataset: + name: meps_example_reduced + num_forcing_features: 16 + var_longnames: + - pres_heightAboveGround_0_instant + - pres_heightAboveSea_0_instant + - nlwrs_heightAboveGround_0_accum + - nswrs_heightAboveGround_0_accum + - r_heightAboveGround_2_instant + - r_hybrid_65_instant + - t_heightAboveGround_2_instant + - t_hybrid_65_instant + var_names: + - pres_0g + - pres_0s + - nlwrs_0 + - nswrs_0 + - r_2 + - r_65 + - t_2 + - t_65 + var_units: + - Pa + - Pa + - W/m**2 + - W/m**2 + - '' + - '' + - K + - K + num_timesteps: 65 + num_ensemble_members: 2 + step_length: 3 +grid_shape_state: +- 134 +- 119 +projection: + class_name: LambertConformal + kwargs: + central_latitude: 63.3 + central_longitude: 15.0 + standard_parallels: + - 63.3 + - 63.3 diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index 9075d40..a958b8f 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -28,7 +28,7 @@ class DummyDatastore(BaseRegularGridDatastore): """ SHORT_NAME = "dummydata" - T0 = isodate.parse_datetime("2021-01-01T00:00:00") + T0 = isodate.parse_datetime("1990-09-02T00:00:00") N_FEATURES = dict(state=5, forcing=2, static=1) CARTESIAN_COORDS = ["x", "y"] @@ -148,12 +148,6 @@ def __init__( times = [self.T0 + dt * i for i in range(n_timesteps)] self.ds.coords["time"] = times - # Add boundary mask - self.ds["boundary_mask"] = xr.DataArray( - np.random.choice([0, 1], size=(n_points_1d, n_points_1d)), - dims=["x", "y"], - ) - # Stack the spatial dimensions into grid_index self.ds = self.ds.stack(grid_index=self.CARTESIAN_COORDS) @@ -342,22 +336,6 @@ def get_dataarray( dim_order = self.expected_dim_order(category=category) return self.ds[category].transpose(*dim_order) - @cached_property - def boundary_mask(self) -> xr.DataArray: - """ - Return the boundary mask for the dataset, with spatial dimensions - stacked. Where the value is 1, the grid point is a boundary point, and - where the value is 0, the grid point is not a boundary point. - - Returns - ------- - xr.DataArray - The boundary mask for the dataset, with dimensions - `('grid_index',)`. - - """ - return self.ds["boundary_mask"] - def get_xy(self, category: str, stacked: bool) -> ndarray: """Return the x, y coordinates of the dataset. diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 419aece..aa7b645 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -14,12 +14,19 @@ from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataset -from tests.conftest import init_datastore_example +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) from tests.dummy_datastore import DummyDatastore @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_dataset_item_shapes(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_dataset_item_shapes(datastore_name, datastore_boundary_name): """Check that the `datastore.get_dataarray` method is implemented. Validate the shapes of the tensors match between the different @@ -31,24 +38,33 @@ def test_dataset_item_shapes(datastore_name): """ datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) N_gridpoints = datastore.num_grid_points + N_gridpoints_boundary = datastore_boundary.num_grid_points N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 dataset = WeatherDataset( datastore=datastore, + datastore_boundary=datastore_boundary, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, num_future_forcing_steps=num_future_forcing_steps, + num_past_boundary_steps=num_past_boundary_steps, + num_future_boundary_steps=num_future_boundary_steps, ) item = dataset[0] # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - init_states, target_states, forcing, target_times = item + init_states, target_states, forcing, boundary, target_times = item # initial states assert init_states.ndim == 3 @@ -66,10 +82,20 @@ def test_dataset_item_shapes(datastore_name): assert forcing.ndim == 3 assert forcing.shape[0] == N_pred_steps assert forcing.shape[1] == N_gridpoints - assert forcing.shape[2] == datastore.get_num_data_vars("forcing") * ( + # each time step in the window has one corresponding temporal embedding + # that is shared across all grid points, times and variables + assert forcing.shape[2] == (datastore.get_num_data_vars("forcing") + 1) * ( num_past_forcing_steps + num_future_forcing_steps + 1 ) + # boundary + assert boundary.ndim == 3 + assert boundary.shape[0] == N_pred_steps + assert boundary.shape[1] == N_gridpoints_boundary + assert boundary.shape[2] == ( + datastore_boundary.get_num_data_vars("forcing") + 1 + ) * (num_past_boundary_steps + num_future_boundary_steps + 1) + # batch times assert target_times.ndim == 1 assert target_times.shape[0] == N_pred_steps @@ -87,8 +113,10 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name): N_pred_steps = 4 num_past_forcing_steps = 1 num_future_forcing_steps = 1 + dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=N_pred_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -99,10 +127,14 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name): # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - _, target_states, _, target_times_arr = dataset[idx] - _, da_target_true, _, da_target_times_true = dataset._build_item_dataarrays( - idx=idx - ) + _, target_states, _, _, target_times_arr = dataset[idx] + ( + _, + da_target_true, + _, + _, + da_target_times_true, + ) = dataset._build_item_dataarrays(idx=idx) target_times = np.array(target_times_arr, dtype="datetime64[ns]") np.testing.assert_equal(target_times, da_target_times_true.values) @@ -158,13 +190,19 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name): @pytest.mark.parametrize("split", ["train", "val", "test"]) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_single_batch(datastore_name, split): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_single_batch(datastore_name, datastore_boundary_name, split): """Check that the `datastore.get_dataarray` method is implemented. And that it returns an xarray DataArray with the correct dimensions. """ datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) device_name = ( torch.device("cuda") if torch.cuda.is_available() else "cpu" @@ -210,7 +248,9 @@ def _create_graph(): ) ) - dataset = WeatherDataset(datastore=datastore, split=split, ar_steps=2) + dataset = WeatherDataset( + datastore=datastore, datastore_boundary=datastore_boundary, split=split + ) model = GraphLAM(args=args, datastore=datastore, config=config) # noqa @@ -244,6 +284,7 @@ def test_dataset_length(dataset_config): dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, split="train", ar_steps=dataset_config["ar_steps"], num_past_forcing_steps=dataset_config["past"], diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 4a4b110..a91f624 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -18,8 +18,6 @@ dataarray for the given category. - `get_dataarray` (method): Return the processed data (as a single `xr.DataArray`) for the given category and test/train/val-split. -- `boundary_mask` (property): Return the boundary mask for the dataset, - with spatial dimensions stacked. - `config` (property): Return the configuration of the datastore. In addition BaseRegularGridDatastore must have the following methods and @@ -213,25 +211,6 @@ def test_get_dataarray(datastore_name): assert n_features["train"] == n_features["val"] == n_features["test"] -@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_boundary_mask(datastore_name): - """Check that the `datastore.boundary_mask` property is implemented and - that the returned object is an xarray DataArray with the correct shape.""" - datastore = init_datastore_example(datastore_name) - da_mask = datastore.boundary_mask - - assert isinstance(da_mask, xr.DataArray) - assert set(da_mask.dims) == {"grid_index"} - assert da_mask.dtype == "int" - assert set(da_mask.values) == {0, 1} - assert da_mask.sum() > 0 - assert da_mask.sum() < da_mask.size - - if isinstance(datastore, BaseRegularGridDatastore): - grid_shape = datastore.grid_shape_state - assert datastore.boundary_mask.size == grid_shape.x * grid_shape.y - - @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_get_xy_extent(datastore_name): """Check that the `datastore.get_xy_extent` method is implemented and that diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py index 2916150..21038e7 100644 --- a/tests/test_time_slicing.py +++ b/tests/test_time_slicing.py @@ -16,40 +16,76 @@ class SinglePointDummyDatastore(BaseDatastore): root_path = None def __init__(self, time_values, state_data, forcing_data, is_forecast): - self._time_values = np.array(time_values) - self._state_data = np.array(state_data) - self._forcing_data = np.array(forcing_data) self.is_forecast = is_forecast - if is_forecast: - assert self._state_data.ndim == 2 + self._analysis_times, self._forecast_times = time_values + self._state_data = np.array(state_data) + self._forcing_data = np.array(forcing_data) + # state_data and forcing_data should be 2D arrays with shape + # (n_analysis_times, n_forecast_times) else: - assert self._state_data.ndim == 1 + self._time_values = np.array(time_values) + self._state_data = np.array(state_data) + self._forcing_data = np.array(forcing_data) + + if is_forecast: + assert self._state_data.ndim == 2 + else: + assert self._state_data.ndim == 1 def get_num_data_vars(self, category): return 1 def get_dataarray(self, category, split): - if category == "state": - values = self._state_data - elif category == "forcing": - values = self._forcing_data - else: - raise NotImplementedError(category) - if self.is_forecast: - raise NotImplementedError() + if category == "state": + # Create DataArray with dims ('analysis_time', + # 'elapsed_forecast_duration') + da = xr.DataArray( + self._state_data, + dims=["analysis_time", "elapsed_forecast_duration"], + coords={ + "analysis_time": self._analysis_times, + "elapsed_forecast_duration": self._forecast_times, + }, + ) + elif category == "forcing": + da = xr.DataArray( + self._forcing_data, + dims=["analysis_time", "elapsed_forecast_duration"], + coords={ + "analysis_time": self._analysis_times, + "elapsed_forecast_duration": self._forecast_times, + }, + ) + else: + raise NotImplementedError(category) + # Add 'grid_index' and '{category}_feature' dimensions + da = da.expand_dims("grid_index") + da = da.expand_dims(f"{category}_feature") + dim_order = self.expected_dim_order(category=category) + return da.transpose(*dim_order) else: - da = xr.DataArray( - values, dims=["time"], coords={"time": self._time_values} - ) - # add `{category}_feature` and `grid_index` dimensions + if category == "state": + values = self._state_data + elif category == "forcing": + values = self._forcing_data + else: + raise NotImplementedError(category) + + if self.is_forecast: + raise NotImplementedError() + else: + da = xr.DataArray( + values, dims=["time"], coords={"time": self._time_values} + ) - da = da.expand_dims("grid_index") - da = da.expand_dims(f"{category}_feature") + # add `{category}_feature` and `grid_index` dimensions + da = da.expand_dims("grid_index") + da = da.expand_dims(f"{category}_feature") - dim_order = self.expected_dim_order(category=category) - return da.transpose(*dim_order) + dim_order = self.expected_dim_order(category=category) + return da.transpose(*dim_order) def get_standardization_dataarray(self, category): raise NotImplementedError() @@ -67,25 +103,55 @@ def get_vars_long_names(self, category): raise NotImplementedError() -ANALYSIS_STATE_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +INIT_STEPS = 2 + +STATE_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] FORCING_VALUES = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] +STATE_VALUES_FORECAST = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], # Analysis time 0 + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], # Analysis time 1 + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], # Analysis time 2 +] +FORCING_VALUES_FORECAST = [ + [100, 101, 102, 103, 104, 105, 106, 107, 108, 109], # Analysis time 0 + [110, 111, 112, 113, 114, 115, 116, 117, 118, 119], # Analysis time 1 + [120, 121, 122, 123, 124, 125, 126, 127, 128, 129], # Analysis time 2 +] + +SCENARIOS = [ + [3, 0, 0], + [3, 1, 0], + [3, 2, 0], + [3, 3, 0], + [3, 0, 1], + [3, 0, 2], + [3, 0, 3], + [3, 1, 1], + [3, 2, 1], + [3, 3, 1], + [3, 1, 2], + [3, 1, 3], + [3, 2, 2], + [3, 2, 3], + [3, 3, 2], + [3, 3, 3], +] + @pytest.mark.parametrize( "ar_steps,num_past_forcing_steps,num_future_forcing_steps", - [[3, 0, 0], [3, 1, 0], [3, 2, 0], [3, 3, 0]], + SCENARIOS, ) def test_time_slicing_analysis( ar_steps, num_past_forcing_steps, num_future_forcing_steps ): - # state and forcing variables have only on dimension, `time` - time_values = np.datetime64("2020-01-01") + np.arange( - len(ANALYSIS_STATE_VALUES) - ) - assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values) + # state and forcing variables have only one dimension, `time` + time_values = np.datetime64("2020-01-01") + np.arange(len(STATE_VALUES)) + assert len(STATE_VALUES) == len(FORCING_VALUES) == len(time_values) datastore = SinglePointDummyDatastore( - state_data=ANALYSIS_STATE_VALUES, + state_data=STATE_VALUES, forcing_data=FORCING_VALUES, time_values=time_values, is_forecast=False, @@ -93,6 +159,7 @@ def test_time_slicing_analysis( dataset = WeatherDataset( datastore=datastore, + datastore_boundary=None, ar_steps=ar_steps, num_future_forcing_steps=num_future_forcing_steps, num_past_forcing_steps=num_past_forcing_steps, @@ -101,16 +168,14 @@ def test_time_slicing_analysis( sample = dataset[0] - init_states, target_states, forcing, _ = [ + init_states, target_states, forcing, _, _ = [ tensor.numpy() for tensor in sample ] + # Some scenarios for the human reader expected_init_states = [0, 1] if ar_steps == 3: expected_target_states = [2, 3, 4] - else: - raise NotImplementedError() - if num_past_forcing_steps == num_future_forcing_steps == 0: expected_forcing_values = [[12], [13], [14]] elif num_past_forcing_steps == 1 and num_future_forcing_steps == 0: @@ -125,22 +190,138 @@ def test_time_slicing_analysis( [11, 12, 13, 14], [12, 13, 14, 15], ] - else: - raise NotImplementedError() + + # Compute expected initial states and target states based on ar_steps + offset = max(0, num_past_forcing_steps - INIT_STEPS) + init_idx = INIT_STEPS + offset + # Compute expected forcing values based on num_past_forcing_steps and + # num_future_forcing_steps for all scenarios + expected_init_states = STATE_VALUES[offset:init_idx] + expected_target_states = STATE_VALUES[init_idx : init_idx + ar_steps] + total_forcing_window = num_past_forcing_steps + num_future_forcing_steps + 1 + expected_forcing_values = [] + for i in range(ar_steps): + start_idx = i + init_idx - num_past_forcing_steps + end_idx = i + init_idx + num_future_forcing_steps + 1 + forcing_window = FORCING_VALUES[start_idx:end_idx] + expected_forcing_values.append(forcing_window) # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) - # forcing: (ar_steps, N_grid, d_windowed_forcing) + # forcing: (ar_steps, N_grid, d_windowed_forcing * 2) # target_times: (ar_steps,) - assert init_states.shape == (2, 1, 1) - assert init_states[:, 0, 0].tolist() == expected_init_states - assert target_states.shape == (3, 1, 1) - assert target_states[:, 0, 0].tolist() == expected_target_states + # Adjust assertions to use computed expected values + assert init_states.shape == (INIT_STEPS, 1, 1) + np.testing.assert_array_equal(init_states[:, 0, 0], expected_init_states) + + assert target_states.shape == (ar_steps, 1, 1) + np.testing.assert_array_equal( + target_states[:, 0, 0], expected_target_states + ) assert forcing.shape == ( - 3, + ar_steps, 1, - 1 + num_past_forcing_steps + num_future_forcing_steps, + total_forcing_window + * 2, # Each windowed feature includes temporal embedding + ) + + # Extract the forcing values from the tensor (excluding temporal embeddings) + forcing_values = forcing[:, 0, :total_forcing_window] + + # Compare with expected forcing values + for i in range(ar_steps): + np.testing.assert_array_equal( + forcing_values[i], expected_forcing_values[i] + ) + + +@pytest.mark.parametrize( + "ar_steps,num_past_forcing_steps,num_future_forcing_steps", + SCENARIOS, +) +def test_time_slicing_forecast( + ar_steps, num_past_forcing_steps, num_future_forcing_steps +): + # Constants for forecast data + ANALYSIS_TIMES = np.datetime64("2020-01-01") + np.arange( + len(STATE_VALUES_FORECAST) + ) + ELAPSED_FORECAST_DURATION = np.timedelta64(0, "D") + np.arange( + len(FORCING_VALUES_FORECAST[0]) + ) + # Create a dummy datastore with forecast data + time_values = (ANALYSIS_TIMES, ELAPSED_FORECAST_DURATION) + datastore = SinglePointDummyDatastore( + state_data=STATE_VALUES_FORECAST, + forcing_data=FORCING_VALUES_FORECAST, + time_values=time_values, + is_forecast=True, ) - np.testing.assert_equal(forcing[:, 0, :], np.array(expected_forcing_values)) + + dataset = WeatherDataset( + datastore=datastore, + datastore_boundary=None, + split="train", + ar_steps=ar_steps, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + standardize=False, + ) + + # Test the dataset length + assert len(dataset) == len(ANALYSIS_TIMES) + + sample = dataset[0] + + init_states, target_states, forcing, _, _ = [ + tensor.numpy() for tensor in sample + ] + + # Compute expected initial states and target states based on ar_steps + offset = max(0, num_past_forcing_steps - INIT_STEPS) + init_idx = INIT_STEPS + offset + expected_init_states = STATE_VALUES_FORECAST[0][offset:init_idx] + expected_target_states = STATE_VALUES_FORECAST[0][ + init_idx : init_idx + ar_steps + ] + + # Compute expected forcing values based on num_past_forcing_steps and + # num_future_forcing_steps + total_forcing_window = num_past_forcing_steps + num_future_forcing_steps + 1 + expected_forcing_values = [] + for i in range(ar_steps): + start_idx = i + init_idx - num_past_forcing_steps + end_idx = i + init_idx + num_future_forcing_steps + 1 + forcing_window = FORCING_VALUES_FORECAST[INIT_STEPS][start_idx:end_idx] + expected_forcing_values.append(forcing_window) + + # init_states: (2, N_grid, d_features) + # target_states: (ar_steps, N_grid, d_features) + # forcing: (ar_steps, N_grid, d_windowed_forcing * 2) + # target_times: (ar_steps,) + + # Assertions + np.testing.assert_array_equal(init_states[:, 0, 0], expected_init_states) + np.testing.assert_array_equal( + target_states[:, 0, 0], expected_target_states + ) + + # Verify the shape of the forcing data + expected_forcing_shape = ( + ar_steps, # Number of AR steps + 1, # Number of grid points + total_forcing_window # Total number of forcing steps in the window + * 2, # Each windowed feature includes temporal embedding + ) + assert forcing.shape == expected_forcing_shape + + # Extract the forcing values from the tensor (excluding temporal embeddings) + forcing_values = forcing[:, 0, :total_forcing_window] + + # Compare with expected forcing values + for i in range(ar_steps): + np.testing.assert_array_equal( + forcing_values[i], expected_forcing_values[i] + ) diff --git a/tests/test_training.py b/tests/test_training.py index 1ed1847..ca0ebf4 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -14,18 +14,33 @@ from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataModule -from tests.conftest import init_datastore_example +from tests.conftest import ( + DATASTORES_BOUNDARY_EXAMPLES, + init_datastore_boundary_example, + init_datastore_example, +) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_training(datastore_name): +@pytest.mark.parametrize( + "datastore_boundary_name", DATASTORES_BOUNDARY_EXAMPLES.keys() +) +def test_training(datastore_name, datastore_boundary_name): datastore = init_datastore_example(datastore_name) + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name + ) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip( f"Skipping test for {datastore_name} as it is not a regular " "grid datastore." ) + if not isinstance(datastore_boundary, BaseRegularGridDatastore): + pytest.skip( + f"Skipping test for {datastore_boundary_name} as it is not a " + "regular grid datastore." + ) if torch.cuda.is_available(): device_name = "cuda" @@ -59,6 +74,7 @@ def test_training(datastore_name): data_module = WeatherDataModule( datastore=datastore, + datastore_boundary=datastore_boundary, ar_steps_train=3, ar_steps_eval=5, standardize=True, @@ -66,6 +82,8 @@ def test_training(datastore_name): num_workers=1, num_past_forcing_steps=1, num_future_forcing_steps=1, + num_past_boundary_steps=1, + num_future_boundary_steps=1, ) class ModelArgs: @@ -85,6 +103,8 @@ class ModelArgs: metrics_watch = [] num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 model_args = ModelArgs()