diff --git a/src/cloudcasting/dataset.py b/src/cloudcasting/dataset.py index 13f2e9b..b649976 100644 --- a/src/cloudcasting/dataset.py +++ b/src/cloudcasting/dataset.py @@ -44,7 +44,7 @@ def load_satellite_zarrs(zarr_path: list[str] | tuple[str] | str) -> xr.Dataset: if isinstance(zarr_path, list | tuple): ds = xr.combine_nested( - [xr.open_dataset(path, engine="zarr") for path in zarr_path], + [xr.open_dataset(path, engine="zarr", chunks="auto") for path in zarr_path], concat_dim="time", combine_attrs="override", join="override", @@ -296,6 +296,8 @@ def __init__( val_period: list[str | None] | tuple[str | None] | None = None, test_period: list[str | None] | tuple[str | None] | None = None, nan_to_num: bool = False, + pin_memory: bool = False, + persistent_workers: bool = False, ): """A lightning DataModule for loading past and future satellite data @@ -311,6 +313,11 @@ def __init__( train_period: Date range filter for train dataloader. val_period: Date range filter for val dataloader. test_period: Date range filter for test dataloader. + pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory + before returning them + persistent_workers: If True, the data loader will not shut down the worker processes + after a dataset has been consumed once. This allows to maintain the workers Dataset + instances alive. """ super().__init__() @@ -339,12 +346,12 @@ def __init__( sampler=None, batch_sampler=None, num_workers=num_workers, - pin_memory=False, + pin_memory=pin_memory, drop_last=False, timeout=0, worker_init_fn=None, prefetch_factor=prefetch_factor, - persistent_workers=False, + persistent_workers=persistent_workers, ) self.nan_to_num = nan_to_num