diff --git a/zoobot/pytorch/training/webdatamodule.py b/zoobot/pytorch/training/webdatamodule.py index a359c635..6449af55 100644 --- a/zoobot/pytorch/training/webdatamodule.py +++ b/zoobot/pytorch/training/webdatamodule.py @@ -100,7 +100,8 @@ def make_loader(self, urls, mode="train"): batch_size=None, # already batched shuffle=False, # already shuffled num_workers=self.num_workers, - pin_memory=True + pin_memory=True, + prefetch_factor=10 ) # print('sampling')