Skip to content

Commit

Permalink
unpack generator
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 4, 2023
1 parent fdfc84d commit f7aeac4
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions zoobot/pytorch/training/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ class WebDataModule(pl.LightningDataModule):
def __init__(self, train_urls, val_urls, train_size=None, val_size=None, label_cols=None, batch_size=64, num_workers=4, cache_dir=None):
super().__init__()

if isinstance(train_urls, types.GeneratorType):
train_urls = list(train_urls)
if isinstance(val_urls, types.GeneratorType):
val_urls = list(val_urls)
# if isinstance(train_urls, types.GeneratorType):
# train_urls = list(train_urls)
# if isinstance(val_urls, types.GeneratorType):
# val_urls = list(val_urls)
self.train_urls = train_urls
self.val_urls = val_urls

Expand Down Expand Up @@ -135,9 +135,10 @@ def val_dataloader(self):
# return parser

def nodesplitter_func(urls):
print(urls)
try:
node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size()
return urls[node_id::node_count]
return list(urls)[node_id::node_count]
except RuntimeError:
print('Distributed not initialised. Hopefully single node.')
return urls
Expand Down

0 comments on commit f7aeac4

Please sign in to comment.