Skip to content

Commit

Permalink
Fix a bug in LazyMetadataDict where the metadata is not copyable when…
Browse files Browse the repository at this point in the history
… using **.

PiperOrigin-RevId: 696856102
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Nov 15, 2024
1 parent 6bcafe4 commit 0ca4911
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
16 changes: 11 additions & 5 deletions tensorflow_datasets/core/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ def alternative_file_formats(self) -> list[file_adapters.FileFormat]:

@property
def metadata(self) -> Metadata | None:
if isinstance(self._metadata, LazyMetadataDict):
self._metadata.load_metadata_if_needed()
return self._metadata

@property
Expand Down Expand Up @@ -1414,26 +1416,30 @@ def __init__(self, data_dir: epath.PathLike) -> None:
self._data_is_loaded = False
super().__init__()

def _load_metadata(self):
def load_metadata_if_needed(self):
if not self._data_is_loaded:
if _metadata_filepath(self._data_dir).exists():
self.load_metadata(self._data_dir)
self._data_is_loaded = True

def __getitem__(self, key, /):
self._load_metadata()
self.load_metadata_if_needed()
return super().__getitem__(key)

def __eq__(self, value, /):
self._load_metadata()
self.load_metadata_if_needed()
return super().__eq__(value)

def keys(self):
self._load_metadata()
self.load_metadata_if_needed()
return super().keys()

def values(self):
self.load_metadata_if_needed()
return super().values()

def items(self):
self._load_metadata()
self.load_metadata_if_needed()
return super().items()

def copy(self):
Expand Down
8 changes: 8 additions & 0 deletions tensorflow_datasets/core/dataset_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,14 @@ def test_metadata(self):
)
self.assertEqual(builder3.info.metadata, {"some_key": 123})

# Test whether the metadata is copyable.
builder = RandomShapedImageGenerator(data_dir=tmp_dir)
self.assertIsInstance(
builder.info.metadata, dataset_info.LazyMetadataDict
)
metadata_copy = {**builder.info.metadata}
self.assertEqual(metadata_copy, {"some_key": 123})

def test_redistribution_info(self):
info = dataset_info.DatasetInfo(
builder=self._builder, license="some license"
Expand Down

0 comments on commit 0ca4911

Please sign in to comment.