From 57807ceaea1825a7fc22c30d2ee21b0f1ff433ac Mon Sep 17 00:00:00 2001 From: Gabriel Ilharco Date: Thu, 5 Jan 2023 14:29:47 +0000 Subject: [PATCH 1/2] Stop trying to guess the length of webdataset datasets --- src/training/data.py | 43 ++++++++++-------------------------------- src/training/params.py | 4 ++-- 2 files changed, 12 insertions(+), 35 deletions(-) diff --git a/src/training/data.py b/src/training/data.py index 671ab8c82..1320df0e4 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -72,28 +72,6 @@ def set_epoch(self, epoch): self.sampler.set_epoch(epoch) -def get_dataset_size(shards): - shards_list = list(braceexpand.braceexpand(shards)) - dir_path = os.path.dirname(shards) - sizes_filename = os.path.join(dir_path, 'sizes.json') - len_filename = os.path.join(dir_path, '__len__') - if os.path.exists(sizes_filename): - sizes = json.load(open(sizes_filename, 'r')) - total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) - elif os.path.exists(len_filename): - # FIXME this used to be eval(open(...)) but that seemed rather unsafe - total_size = ast.literal_eval(open(len_filename, 'r').read()) - else: - total_size = None # num samples undefined - # some common dataset sizes (at time of authors last download) - # CC3M (train): 2905954 - # CC12M: 10968539 - # LAION-400M: 407332084 - # LAION-2B (english): 2170337258 - num_shards = len(shards_list) - return total_size, num_shards - - def get_imagenet(args, preprocess_fns, split): assert split in ["train", "val", "v2"] is_train = split == "train" @@ -301,16 +279,14 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni assert input_shards is not None resampled = getattr(args, 'dataset_resampled', False) and is_train - num_samples, num_shards = get_dataset_size(input_shards) - if not num_samples: - if is_train: - num_samples = args.train_num_samples - if not num_samples: - raise RuntimeError( - 'Currently, number of dataset samples must be specified for training dataset. ' - 'Please specify via `--train-num-samples` if no dataset length info present.') - else: - num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified + if is_train: + num_samples = args.train_num_samples + if num_samples is None: + raise RuntimeError( + 'The number of training samples must be specified when training with webdataset. ' + 'Please specify it via `--train-num-samples`.') + else: + num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc @@ -358,7 +334,8 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni dataset = wds.DataPipeline(*pipeline) if is_train: if not resampled: - assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' + num_shards = wds.shardlists.expand_urls(input_shards) + assert num_shards >= args.workers * args.world_size, 'The number of shards must be >= total workers' # roll over and repeat a few samples to get same number of full batches on each node round_fn = math.floor if floor else math.ceil global_batch_size = args.batch_size * args.world_size diff --git a/src/training/params.py b/src/training/params.py index abc07dd50..d592a363e 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -28,13 +28,13 @@ def parse_args(args): "--train-num-samples", type=int, default=None, - help="Number of samples in dataset. Required for webdataset if not available in info file.", + help="Number of samples in dataset. Required for webdataset.", ) parser.add_argument( "--val-num-samples", type=int, default=None, - help="Number of samples in dataset. Useful for webdataset if not available in info file.", + help="Number of samples in dataset. Useful for webdataset.", ) parser.add_argument( "--dataset-type", From 342dcfd9a8e19e5f9406d268baf5905a7631348f Mon Sep 17 00:00:00 2001 From: Gabriel Ilharco Date: Thu, 5 Jan 2023 15:26:36 +0000 Subject: [PATCH 2/2] remove old test --- tests/test_num_shards.py | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 tests/test_num_shards.py diff --git a/tests/test_num_shards.py b/tests/test_num_shards.py deleted file mode 100644 index 70ca8fecc..000000000 --- a/tests/test_num_shards.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest - -from training.data import get_dataset_size - -@pytest.mark.parametrize( - "shards,expected_size", - [ - ('/path/to/shard.tar', 1), - ('/path/to/shard_{000..000}.tar', 1), - ('/path/to/shard_{000..009}.tar', 10), - ('/path/to/shard_{000..009}_{000..009}.tar', 100), - ('/path/to/shard.tar::/path/to/other_shard_{000..009}.tar', 11), - ('/path/to/shard_{000..009}.tar::/path/to/other_shard_{000..009}.tar', 20), - (['/path/to/shard.tar'], 1), - (['/path/to/shard.tar', '/path/to/other_shard.tar'], 2), - ] -) -def test_num_shards(shards, expected_size): - _, size = get_dataset_size(shards) - assert size == expected_size, f'Expected {expected_size} for {shards} but found {size} instead.'