You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
To accelerate training with PyTorch, it would be useful to make use of the preprocessed and cached tfrecord datasets we've produced for operational Tensorflow training! Loading from these tfrecord files is (should be?) much faster than loading from NetCDF files on the fly using the generate_sample logic in, e.g., the DaskMultiWorkerLoader.generate_sample defined in data/loaders/dask.py which must perform a decent amount of computation.
What I Did
I've implemented an IterableIceNetDatasetPyTorch class which inherits from torch.utils.data.IterableDataset as a first go, and will link it below in a pull request. The following script demonstrates its use and can easily be stepped through in a debugger. The batching logic works properly with torch.utils.data.DataLoader in the script below, but I've seen some weird behaviour (overrunning the number of samples that should be generated per epoch) during training runs when num_workers > 1 in the torch.utils.data.DataLoader so there's room to improve this first implementation.
importosimporttorchfromutilsimportIterableIceNetDataSetPyTorchdataset_config="dataset_config.exp23_south.json"ds=IterableIceNetDataSetPyTorch(dataset_config, "test", batch_size=4, shuffling=False)
dl=torch.utils.data.DataLoader(ds, batch_size=4, shuffle=False)
fori, batchinenumerate(dl):
x, y, sw=batchprint(x, y, sw)
print(i)
The text was updated successfully, but these errors were encountered:
Description
To accelerate training with PyTorch, it would be useful to make use of the preprocessed and cached
tfrecord
datasets we've produced for operational Tensorflow training! Loading from thesetfrecord
files is (should be?) much faster than loading from NetCDF files on the fly using thegenerate_sample
logic in, e.g., theDaskMultiWorkerLoader.generate_sample
defined indata/loaders/dask.py
which must perform a decent amount of computation.What I Did
I've implemented an
IterableIceNetDatasetPyTorch
class which inherits fromtorch.utils.data.IterableDataset
as a first go, and will link it below in a pull request. The following script demonstrates its use and can easily be stepped through in a debugger. The batching logic works properly withtorch.utils.data.DataLoader
in the script below, but I've seen some weird behaviour (overrunning the number of samples that should be generated per epoch) during training runs whennum_workers
> 1 in thetorch.utils.data.DataLoader
so there's room to improve this first implementation.The text was updated successfully, but these errors were encountered: