diff --git a/zoobot/pytorch/training/webdatamodule.py b/zoobot/pytorch/training/webdatamodule.py index ae23d2d6..dd01bad6 100644 --- a/zoobot/pytorch/training/webdatamodule.py +++ b/zoobot/pytorch/training/webdatamodule.py @@ -14,10 +14,10 @@ class WebDataModule(pl.LightningDataModule): def __init__(self, train_urls, val_urls, train_size=None, val_size=None, label_cols=None, batch_size=64, num_workers=4, cache_dir=None): super().__init__() - if isinstance(train_urls, types.GeneratorType): - train_urls = list(train_urls) - if isinstance(val_urls, types.GeneratorType): - val_urls = list(val_urls) + # if isinstance(train_urls, types.GeneratorType): + # train_urls = list(train_urls) + # if isinstance(val_urls, types.GeneratorType): + # val_urls = list(val_urls) self.train_urls = train_urls self.val_urls = val_urls @@ -135,9 +135,10 @@ def val_dataloader(self): # return parser def nodesplitter_func(urls): + print(urls) try: node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size() - return urls[node_id::node_count] + return list(urls)[node_id::node_count] except RuntimeError: print('Distributed not initialised. Hopefully single node.') return urls