Skip to content

Commit

Permalink
Refactor croissantbuilder code
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697925679
  • Loading branch information
The TensorFlow Datasets Authors committed Nov 19, 2024
1 parent 778b3b4 commit 753ed0c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
27 changes: 16 additions & 11 deletions tensorflow_datasets/core/dataset_builders/croissant_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def datatype_converter(
if not field_data_type:
# Fields with sub fields are of type None
if field.sub_fields:
return features_dict.FeaturesDict(
feature = features_dict.FeaturesDict(
{
subfield.id: datatype_converter(
subfield, int_dtype=int_dtype, float_dtype=float_dtype
Expand All @@ -109,26 +109,33 @@ def datatype_converter(
},
doc=field.description,
)
return None
else:
feature = None
elif field_data_type == int:
return int_dtype
feature = int_dtype
elif field_data_type == float:
return float_dtype
feature = float_dtype
elif field_data_type == bool:
return np.bool_
feature = np.bool_
elif field_data_type == bytes:
return text_feature.Text(doc=field.description)
feature = text_feature.Text(doc=field.description)
# We return a text feature for mlc.DataType.DATE features.
elif field_data_type == pd.Timestamp:
return text_feature.Text(doc=field.description)
feature = text_feature.Text(doc=field.description)
elif field_data_type == mlc.DataType.IMAGE_OBJECT:
return image_feature.Image(doc=field.description)
feature = image_feature.Image(doc=field.description)
elif field_data_type == mlc.DataType.BOUNDING_BOX:
# TFDS uses REL_YXYX by default, but Hugging Face doesn't enforce a format.
return bounding_boxes.BBoxFeature(doc=field.description, bbox_format=None)
feature = bounding_boxes.BBoxFeature(
doc=field.description, bbox_format=None
)
else:
raise ValueError(f'Unknown data type: {field_data_type}.')

if feature and field.repeated:
feature = sequence_feature.Sequence(feature, doc=field.description)
return feature


def _extract_license(license_: Any) -> str | None:
"""Extracts the full terms of a license as a string.
Expand Down Expand Up @@ -271,8 +278,6 @@ def get_features(self) -> features_dict.FeaturesDict:
feature = datatype_converter(
field, int_dtype=self._int_dtype, float_dtype=self._float_dtype
)
if field.repeated:
feature = sequence_feature.Sequence(feature)
features[field.id] = feature
features = _strip_record_set_prefix(features, record_set.id)
return features_dict.FeaturesDict(features)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tensorflow_datasets.core.features import bounding_boxes
from tensorflow_datasets.core.features import features_dict
from tensorflow_datasets.core.features import image_feature
from tensorflow_datasets.core.features import sequence_feature
from tensorflow_datasets.core.features import tensor_feature
from tensorflow_datasets.core.features import text_feature
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
Expand Down Expand Up @@ -146,15 +147,26 @@ def test_simple_datatype_converter(field, feature_type, int_dtype, float_dtype):
)
def test_complex_datatype_converter(field, feature_type, subfield_types):
actual_feature = croissant_builder.datatype_converter(field)
assert isinstance(actual_feature, feature_type)
assert actual_feature.doc.desc == field.description
assert isinstance(actual_feature, feature_type)
if subfield_types:
for feature_name in actual_feature.keys():
assert isinstance(
actual_feature[feature_name], subfield_types[feature_name]
)


def test_sequence_feature_datatype_converter():
field = mlc.Field(
data_types=mlc.DataType.TEXT,
description="Text feature",
repeated=True,
)
actual_feature = croissant_builder.datatype_converter(field)
assert isinstance(actual_feature, sequence_feature.Sequence)
assert isinstance(actual_feature.feature, text_feature.Text)


@pytest.fixture(name="crs_builder")
def mock_croissant_dataset_builder(tmp_path, request):
dataset_name = request.param["dataset_name"]
Expand Down

0 comments on commit 753ed0c

Please sign in to comment.