Skip to content

Commit

Permalink
Support Parquet files in TFDS.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 593915075
  • Loading branch information
marcenacp authored and The TensorFlow Datasets Authors committed Dec 27, 2023
1 parent e058008 commit 38953e1
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 1 deletion.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ and this project adheres to

### Added

- Support to download and prepare datasets using the
[Parquet](https://parquet.apache.org) data format.
```python
builder = tfds.builder('fashion_mnist', file_format='parquet')
builder.download_and_prepare()
ds = builder.as_dataset(split='train')
print(next(iter(ds)))
```

### Changed

### Deprecated
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
'promise',
'protobuf>=3.20',
'psutil',
'pyarrow',
'requests>=2.19.0',
'tensorflow-metadata',
'termcolor',
Expand Down
106 changes: 105 additions & 1 deletion tensorflow_datasets/core/file_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,21 @@
from __future__ import annotations

import abc
from collections.abc import Iterator
import enum
import itertools
import os
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Type, Union
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Type, TypeVar, Union
import uuid

from etils import epath
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_module
from tensorflow_datasets.core.utils.lazy_imports_utils import pyarrow as pa
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf

ExamplePositions = List[Any]
T = TypeVar('T')


class FileFormat(enum.Enum):
Expand All @@ -40,6 +44,7 @@ class FileFormat(enum.Enum):
TFRECORD = 'tfrecord'
RIEGELI = 'riegeli'
ARRAY_RECORD = 'array_record'
PARQUET = 'parquet'

@property
def file_suffix(self) -> str:
Expand Down Expand Up @@ -214,6 +219,81 @@ def write_examples(
writer.close()


class ParquetFileAdapter(FileAdapter):
"""File adapter for the [Parquet](https://parquet.apache.org) file format.
This FileAdapter requires `pyarrow` as a dependency and builds upon
`pyarrow.parquet`.
At the moment, the Parquet adapter doesn't leverage Parquet's columnar
features and behaves like any other adapter. Instead of saving the features in
the columns, we use one single `data` column where we store the serialized
tf.Example proto.
TODO(b/317277518): Let Parquet handle the serialization/deserialization.
"""

FILE_SUFFIX = 'parquet'
_PARQUET_FIELD = 'data'
_BATCH_SIZE = 100

@classmethod
def _schema(cls) -> pa.Schema:
"""Returns the Parquet schema as a one-column `data` binary field."""
return pa.schema([pa.field(cls._PARQUET_FIELD, pa.binary())])

@classmethod
def make_tf_data(
cls,
filename: epath.PathLike,
buffer_size: int | None = None,
) -> tf.data.Dataset:
"""Reads a Parquet file as a tf.data.Dataset.
Args:
filename: Path to the Parquet file.
buffer_size: Unused buffer size.
Returns:
A tf.data.Dataset with the serialized examples.
"""
del buffer_size # unused
import pyarrow.parquet as pq # pylint: disable=g-import-not-at-top

def get_data(py_filename: bytes) -> Iterator[tf.Tensor]:
table = pq.read_table(py_filename.decode(), schema=cls._schema())
for batch in table.to_batches():
for example in batch.to_pylist():
yield tf.constant(example[cls._PARQUET_FIELD])

return tf.data.Dataset.from_generator(
get_data,
args=(filename,),
output_signature=tf.TensorSpec(shape=(), dtype=tf.string),
)

@classmethod
def write_examples(
cls,
path: epath.PathLike,
iterator: Iterable[type_utils.KeySerializedExample],
) -> None:
"""Writes the serialized tf.Example proto in a binary field named `data`.
Args:
path: Path to the Parquet file.
iterator: Iterable of serialized examples.
"""
import pyarrow.parquet as pq # pylint: disable=g-import-not-at-top

with pq.ParquetWriter(path, schema=cls._schema()) as writer:
for examples in _batched(iterator, cls._BATCH_SIZE):
examples = [{cls._PARQUET_FIELD: example} for _, example in examples]
batch = pa.RecordBatch.from_pylist(examples)
writer.write_batch(batch)
return None


def _to_bytes(key: type_utils.Key) -> bytes:
"""Convert the key to bytes."""
if isinstance(key, int):
Expand All @@ -231,6 +311,7 @@ def _to_bytes(key: type_utils.Key) -> bytes:
FileFormat.RIEGELI: RiegeliFileAdapter,
FileFormat.TFRECORD: TfRecordFileAdapter,
FileFormat.ARRAY_RECORD: ArrayRecordFileAdapter,
FileFormat.PARQUET: ParquetFileAdapter,
}

_FILE_SUFFIX_TO_FORMAT = {
Expand All @@ -255,3 +336,26 @@ def is_example_file(filename: str) -> bool:
f'.{adapter.FILE_SUFFIX}' in filename
for adapter in ADAPTER_FOR_FORMAT.values()
)


def _batched(iterator: Iterator[T] | Iterable[T], n: int) -> Iterator[List[T]]:
"""Batches the result of an iterator into lists of length n.
This function is built-in the standard library from 3.12 (source:
https://docs.python.org/3/library/itertools.html#itertools.batched). However,
TFDS supports older versions of Python.
Args:
iterator: The iterator to batch.
n: The maximal length of each batch.
Yields:
The next list of n elements.
"""
i = 0
while True:
batch = list(itertools.islice(iterator, i, i + n))
if not batch:
return
yield batch
i += n
26 changes: 26 additions & 0 deletions tensorflow_datasets/core/file_adapters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@
from tensorflow_datasets.core import file_adapters


def test_batched():
assert list(file_adapters._batched(range(10), 5)) == [
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
]
assert list(file_adapters._batched(range(10), 100)) == [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
]
assert not list(file_adapters._batched(range(10), 0))


def test_is_example_file():
assert file_adapters.is_example_file('example1.tfrecord')
assert file_adapters.is_example_file('example1.riegeli')
Expand Down Expand Up @@ -56,12 +67,19 @@ def test_format_suffix():
].FILE_SUFFIX
== 'array_record'
)
assert (
file_adapters.ADAPTER_FOR_FORMAT[
file_adapters.FileFormat.PARQUET
].FILE_SUFFIX
== 'parquet'
)


@pytest.mark.parametrize(
'file_format',
[
file_adapters.FileFormat.TFRECORD,
file_adapters.FileFormat.PARQUET,
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -122,3 +140,11 @@ def test_prase_file_format():
file_adapters.FileFormat.from_value(file_adapters.FileFormat.ARRAY_RECORD)
== file_adapters.FileFormat.ARRAY_RECORD
)
assert (
file_adapters.FileFormat.from_value('parquet')
== file_adapters.FileFormat.PARQUET
)
assert (
file_adapters.FileFormat.from_value(file_adapters.FileFormat.PARQUET)
== file_adapters.FileFormat.PARQUET
)
1 change: 1 addition & 0 deletions tensorflow_datasets/core/utils/lazy_imports_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def mlcroissant_error_callback(**kwargs):
with lazy_imports():
import apache_beam
import pandas
import pyarrow


with lazy_imports(
Expand Down

0 comments on commit 38953e1

Please sign in to comment.