diff --git a/only_for_me/narval/train.py b/only_for_me/narval/train.py index 8158f267..846b88e0 100644 --- a/only_for_me/narval/train.py +++ b/only_for_me/narval/train.py @@ -118,7 +118,7 @@ resize_after_crop=args.resize_after_crop, # hardware parameters # gpus=args.gpus, - gpus=4, + gpus=1, nodes=args.nodes, mixed_precision=args.mixed_precision, wandb_logger=wandb_logger, diff --git a/only_for_me/narval/train.sh b/only_for_me/narval/train.sh index fc03f575..9be5eda5 100644 --- a/only_for_me/narval/train.sh +++ b/only_for_me/narval/train.sh @@ -2,9 +2,9 @@ #SBATCH --mem-per-cpu 4G #SBATCH --nodes=1 #SBATCH --time=0:40:0 -#SBATCH --tasks-per-node=4 +#SBATCH --tasks-per-node=1 #SBATCH --cpus-per-task=12 -#SBATCH --gres=gpu:a100:4 +#SBATCH --gres=gpu:a100:1 nvidia-smi diff --git a/zoobot/pytorch/training/webdatamodule.py b/zoobot/pytorch/training/webdatamodule.py index e75aa986..f1a94680 100644 --- a/zoobot/pytorch/training/webdatamodule.py +++ b/zoobot/pytorch/training/webdatamodule.py @@ -84,7 +84,7 @@ def make_loader(self, urls, mode="train"): # torch collate stacks dicts nicely while webdataset only lists them # so use the torch collate instead .batched(self.batch_size, torch.utils.data.default_collate, partial=False) - .repeat(2) + .repeat(5) ) # from itertools import islice