diff --git a/tensorflow_datasets/core/dataset_builders/croissant_builder_test.py b/tensorflow_datasets/core/dataset_builders/croissant_builder_test.py index 0fc3103066b..de06a95df70 100644 --- a/tensorflow_datasets/core/dataset_builders/croissant_builder_test.py +++ b/tensorflow_datasets/core/dataset_builders/croissant_builder_test.py @@ -31,16 +31,21 @@ FileFormat = file_adapters.FileFormat -DUMMY_ENTRIES = entries = [ - {"index": i, "text": f"Dummy example {i}"} for i in range(2) +DUMMY_ENTRIES = [ + { + "index": i, + "text": f"Dummy example {i}", + "split": "train" if i == 0 else "test", + } + for i in range(2) ] DUMMY_ENTRIES_WITH_NONE_VALUES = [ - {"index": 0, "text": "Dummy example 0"}, - {"index": 1, "text": None}, + {"split": "train", "index": 0, "text": "Dummy example 0"}, + {"split": "test", "index": 1, "text": None}, ] DUMMY_ENTRIES_WITH_CONVERTED_NONE_VALUES = [ - {"index": 0, "text": "Dummy example 0"}, - {"index": 1, "text": ""}, + {"split": "train", "index": 0, "text": "Dummy example 0"}, + {"split": "test", "index": 1, "text": ""}, ] @@ -173,6 +178,7 @@ def mock_croissant_dataset_builder(tmp_path, request): with testing.dummy_croissant_file( dataset_name=dataset_name, entries=request.param["entries"], + split_names=["train", "test"], ) as croissant_file: builder = croissant_builder.CroissantBuilder( jsonld=croissant_file, @@ -203,8 +209,12 @@ def test_croissant_builder(crs_builder): assert crs_builder._info().description == "Dummy description." assert crs_builder._info().homepage == "https://dummy_url" assert crs_builder._info().redistribution_info.license == "Public" - assert len(crs_builder.metadata.record_sets) == 1 - assert crs_builder.metadata.record_sets[0].id == "jsonl" + # One `split` and one `jsonl` recordset. + assert len(crs_builder.metadata.record_sets) == 2 + assert set([rs.id for rs in crs_builder.metadata.record_sets]) == { + "jsonl", + "split", + } assert ( crs_builder.metadata.ctx.conforms_to.value == "http://mlcommons.org/croissant/1.0" @@ -228,11 +238,11 @@ def test_croissant_builder(crs_builder): ], indirect=["crs_builder"], ) -@pytest.mark.parametrize("split_name", ["all", "default"]) +@pytest.mark.parametrize("split_name", ["train", "test"]) def test_download_and_prepare(crs_builder, expected_entries, split_name): crs_builder.download_and_prepare() data_source = crs_builder.as_data_source(split=split_name) - assert len(data_source) == 2 + assert len(data_source) == 1 for entry, expected_entry in zip(data_source, expected_entries): assert entry["index"] == expected_entry["index"] assert entry["text"].decode() == expected_entry["text"] diff --git a/tensorflow_datasets/testing/test_utils.py b/tensorflow_datasets/testing/test_utils.py index 4c970108a94..be0b0020456 100644 --- a/tensorflow_datasets/testing/test_utils.py +++ b/tensorflow_datasets/testing/test_utils.py @@ -723,6 +723,7 @@ def dummy_croissant_file( entries: Sequence[dict[str, Any]] | None = None, raw_data_filename: epath.PathLike = 'raw_data.jsonl', croissant_filename: epath.PathLike = 'croissant.json', + split_names: Sequence[str] | None = None, ) -> Iterator[epath.Path]: """Yields temporary path to a dummy Croissant file. @@ -732,13 +733,29 @@ def dummy_croissant_file( Args: dataset_name: The name of the dataset. entries: A list of dictionaries representing the dataset's entries. Each - dictionary should contain an 'index' and a 'text' key. If None, the - function will create two entries with indices 0 and 1 and dummy text. - raw_data_filename: Filename of the raw data file. + dictionary should contain an 'index', a 'text', and a `split` key. If + None, the function will create two entries with indices 0 and 1 and dummy + text, and with the first entry belonging to the split `train` and the + second to `test`. + raw_data_filename: Filename of the raw data file. If `split_names` is True, + the function will create a raw data file for each split, including the + split name before the file extension. croissant_filename: Filename of the Croissant JSON-LD file. + split_names: A list of split names to populate the split record set with. If + split_names are defined, they must match the `split` key in the entries. + If None, the function will create a split record set with the default + split names `train` and `test`. If `split_names` is defined, the `split` + key in the entries must match one of the split names. """ if entries is None: - entries = [{'index': i, 'text': f'Dummy example {i}'} for i in range(2)] + entries = [ + { + 'index': i, + 'text': f'Dummy example {i}', + 'split': 'train' if i % 2 == 0 else 'test', + } + for i in range(2) + ] fields = [ mlc.Field( @@ -771,6 +788,39 @@ def dummy_croissant_file( fields=fields, ) ] + if split_names: + record_sets[0].fields.append( + mlc.Field( + id='jsonl/split', + name='jsonl/split', + description='The dummy split.', + data_types=mlc.DataType.TEXT, + source=mlc.Source( + file_object='raw_data', + extract=mlc.Extract(file_property='fullpath'), + transforms=[mlc.Transform(regex='.*(.+).+jsonl$')], + ), + references=mlc.Source(field='split/name'), + ), + ) + record_sets.append( + mlc.RecordSet( + id='split', + name='split', + key='split/name', + data_types=[mlc.DataType.SPLIT], + description='Dummy split.', + fields=[ + mlc.Field( + id='split/name', + name='split/name', + description='The dummy split name.', + data_types=mlc.DataType.TEXT, + ) + ], + data=[{'split/name': split_name} for split_name in split_names], + ) + ) with tempfile.TemporaryDirectory() as tempdir: tempdir = epath.Path(tempdir) @@ -778,22 +828,42 @@ def dummy_croissant_file( # Write raw examples to tempdir/data. raw_data_dir = tempdir / 'data' raw_data_dir.mkdir() - raw_data_file = raw_data_dir / raw_data_filename - raw_data_file.write_text('\n'.join(map(json.dumps, entries))) - - # Get the actual raw file's hash, set distribution and metadata. - raw_data_file_content = raw_data_file.read_text() - sha256 = hashlib.sha256(raw_data_file_content.encode()).hexdigest() - distribution = [ - mlc.FileObject( - id='raw_data', - name='raw_data', - description='File with the data.', - encoding_format='application/jsonlines', - content_url=f'data/{raw_data_filename}', - sha256=sha256, - ), - ] + if split_names: + parts = str(raw_data_filename).split('.') + file_name, extension = '.'.join(parts[:-1]), parts[-1] + for split_name in split_names: + raw_data_file = raw_data_dir / ( + file_name + '_' + split_name + '.' + extension + ) + split_entries = [ + entry for entry in entries if entry['split'] == split_name + ] + raw_data_file.write_text('\n'.join(map(json.dumps, split_entries))) + distribution = [ + mlc.FileSet( + id='raw_data', + name='raw_data', + description='Files with the data.', + encoding_format='application/jsonlines', + includes=f'data/{file_name}*.{extension}', + ), + ] + else: + raw_data_file = raw_data_dir / raw_data_filename + raw_data_file.write_text('\n'.join(map(json.dumps, entries))) + # Get the actual raw file's hash, set distribution and metadata. + raw_data_file_content = raw_data_file.read_text() + sha256 = hashlib.sha256(raw_data_file_content.encode()).hexdigest() + distribution = [ + mlc.FileObject( + id='raw_data', + name='raw_data', + description='File with the data.', + encoding_format='application/jsonlines', + content_url=f'data/{raw_data_filename}', + sha256=sha256, + ), + ] dummy_metadata = mlc.Metadata( name=dataset_name, description='Dummy description.', @@ -807,7 +877,6 @@ def dummy_croissant_file( version='1.2.0', license='Public', ) - # Write Croissant JSON-LD to tempdir. croissant_file = tempdir / croissant_filename croissant_file.write_text(json.dumps(dummy_metadata.to_json(), indent=2))