Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop trying to guess the length of webdataset datasets #339

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 10 additions & 33 deletions src/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,28 +71,6 @@ def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)


def get_dataset_size(shards):
shards_list = wds.shardlists.expand_urls(shards)
dir_path = os.path.dirname(shards_list[0])
sizes_filename = os.path.join(dir_path, 'sizes.json')
len_filename = os.path.join(dir_path, '__len__')
if os.path.exists(sizes_filename):
sizes = json.load(open(sizes_filename, 'r'))
total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])
elif os.path.exists(len_filename):
# FIXME this used to be eval(open(...)) but that seemed rather unsafe
total_size = ast.literal_eval(open(len_filename, 'r').read())
else:
total_size = None # num samples undefined
# some common dataset sizes (at time of authors last download)
# CC3M (train): 2905954
# CC12M: 10968539
# LAION-400M: 407332084
# LAION-2B (english): 2170337258
num_shards = len(shards_list)
return total_size, num_shards


def get_imagenet(args, preprocess_fns, split):
assert split in ["train", "val", "v2"]
is_train = split == "train"
Expand Down Expand Up @@ -300,16 +278,14 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni
assert input_shards is not None
resampled = getattr(args, 'dataset_resampled', False) and is_train

num_samples, num_shards = get_dataset_size(input_shards)
if not num_samples:
if is_train:
num_samples = args.train_num_samples
if not num_samples:
raise RuntimeError(
'Currently, number of dataset samples must be specified for training dataset. '
'Please specify via `--train-num-samples` if no dataset length info present.')
else:
num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified
if is_train:
num_samples = args.train_num_samples
if num_samples is None:
raise RuntimeError(
'The number of training samples must be specified when training with webdataset. '
'Please specify it via `--train-num-samples`.')
else:
num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified

shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc

Expand Down Expand Up @@ -357,7 +333,8 @@ def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokeni
dataset = wds.DataPipeline(*pipeline)
if is_train:
if not resampled:
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers'
num_shards = wds.shardlists.expand_urls(input_shards)
assert num_shards >= args.workers * args.world_size, 'The number of shards must be >= total workers'
# roll over and repeat a few samples to get same number of full batches on each node
round_fn = math.floor if floor else math.ceil
global_batch_size = args.batch_size * args.world_size
Expand Down
4 changes: 2 additions & 2 deletions src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def parse_args(args):
"--train-num-samples",
type=int,
default=None,
help="Number of samples in dataset. Required for webdataset if not available in info file.",
help="Number of samples in dataset. Required for webdataset.",
)
parser.add_argument(
"--val-num-samples",
type=int,
default=None,
help="Number of samples in dataset. Useful for webdataset if not available in info file.",
help="Number of samples in dataset. Useful for webdataset.",
)
parser.add_argument(
"--dataset-type",
Expand Down
20 changes: 0 additions & 20 deletions tests/test_num_shards.py

This file was deleted.