From 6c08a01b9aa80ea7b6ae3573165829833bb75b41 Mon Sep 17 00:00:00 2001 From: habanoz Date: Fri, 29 Nov 2024 23:25:18 +0300 Subject: [PATCH] - resolve issue 308 - review unit tests - update datasets version to 3.1.0 --- pyproject.toml | 2 +- src/datatrove/pipeline/readers/huggingface.py | 17 ++- tests/pipeline/test_hf_reader.py | 130 ++++++++++++++++-- 3 files changed, 132 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cf226903..5031b6c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ io = [ "pyarrow", "python-magic", "warcio", - "datasets>=2.18.0", + "datasets>=3.1.0", "orjson", "zstandard" ] diff --git a/src/datatrove/pipeline/readers/huggingface.py b/src/datatrove/pipeline/readers/huggingface.py index c031cab9..037d24cf 100644 --- a/src/datatrove/pipeline/readers/huggingface.py +++ b/src/datatrove/pipeline/readers/huggingface.py @@ -76,7 +76,20 @@ def _get_dataset_shard(self, dst, rank: int, world_size: int): f"Requested shard {rank} of a streaming dataset, but it only has {dst.n_shards} shards." ) return None - ex_iterable = dst._ex_iterable.shard_data_sources(rank, world_size) + + # https://github.com/huggingface/datatrove/issues/308 + # huggingface/datasets@65f6eb5#diff-edc4da5f2179552e25f4f3dc9d6bf07265b68bbef048a8f712e798520a23d048L103 + # Order of the arguments to shard_data_sources function in datasets lib changed. Make sure current version of the datasets is up-to-date. + import inspect + ex_iterable = dst._ex_iterable.shard_data_sources(world_size, rank) + arg_names = inspect.signature(dst._ex_iterable.shard_data_sources).parameters + + # Assert that the first argument is not "worker_id" + assert list(arg_names.keys())[0]!= "worker_id", "The first argument to shard_data_sources cannot be named 'worker_id'. Make sure datasets version is up-to-date" + # Assert that the second argument is not "num_workers" + assert list(arg_names.keys())[1]!= "num_workers", "The second argument to shard_data_sources cannot be named 'num_workers'. Make sure datasets version is up-to-date" + + ex_iterable = dst._ex_iterable.shard_data_sources(world_size, rank) return IterableDataset( ex_iterable=ex_iterable, info=dst._info.copy(), @@ -96,7 +109,7 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 if data: yield from data ds = load_dataset(self.dataset, **self.dataset_options, streaming=self.streaming) - + if self.shuffle_files: if not self.streaming: ds = ds.shuffle(seed=42) diff --git a/tests/pipeline/test_hf_reader.py b/tests/pipeline/test_hf_reader.py index f8a31926..d3736d67 100644 --- a/tests/pipeline/test_hf_reader.py +++ b/tests/pipeline/test_hf_reader.py @@ -47,17 +47,119 @@ def test_read_streaming_dataset_shuffle(self): self.assertEqual(len(data[0].text), 69) self.assertEqual(len(data[1].text), 46) - def test_sharding(self): - for shards in [1, 3]: - for streaming in [True, False]: - reader = HuggingFaceDatasetReader( - "huggingface/datatrove-tests", - dataset_options={"name": f"sharding-{shards}", "split": "train"}, - text_key="text", - streaming=streaming, - ) - data0 = list(reader(rank=0, world_size=2)) - data1 = list(reader(rank=1, world_size=2)) - - self.assertEqual(len(data0), 3) - self.assertEqual(len(data1), 2) + def test_sharding_1(self): + """ + >>> ds = load_dataset("huggingface/datatrove-tests",name="sharding-1",split="train",streaming=True) + >>> ds + IterableDataset({ + features: ['text'], + num_shards: 1 + }) + + >>> print(list(ds.shard(num_shards=2, index=0))) + [{'text': 'hello'}, {'text': 'world'}, {'text': 'how'}, {'text': 'are'}, {'text': 'you'}] + + >>> print(list(ds.shard(num_shards=2, index=1))) + IndexError: list index out of range + + >>> ds = load_dataset("huggingface/datatrove-tests",name="sharding-1",split="train",streaming=False) + >>> ds + Dataset({ + features: ['text'], + num_rows: 5 + }) + + >>> print(list(ds.shard(num_shards=2, index=0))) + >>> print(list(ds.shard(num_shards=2, index=1))) + [{'text': 'hello'}, {'text': 'world'}, {'text': 'how'}] + [{'text': 'are'}, {'text': 'you'}] + + >>> print(list(ds.shard(num_shards=3, index=0))) + >>> print(list(ds.shard(num_shards=3, index=1))) + >>> print(list(ds.shard(num_shards=3, index=2))) + [{'text': 'hello'}, {'text': 'world'}] + [{'text': 'how'}, {'text': 'are'}] + [{'text': 'you'}] + + """ + for streaming in [True, False]: + reader = HuggingFaceDatasetReader( + "huggingface/datatrove-tests", + dataset_options={"name": f"sharding-1", "split": "train"}, + text_key="text", + streaming=streaming, + ) + data0 = list(reader(rank=0, world_size=2)) + data1 = list(reader(rank=1, world_size=2)) + + self.assertEqual(len(data0), 3) + self.assertEqual(len(data1), 2) + + def test_sharding_3_stream(self): + """ + >>> ds_stream = load_dataset("huggingface/datatrove-tests",name="sharding-3",split="train",streaming=True) + >>> ds_stream + IterableDataset({ + features: ['text'], + num_shards: 3 + }) + + >>> print(list(ds_stream.shard(num_shards=2, index=0))) + >>> print(list(ds_stream.shard(num_shards=2, index=1))) + [{'text': 'hello'}, {'text': 'world'}, {'text': 'how'}, {'text': 'are'}] + [{'text': 'you'}] + + >>> print(list(list(ds_stream.shard(num_shards=3, index=0)))) + >>> print(list(list(ds_stream.shard(num_shards=3, index=1)))) + >>> print(list(list(ds_stream.shard(num_shards=3, index=2)))) + [{'text': 'hello'}, {'text': 'world'}] + [{'text': 'how'}, {'text': 'are'}] + [{'text': 'you'}] + + """ + reader = HuggingFaceDatasetReader( + "huggingface/datatrove-tests", + dataset_options={"name": f"sharding-3", "split": "train"}, + text_key="text", + streaming=True, + ) + data0 = list(reader(rank=0, world_size=2)) + data1 = list(reader(rank=1, world_size=2)) + + self.assertEqual(len(data0), 4) + self.assertEqual(len(data1), 1) + + def test_sharding_3(self): + """ + >>> ds = load_dataset("huggingface/datatrove-tests",name="sharding-3",split="train",streaming=False) + >>> ds + Dataset({ + features: ['text'], + num_rows: 5 + }) + + >>> print(list(ds.shard(num_shards=2, index=0))) + >>> print(list(ds.shard(num_shards=2, index=1))) + [{'text': 'hello'}, {'text': 'world'}, {'text': 'how'}] + [{'text': 'are'}, {'text': 'you'}] + + >>> print(list(ds.shard(num_shards=3, index=0))) + >>> print(list(ds.shard(num_shards=3, index=1))) + >>> print(list(ds.shard(num_shards=3, index=2))) + [{'text': 'hello'}, {'text': 'world'}] + [{'text': 'how'}, {'text': 'are'}] + [{'text': 'you'}] + + """ + reader = HuggingFaceDatasetReader( + "huggingface/datatrove-tests", + dataset_options={"name": f"sharding-3", "split": "train"}, + text_key="text", + streaming=False, + ) + data0 = list(reader(rank=0, world_size=2)) + data1 = list(reader(rank=1, world_size=2)) + + self.assertEqual(len(data0), 3) + self.assertEqual(len(data1), 2) +