Skip to content

Commit

Permalink
add back NCCL, remove nodesplitter func
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 6, 2023
1 parent 1feb215 commit aeb2150
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion only_for_me/narval/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ PYTHON=/home/walml/envs/zoobot39_dev/bin/python
mkdir $SLURM_TMPDIR/cache
# mkdir /tmp/cache

# export NCCL_BLOCKING_WAIT=1 #Set this environment variable if you wish to use the NCCL backend for inter-GPU communication.
export NCCL_BLOCKING_WAIT=1 #Set this environment variable if you wish to use the NCCL backend for inter-GPU communication.
# export MASTER_ADDR=$(hostname) #Store the master node’s IP address in the MASTER_ADDR environment variable.
# echo "r$SLURM_NODEID master: $MASTER_ADDR"
# echo "r$SLURM_NODEID Launching python script"
Expand Down
19 changes: 10 additions & 9 deletions zoobot/pytorch/training/webdatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def make_loader(self, urls, mode="train"):
dataset = (
# https://webdataset.github.io/webdataset/multinode/
# WDS 'knows' which worker it is running on and selects a subset of urls accordingly
wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0, nodesplitter=nodesplitter_func)
wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0)
# , nodesplitter=nodesplitter_func)
.shuffle(shuffle)
.decode("rgb")
.to_tuple('image.jpg', 'labels.json')
Expand Down Expand Up @@ -137,14 +138,14 @@ def val_dataloader(self):
# parser.add_argument("--valshards", default="imagenet-val-{000000..000006}.tar")
# return parser

def nodesplitter_func(urls): # SimpleShardList
# print(urls)
try:
node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size()
return list(urls)[node_id::node_count]
except RuntimeError:
print('Distributed not initialised. Hopefully single node.')
return urls
# def nodesplitter_func(urls): # SimpleShardList
# # print(urls)
# try:
# node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size()
# return list(urls)[node_id::node_count]
# except RuntimeError:
# print('Distributed not initialised. Hopefully single node.')
# return urls

def identity(x):
return x

0 comments on commit aeb2150

Please sign in to comment.