diff --git a/tests/test_multieurlex_utils.py b/tests/test_multieurlex_utils.py index 2958d79..41b162c 100644 --- a/tests/test_multieurlex_utils.py +++ b/tests/test_multieurlex_utils.py @@ -1,4 +1,5 @@ import os +from typing import Any from unittest.mock import patch import datasets @@ -178,12 +179,12 @@ def _check_pipeline_text( ] -def test_load_multieurlex_for_pipeline(): +def _test_load_multieurlex_for_pipeline(expected_keys: set[str], load_ocr_data: bool): data_dir = f"{TEST_ROOT}/testdata/multieurlex_test" lang_pair = {"source": "de", "target": "en"} ds = datasets.load_from_disk(data_dir) - expected_keys = {"celex_id", "labels", "source_text", "target_text"} + expected_non_empty_indices = [0, 1, 3, 4] text_expected_non_empty_indices = [i + 1 for i in expected_non_empty_indices] with patch( @@ -194,7 +195,7 @@ def test_load_multieurlex_for_pipeline(): level=1, lang_pair=lang_pair, drop_empty=True, - load_ocr_data=False, + load_ocr_data=load_ocr_data, ) assert len(dataset_dict) == 3 for split in ["train", "validation", "test"]: @@ -213,63 +214,57 @@ def test_load_multieurlex_for_pipeline(): ignore_keys=["source_text", "target_text", "ocr_images", "ocr_targets"], ) + return dataset_dict, metadata -def test_load_multieurlex_for_pipeline_ocr(): - data_dir = f"{TEST_ROOT}/testdata/multieurlex_test" - lang_pair = {"source": "de", "target": "en"} - ds = datasets.load_from_disk(data_dir) - expected_keys = { - "celex_id", - "labels", - "source_text", - "target_text", - "ocr_images", - "ocr_targets", - } - expected_non_empty_indices = [0, 1, 3, 4] - text_expected_non_empty_indices = [i + 1 for i in expected_non_empty_indices] +def test_load_multieurlex_for_pipeline(): + expected_keys = {"celex_id", "labels", "source_text", "target_text"} + _test_load_multieurlex_for_pipeline( + expected_keys=expected_keys, load_ocr_data=False + ) + + +def _create_ocr_data( + expected_n_rows: int, +) -> tuple[list[dict[str, Any]], tuple[list[PILImage.Image]], tuple[list[str]]]: dummy_ocr_data = [ { "ocr_images": [PILImage.fromarray(np.ones((5, 5, 3)).astype(np.uint8) * i)] * (i + 1), "ocr_targets": [f"foo {i}"] * (i + 1), } - for i in expected_non_empty_indices + for i in range(expected_n_rows) ] # different value at each call, different number of items each time (nonzero i+1) + # unpack data to give expected valuess expected_ocr_data, expected_ocr_targets = zip( *[x.values() for x in dummy_ocr_data], strict=True ) - with patch( # noqa: SIM117 - "arc_spice.data.multieurlex_utils.datasets.load_dataset", return_value=ds - ): - with patch("arc_spice.data.multieurlex_utils.make_ocr_data") as mock_mod: - mock_mod.side_effect = dummy_ocr_data * 3 # handle all 3 splits - dataset_dict, metadata = multieurlex_utils.load_multieurlex_for_pipeline( - data_dir=data_dir, - level=1, - lang_pair=lang_pair, - drop_empty=True, - load_ocr_data=True, - ) - assert len(dataset_dict) == 3 - for split in ["train", "validation", "test"]: - assert set(dataset_dict[split].features.keys()) == expected_keys - assert len(dataset_dict[split]) == 4 # 5 items, 1 is empty so dropped - _check_pipeline_text( - dataset=dataset_dict[split], - text_indices_kept=text_expected_non_empty_indices, # inds start at 1 - source_lang=lang_pair["source"], - target_lang=lang_pair["target"], - ) - _check_keys_untouched( - original_dataset=ds[split], - dataset=dataset_dict[split], - indices_kept=expected_non_empty_indices, # inds start at 0 - ignore_keys=["source_text", "target_text", "ocr_images", "ocr_targets"], - ) + return dummy_ocr_data, expected_ocr_data, expected_ocr_targets + +# same as above but with OCR data checks +def test_load_multieurlex_for_pipeline_ocr(): + expected_keys = { + "celex_id", + "labels", + "source_text", + "target_text", + "ocr_images", + "ocr_targets", + } + + expected_n_rows = 4 + dummy_ocr_data, expected_ocr_data, expected_ocr_targets = _create_ocr_data( + expected_n_rows=expected_n_rows + ) + + with patch("arc_spice.data.multieurlex_utils.make_ocr_data") as mock_mod: + mock_mod.side_effect = dummy_ocr_data * 3 # handle all 3 splits + dataset_dict, metadata = _test_load_multieurlex_for_pipeline( + expected_keys=expected_keys, load_ocr_data=True + ) + for split in ["train", "validation", "test"]: for row_index in range(len(dataset_dict[split])): # OCR - images # PIL.PngImagePlugin.PngImageFile vs PIL.Image.Image so compare as np