Skip to content

Commit

Permalink
rationalise tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lannelin committed Nov 26, 2024
1 parent fd25f70 commit 8426046
Showing 1 changed file with 41 additions and 46 deletions.
87 changes: 41 additions & 46 deletions tests/test_multieurlex_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Any
from unittest.mock import patch

import datasets
Expand Down Expand Up @@ -178,12 +179,12 @@ def _check_pipeline_text(
]


def test_load_multieurlex_for_pipeline():
def _test_load_multieurlex_for_pipeline(expected_keys: set[str], load_ocr_data: bool):
data_dir = f"{TEST_ROOT}/testdata/multieurlex_test"
lang_pair = {"source": "de", "target": "en"}

ds = datasets.load_from_disk(data_dir)
expected_keys = {"celex_id", "labels", "source_text", "target_text"}

expected_non_empty_indices = [0, 1, 3, 4]
text_expected_non_empty_indices = [i + 1 for i in expected_non_empty_indices]
with patch(
Expand All @@ -194,7 +195,7 @@ def test_load_multieurlex_for_pipeline():
level=1,
lang_pair=lang_pair,
drop_empty=True,
load_ocr_data=False,
load_ocr_data=load_ocr_data,
)
assert len(dataset_dict) == 3
for split in ["train", "validation", "test"]:
Expand All @@ -213,63 +214,57 @@ def test_load_multieurlex_for_pipeline():
ignore_keys=["source_text", "target_text", "ocr_images", "ocr_targets"],
)

return dataset_dict, metadata

def test_load_multieurlex_for_pipeline_ocr():
data_dir = f"{TEST_ROOT}/testdata/multieurlex_test"
lang_pair = {"source": "de", "target": "en"}

ds = datasets.load_from_disk(data_dir)
expected_keys = {
"celex_id",
"labels",
"source_text",
"target_text",
"ocr_images",
"ocr_targets",
}
expected_non_empty_indices = [0, 1, 3, 4]
text_expected_non_empty_indices = [i + 1 for i in expected_non_empty_indices]
def test_load_multieurlex_for_pipeline():
expected_keys = {"celex_id", "labels", "source_text", "target_text"}
_test_load_multieurlex_for_pipeline(
expected_keys=expected_keys, load_ocr_data=False
)


def _create_ocr_data(
expected_n_rows: int,
) -> tuple[list[dict[str, Any]], tuple[list[PILImage.Image]], tuple[list[str]]]:
dummy_ocr_data = [
{
"ocr_images": [PILImage.fromarray(np.ones((5, 5, 3)).astype(np.uint8) * i)]
* (i + 1),
"ocr_targets": [f"foo {i}"] * (i + 1),
}
for i in expected_non_empty_indices
for i in range(expected_n_rows)
] # different value at each call, different number of items each time (nonzero i+1)

# unpack data to give expected valuess
expected_ocr_data, expected_ocr_targets = zip(
*[x.values() for x in dummy_ocr_data], strict=True
)
with patch( # noqa: SIM117
"arc_spice.data.multieurlex_utils.datasets.load_dataset", return_value=ds
):
with patch("arc_spice.data.multieurlex_utils.make_ocr_data") as mock_mod:
mock_mod.side_effect = dummy_ocr_data * 3 # handle all 3 splits
dataset_dict, metadata = multieurlex_utils.load_multieurlex_for_pipeline(
data_dir=data_dir,
level=1,
lang_pair=lang_pair,
drop_empty=True,
load_ocr_data=True,
)
assert len(dataset_dict) == 3
for split in ["train", "validation", "test"]:
assert set(dataset_dict[split].features.keys()) == expected_keys
assert len(dataset_dict[split]) == 4 # 5 items, 1 is empty so dropped
_check_pipeline_text(
dataset=dataset_dict[split],
text_indices_kept=text_expected_non_empty_indices, # inds start at 1
source_lang=lang_pair["source"],
target_lang=lang_pair["target"],
)
_check_keys_untouched(
original_dataset=ds[split],
dataset=dataset_dict[split],
indices_kept=expected_non_empty_indices, # inds start at 0
ignore_keys=["source_text", "target_text", "ocr_images", "ocr_targets"],
)
return dummy_ocr_data, expected_ocr_data, expected_ocr_targets


# same as above but with OCR data checks
def test_load_multieurlex_for_pipeline_ocr():
expected_keys = {
"celex_id",
"labels",
"source_text",
"target_text",
"ocr_images",
"ocr_targets",
}

expected_n_rows = 4
dummy_ocr_data, expected_ocr_data, expected_ocr_targets = _create_ocr_data(
expected_n_rows=expected_n_rows
)

with patch("arc_spice.data.multieurlex_utils.make_ocr_data") as mock_mod:
mock_mod.side_effect = dummy_ocr_data * 3 # handle all 3 splits
dataset_dict, metadata = _test_load_multieurlex_for_pipeline(
expected_keys=expected_keys, load_ocr_data=True
)
for split in ["train", "validation", "test"]:
for row_index in range(len(dataset_dict[split])):
# OCR - images
# PIL.PngImagePlugin.PngImageFile vs PIL.Image.Image so compare as np
Expand Down

0 comments on commit 8426046

Please sign in to comment.