diff --git a/only_for_me/narval/train.sh b/only_for_me/narval/train.sh index ceff355e..be3032cd 100644 --- a/only_for_me/narval/train.sh +++ b/only_for_me/narval/train.sh @@ -13,7 +13,7 @@ PYTHON=/home/walml/envs/zoobot39_dev/bin/python mkdir $SLURM_TMPDIR/cache # mkdir /tmp/cache -# export NCCL_BLOCKING_WAIT=1 #Set this environment variable if you wish to use the NCCL backend for inter-GPU communication. +export NCCL_BLOCKING_WAIT=1 #Set this environment variable if you wish to use the NCCL backend for inter-GPU communication. # export MASTER_ADDR=$(hostname) #Store the master node’s IP address in the MASTER_ADDR environment variable. # echo "r$SLURM_NODEID master: $MASTER_ADDR" # echo "r$SLURM_NODEID Launching python script" diff --git a/zoobot/pytorch/training/webdatamodule.py b/zoobot/pytorch/training/webdatamodule.py index f1a94680..a965a777 100644 --- a/zoobot/pytorch/training/webdatamodule.py +++ b/zoobot/pytorch/training/webdatamodule.py @@ -76,7 +76,8 @@ def make_loader(self, urls, mode="train"): dataset = ( # https://webdataset.github.io/webdataset/multinode/ # WDS 'knows' which worker it is running on and selects a subset of urls accordingly - wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0, nodesplitter=nodesplitter_func) + wds.WebDataset(urls, cache_dir=self.cache_dir, shardshuffle=shuffle>0) + # , nodesplitter=nodesplitter_func) .shuffle(shuffle) .decode("rgb") .to_tuple('image.jpg', 'labels.json') @@ -137,14 +138,14 @@ def val_dataloader(self): # parser.add_argument("--valshards", default="imagenet-val-{000000..000006}.tar") # return parser -def nodesplitter_func(urls): # SimpleShardList - # print(urls) - try: - node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size() - return list(urls)[node_id::node_count] - except RuntimeError: - print('Distributed not initialised. Hopefully single node.') - return urls +# def nodesplitter_func(urls): # SimpleShardList +# # print(urls) +# try: +# node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size() +# return list(urls)[node_id::node_count] +# except RuntimeError: +# print('Distributed not initialised. Hopefully single node.') +# return urls def identity(x): return x \ No newline at end of file