Skip to content

Commit

Permalink
try repeat and pin memory
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 4, 2023
1 parent f7aeac4 commit 1b2e2f2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion only_for_me/narval/train.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash
#SBATCH --mem=80G
#SBATCH --nodes=1
#SBATCH --time=0:20:0
#SBATCH --time=0:40:0
#SBATCH --tasks-per-node=2
#SBATCH --cpus-per-task=12
#SBATCH --gres=gpu:a100:2
Expand Down
10 changes: 6 additions & 4 deletions zoobot/pytorch/training/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def make_loader(self, urls, mode="train"):
.map_tuple(transform_image, transform_label)
# torch collate stacks dicts nicely while webdataset only lists them
# so use the torch collate instead
.batched(self.batch_size, torch.utils.data.default_collate, partial=False)
.batched(self.batch_size, torch.utils.data.default_collate, partial=False)
.repeat(2)
)

# from itertools import islice
Expand All @@ -97,8 +98,9 @@ def make_loader(self, urls, mode="train"):
loader = wds.WebLoader(
dataset,
batch_size=None, # already batched
shuffle=False,
shuffle=False, # already shuffled
num_workers=self.num_workers,
pin_memory=True
)

# print('sampling')
Expand Down Expand Up @@ -134,8 +136,8 @@ def val_dataloader(self):
# parser.add_argument("--valshards", default="imagenet-val-{000000..000006}.tar")
# return parser

def nodesplitter_func(urls):
print(urls)
def nodesplitter_func(urls): # SimpleShardList
# print(urls)
try:
node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size()
return list(urls)[node_id::node_count]
Expand Down

0 comments on commit 1b2e2f2

Please sign in to comment.