Skip to content

Commit

Permalink
Add splits to the test croissant.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699156000
  • Loading branch information
The TensorFlow Datasets Authors committed Nov 27, 2024
1 parent 14f2854 commit 682d435
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 31 deletions.
30 changes: 20 additions & 10 deletions tensorflow_datasets/core/dataset_builders/croissant_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,21 @@
FileFormat = file_adapters.FileFormat


DUMMY_ENTRIES = entries = [
{"index": i, "text": f"Dummy example {i}"} for i in range(2)
DUMMY_ENTRIES = [
{
"index": i,
"text": f"Dummy example {i}",
"split": "train" if i == 0 else "test",
}
for i in range(2)
]
DUMMY_ENTRIES_WITH_NONE_VALUES = [
{"index": 0, "text": "Dummy example 0"},
{"index": 1, "text": None},
{"split": "train", "index": 0, "text": "Dummy example 0"},
{"split": "test", "index": 1, "text": None},
]
DUMMY_ENTRIES_WITH_CONVERTED_NONE_VALUES = [
{"index": 0, "text": "Dummy example 0"},
{"index": 1, "text": ""},
{"split": "train", "index": 0, "text": "Dummy example 0"},
{"split": "test", "index": 1, "text": ""},
]


Expand Down Expand Up @@ -173,6 +178,7 @@ def mock_croissant_dataset_builder(tmp_path, request):
with testing.dummy_croissant_file(
dataset_name=dataset_name,
entries=request.param["entries"],
split_names=["train", "test"],
) as croissant_file:
builder = croissant_builder.CroissantBuilder(
jsonld=croissant_file,
Expand Down Expand Up @@ -203,8 +209,12 @@ def test_croissant_builder(crs_builder):
assert crs_builder._info().description == "Dummy description."
assert crs_builder._info().homepage == "https://dummy_url"
assert crs_builder._info().redistribution_info.license == "Public"
assert len(crs_builder.metadata.record_sets) == 1
assert crs_builder.metadata.record_sets[0].id == "jsonl"
# One `split` and one `jsonl` recordset.
assert len(crs_builder.metadata.record_sets) == 2
assert set([rs.id for rs in crs_builder.metadata.record_sets]) == {
"jsonl",
"split",
}
assert (
crs_builder.metadata.ctx.conforms_to.value
== "http://mlcommons.org/croissant/1.0"
Expand All @@ -228,11 +238,11 @@ def test_croissant_builder(crs_builder):
],
indirect=["crs_builder"],
)
@pytest.mark.parametrize("split_name", ["all", "default"])
@pytest.mark.parametrize("split_name", ["train", "test"])
def test_download_and_prepare(crs_builder, expected_entries, split_name):
crs_builder.download_and_prepare()
data_source = crs_builder.as_data_source(split=split_name)
assert len(data_source) == 2
assert len(data_source) == 1
for entry, expected_entry in zip(data_source, expected_entries):
assert entry["index"] == expected_entry["index"]
assert entry["text"].decode() == expected_entry["text"]
111 changes: 90 additions & 21 deletions tensorflow_datasets/testing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ def dummy_croissant_file(
entries: Sequence[dict[str, Any]] | None = None,
raw_data_filename: epath.PathLike = 'raw_data.jsonl',
croissant_filename: epath.PathLike = 'croissant.json',
split_names: Sequence[str] | None = None,
) -> Iterator[epath.Path]:
"""Yields temporary path to a dummy Croissant file.
Expand All @@ -732,13 +733,29 @@ def dummy_croissant_file(
Args:
dataset_name: The name of the dataset.
entries: A list of dictionaries representing the dataset's entries. Each
dictionary should contain an 'index' and a 'text' key. If None, the
function will create two entries with indices 0 and 1 and dummy text.
raw_data_filename: Filename of the raw data file.
dictionary should contain an 'index', a 'text', and a `split` key. If
None, the function will create two entries with indices 0 and 1 and dummy
text, and with the first entry belonging to the split `train` and the
second to `test`.
raw_data_filename: Filename of the raw data file. If `split_names` is True,
the function will create a raw data file for each split, including the
split name before the file extension.
croissant_filename: Filename of the Croissant JSON-LD file.
split_names: A list of split names to populate the split record set with. If
split_names are defined, they must match the `split` key in the entries.
If None, the function will create a split record set with the default
split names `train` and `test`. If `split_names` is defined, the `split`
key in the entries must match one of the split names.
"""
if entries is None:
entries = [{'index': i, 'text': f'Dummy example {i}'} for i in range(2)]
entries = [
{
'index': i,
'text': f'Dummy example {i}',
'split': 'train' if i % 2 == 0 else 'test',
}
for i in range(2)
]

fields = [
mlc.Field(
Expand Down Expand Up @@ -771,29 +788,82 @@ def dummy_croissant_file(
fields=fields,
)
]
if split_names:
record_sets[0].fields.append(
mlc.Field(
id='jsonl/split',
name='jsonl/split',
description='The dummy split.',
data_types=mlc.DataType.TEXT,
source=mlc.Source(
file_object='raw_data',
extract=mlc.Extract(file_property='fullpath'),
transforms=[mlc.Transform(regex='.*(.+).+jsonl$')],
),
references=mlc.Source(field='split/name'),
),
)
record_sets.append(
mlc.RecordSet(
id='split',
name='split',
key='split/name',
data_types=[mlc.DataType.SPLIT],
description='Dummy split.',
fields=[
mlc.Field(
id='split/name',
name='split/name',
description='The dummy split name.',
data_types=mlc.DataType.TEXT,
)
],
data=[{'split/name': split_name} for split_name in split_names],
)
)

with tempfile.TemporaryDirectory() as tempdir:
tempdir = epath.Path(tempdir)

# Write raw examples to tempdir/data.
raw_data_dir = tempdir / 'data'
raw_data_dir.mkdir()
raw_data_file = raw_data_dir / raw_data_filename
raw_data_file.write_text('\n'.join(map(json.dumps, entries)))

# Get the actual raw file's hash, set distribution and metadata.
raw_data_file_content = raw_data_file.read_text()
sha256 = hashlib.sha256(raw_data_file_content.encode()).hexdigest()
distribution = [
mlc.FileObject(
id='raw_data',
name='raw_data',
description='File with the data.',
encoding_format='application/jsonlines',
content_url=f'data/{raw_data_filename}',
sha256=sha256,
),
]
if split_names:
parts = str(raw_data_filename).split('.')
file_name, extension = '.'.join(parts[:-1]), parts[-1]
for split_name in split_names:
raw_data_file = raw_data_dir / (
file_name + '_' + split_name + '.' + extension
)
split_entries = [
entry for entry in entries if entry['split'] == split_name
]
raw_data_file.write_text('\n'.join(map(json.dumps, split_entries)))
distribution = [
mlc.FileSet(
id='raw_data',
name='raw_data',
description='Files with the data.',
encoding_format='application/jsonlines',
includes=f'data/{file_name}*.{extension}',
),
]
else:
raw_data_file = raw_data_dir / raw_data_filename
raw_data_file.write_text('\n'.join(map(json.dumps, entries)))
# Get the actual raw file's hash, set distribution and metadata.
raw_data_file_content = raw_data_file.read_text()
sha256 = hashlib.sha256(raw_data_file_content.encode()).hexdigest()
distribution = [
mlc.FileObject(
id='raw_data',
name='raw_data',
description='File with the data.',
encoding_format='application/jsonlines',
content_url=f'data/{raw_data_filename}',
sha256=sha256,
),
]
dummy_metadata = mlc.Metadata(
name=dataset_name,
description='Dummy description.',
Expand All @@ -807,7 +877,6 @@ def dummy_croissant_file(
version='1.2.0',
license='Public',
)

# Write Croissant JSON-LD to tempdir.
croissant_file = tempdir / croissant_filename
croissant_file.write_text(json.dumps(dummy_metadata.to_json(), indent=2))
Expand Down

0 comments on commit 682d435

Please sign in to comment.