From fd25f70db6a0cac2d0efafcb3fd44b6212604814 Mon Sep 17 00:00:00 2001 From: James Bishop Date: Tue, 26 Nov 2024 18:55:35 +0000 Subject: [PATCH] ocr outputs --- src/arc_spice/data/multieurlex_utils.py | 58 +++--- tests/test_multieurlex_utils.py | 231 +++++++++++++++++++++--- 2 files changed, 246 insertions(+), 43 deletions(-) diff --git a/src/arc_spice/data/multieurlex_utils.py b/src/arc_spice/data/multieurlex_utils.py index 1790ced..d46da5e 100644 --- a/src/arc_spice/data/multieurlex_utils.py +++ b/src/arc_spice/data/multieurlex_utils.py @@ -1,9 +1,10 @@ import json from typing import Any +import datasets import torch -from datasets import Dataset, DatasetDict, Image, load_dataset from datasets.formatting.formatting import LazyRow +from PIL import Image from torch.nn.functional import one_hot from trdg.generators import GeneratorFromStrings @@ -64,18 +65,15 @@ def extract_articles( } -def _make_ocr_data(text: str): +def _make_ocr_data(text: str) -> list[tuple[Image.Image, str]]: text_split = text.split() generator = GeneratorFromStrings(text_split, count=len(text_split)) - feature = Image(decode=False) - return { - str(idx): {"image": feature.encode_example(gen[0]), "target": gen[1]} - for idx, gen in enumerate(generator) - } + return list(generator) -def make_ocr_data(item: LazyRow): - return {"ocr_data": _make_ocr_data(item["source_text"])} +def make_ocr_data(item: LazyRow) -> dict[str, tuple[Image.Image] | tuple[str]]: + images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True) + return {"ocr_images": images, "ocr_targets": targets} class TranslationPreProcesser: @@ -136,8 +134,7 @@ def load_multieurlex( languages: list[str], drop_empty: bool = True, split: str | None = None, - load_ocr_data: bool = False, -) -> tuple[DatasetDict, dict[str, Any]]: +) -> tuple[datasets.DatasetDict, dict[str, Any]]: """ load the multieurlex dataset @@ -159,7 +156,7 @@ def load_multieurlex( load_langs = languages[0] if len(languages) == 1 else "all_languages" - dataset_dict = load_dataset( + dataset_dict = datasets.load_dataset( "multi_eurlex", load_langs, label_level=f"level_{level}", @@ -168,13 +165,13 @@ def load_multieurlex( ) # ensure we always return dataset dict even if only single split if split is not None: - if not isinstance(dataset_dict, Dataset): + if not isinstance(dataset_dict, datasets.Dataset): msg = ( "Error. load_dataset should return a Dataset object if split specified" ) raise ValueError(msg) - tmp = DatasetDict() + tmp = datasets.DatasetDict() tmp[split] = dataset_dict dataset_dict = tmp @@ -191,16 +188,17 @@ def load_multieurlex( lambda x: all(x is not None for x in x["text"].values()) ) - if load_ocr_data: - dataset_dict = dataset_dict.map(make_ocr_data) - # return datasets and metadata return dataset_dict, metadata -def load_multieurlex_for_translation( - data_dir: str, level: int, lang_pair: dict[str, str], drop_empty: bool = True -) -> tuple[DatasetDict, dict[str, Any]]: +def load_multieurlex_for_pipeline( + data_dir: str, + level: int, + lang_pair: dict[str, str], + drop_empty: bool = True, + load_ocr_data: bool = False, +) -> tuple[datasets.DatasetDict, dict[str, Any]]: langs = [lang_pair["source"], lang_pair["target"]] dataset_dict, meta_data = load_multieurlex( data_dir=data_dir, level=level, languages=langs, drop_empty=drop_empty @@ -208,4 +206,22 @@ def load_multieurlex_for_translation( # instantiate the preprocessor preprocesser = TranslationPreProcesser(lang_pair) # preprocess each split - return dataset_dict.map(preprocesser, remove_columns=["text"]), meta_data + dataset_dict = dataset_dict.map(preprocesser, remove_columns=["text"]) + + # TODO allow for OCR standalone? + if load_ocr_data: + # need to set features so loop through dataset dict manually + for k in dataset_dict: + feats = dataset_dict[k].features + dataset_dict[k] = dataset_dict[k].map( + make_ocr_data, + features=datasets.Features( + { + "ocr_images": datasets.Sequence(datasets.Image(decode=True)), + "ocr_targets": datasets.Sequence(datasets.Value("string")), + **feats, + } + ), + ) + + return dataset_dict, meta_data diff --git a/tests/test_multieurlex_utils.py b/tests/test_multieurlex_utils.py index 4701f51..2958d79 100644 --- a/tests/test_multieurlex_utils.py +++ b/tests/test_multieurlex_utils.py @@ -2,16 +2,18 @@ from unittest.mock import patch import datasets +import numpy as np import pyarrow as pa from datasets.formatting import PythonFormatter from datasets.formatting.formatting import LazyRow +from PIL import Image as PILImage from arc_spice.data import multieurlex_utils TEST_ROOT = os.path.dirname(os.path.abspath(__file__)) -def _create_row(text) -> LazyRow: +def _create_single_lang_row(text) -> LazyRow: pa_table = pa.Table.from_pydict({"text": [text]}) formatter = PythonFormatter(lazy=True) return formatter.format_row(pa_table) @@ -24,11 +26,19 @@ def _create_multilang_row(texts_by_lang: dict[str, str]) -> LazyRow: return formatter.format_row(pa_table) +def _create_translation_row(source_text: str, target_text: str) -> LazyRow: + pa_table = pa.Table.from_pydict( + {"source_text": [source_text], "target_text": [target_text]} + ) + formatter = PythonFormatter(lazy=True) + return formatter.format_row(pa_table) + + def test_extract_articles_single_lang(): langs = ["en"] pre_text = "Some text before the marker" post_text = "Some text after the marker" - row = _create_row( + row = _create_single_lang_row( text=f"{pre_text} {multieurlex_utils.ARTICLE_1_MARKERS['en']} {post_text}" ) out = multieurlex_utils.extract_articles(item=row, languages=langs) @@ -53,6 +63,38 @@ def test_extract_articles_multi_lang(): } +def test_make_ocr_data(): + source_text = "Some text to make into an image" + row = _create_translation_row(source_text=source_text, target_text="foo") + dummy_im1 = PILImage.fromarray( + np.random.randint(0, 255, (10, 10, 3)).astype(np.uint8) + ) + dummy_im2 = PILImage.fromarray( + np.random.randint(0, 255, (10, 10, 3)).astype(np.uint8) + ) + + with patch("arc_spice.data.multieurlex_utils.GeneratorFromStrings") as mock_gen: + mock_gen.return_value = [(dummy_im1, "target1"), (dummy_im2, "target2")] + output = multieurlex_utils.make_ocr_data(row) + + assert output == { + "ocr_images": (dummy_im1, dummy_im2), + "ocr_targets": ("target1", "target2"), + } + + +def _check_keys_untouched( + original_dataset: datasets.Dataset, + dataset: datasets.Dataset, + indices_kept: list[int], + ignore_keys=list[str], +) -> None: + # check remaining keys are untouched + for key in dataset.features: + if key not in ignore_keys: + assert dataset[key] == [original_dataset[key][i] for i in indices_kept] + + def test_load_multieurlex_en(): data_dir = f"{TEST_ROOT}/testdata/multieurlex_test_en" level = 1 @@ -60,39 +102,184 @@ def test_load_multieurlex_en(): drop_empty = True ds = datasets.load_from_disk(data_dir) - with patch("arc_spice.data.multieurlex_utils.load_dataset", return_value=ds): + expected_keys = {"celex_id", "text", "labels"} + expected_non_empty_indices = [0, 1, 3, 4] + text_expected_non_empty_indices = [i + 1 for i in expected_non_empty_indices] + with patch( + "arc_spice.data.multieurlex_utils.datasets.load_dataset", return_value=ds + ): dataset_dict, metadata = multieurlex_utils.load_multieurlex( data_dir=data_dir, level=level, languages=languages, drop_empty=drop_empty ) assert len(dataset_dict) == 3 - assert len(dataset_dict["train"]) == 4 # 5 items, 1 is empty so dropped - assert len(dataset_dict["validation"]) == 4 # 5 items, 1 is empty so dropped - assert len(dataset_dict["test"]) == 4 # 5 items, 1 is empty so dropped - assert dataset_dict["train"]["text"] == [ - f"{multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker {i}" # noqa: E501 - for i in [1, 2, 4, 5] # 3 dropped - ] + for split in ["train", "validation", "test"]: + assert set(dataset_dict[split].features.keys()) == expected_keys + assert len(dataset_dict[split]) == 4 # 5 items, 1 is empty so dropped + assert dataset_dict[split]["text"] == [ + f"{multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker {i}" # noqa: E501 + for i in text_expected_non_empty_indices # 3 dropped + ] + _check_keys_untouched( + original_dataset=ds[split], + dataset=dataset_dict[split], + indices_kept=expected_non_empty_indices, + ignore_keys=["text"], + ) -def test_load_multieurlex_for_translation(): +def test_load_multieurlex_multi_lang(): data_dir = f"{TEST_ROOT}/testdata/multieurlex_test" level = 1 languages = ["de", "en", "fr"] drop_empty = True ds = datasets.load_from_disk(data_dir) - with patch("arc_spice.data.multieurlex_utils.load_dataset", return_value=ds): + expected_keys = {"celex_id", "text", "labels"} + expected_non_empty_indices = [0, 1, 3, 4] + text_expected_non_empty_indices = [i + 1 for i in expected_non_empty_indices] + with patch( + "arc_spice.data.multieurlex_utils.datasets.load_dataset", return_value=ds + ): dataset_dict, metadata = multieurlex_utils.load_multieurlex( data_dir=data_dir, level=level, languages=languages, drop_empty=drop_empty ) assert len(dataset_dict) == 3 - assert len(dataset_dict["train"]) == 4 # 5 items, 1 is empty so dropped - assert len(dataset_dict["validation"]) == 4 # 5 items, 1 is empty so dropped - assert len(dataset_dict["test"]) == 4 # 5 items, 1 is empty so dropped - assert dataset_dict["train"]["text"] == [ # - { - lang: f"{multieurlex_utils.ARTICLE_1_MARKERS[lang]} Some text after the marker {i}" # noqa: E501 - for lang in languages - } - for i in [1, 2, 4, 5] # 3 dropped - ] + for split in ["train", "validation", "test"]: + assert set(dataset_dict[split].features.keys()) == expected_keys + assert len(dataset_dict[split]) == 4 # 5 items, 1 is empty so dropped + assert dataset_dict[split]["text"] == [ # + { + lang: f"{multieurlex_utils.ARTICLE_1_MARKERS[lang]} Some text after the marker {i}" # noqa: E501 + for lang in languages + } + for i in text_expected_non_empty_indices # 3 dropped + ] + _check_keys_untouched( + original_dataset=ds[split], + dataset=dataset_dict[split], + indices_kept=expected_non_empty_indices, + ignore_keys=["text"], + ) + + +def _check_pipeline_text( + dataset: datasets.Dataset, + text_indices_kept: list[int], + source_lang: str, + target_lang: str, +): + assert dataset["source_text"] == [ + f"{multieurlex_utils.ARTICLE_1_MARKERS[source_lang]} Some text after the marker {i}" # noqa: E501 + for i in text_indices_kept + ] + assert dataset["target_text"] == [ + f"{multieurlex_utils.ARTICLE_1_MARKERS[target_lang]} Some text after the marker {i}" # noqa: E501 + for i in text_indices_kept + ] + + +def test_load_multieurlex_for_pipeline(): + data_dir = f"{TEST_ROOT}/testdata/multieurlex_test" + lang_pair = {"source": "de", "target": "en"} + + ds = datasets.load_from_disk(data_dir) + expected_keys = {"celex_id", "labels", "source_text", "target_text"} + expected_non_empty_indices = [0, 1, 3, 4] + text_expected_non_empty_indices = [i + 1 for i in expected_non_empty_indices] + with patch( + "arc_spice.data.multieurlex_utils.datasets.load_dataset", return_value=ds + ): + dataset_dict, metadata = multieurlex_utils.load_multieurlex_for_pipeline( + data_dir=data_dir, + level=1, + lang_pair=lang_pair, + drop_empty=True, + load_ocr_data=False, + ) + assert len(dataset_dict) == 3 + for split in ["train", "validation", "test"]: + assert set(dataset_dict[split].features.keys()) == expected_keys + assert len(dataset_dict[split]) == 4 # 5 items, 1 is empty so dropped + _check_pipeline_text( + dataset=dataset_dict[split], + text_indices_kept=text_expected_non_empty_indices, # inds start at 1 + source_lang=lang_pair["source"], + target_lang=lang_pair["target"], + ) + _check_keys_untouched( + original_dataset=ds[split], + dataset=dataset_dict[split], + indices_kept=expected_non_empty_indices, # inds start at 0 + ignore_keys=["source_text", "target_text", "ocr_images", "ocr_targets"], + ) + + +def test_load_multieurlex_for_pipeline_ocr(): + data_dir = f"{TEST_ROOT}/testdata/multieurlex_test" + lang_pair = {"source": "de", "target": "en"} + + ds = datasets.load_from_disk(data_dir) + expected_keys = { + "celex_id", + "labels", + "source_text", + "target_text", + "ocr_images", + "ocr_targets", + } + expected_non_empty_indices = [0, 1, 3, 4] + text_expected_non_empty_indices = [i + 1 for i in expected_non_empty_indices] + dummy_ocr_data = [ + { + "ocr_images": [PILImage.fromarray(np.ones((5, 5, 3)).astype(np.uint8) * i)] + * (i + 1), + "ocr_targets": [f"foo {i}"] * (i + 1), + } + for i in expected_non_empty_indices + ] # different value at each call, different number of items each time (nonzero i+1) + + expected_ocr_data, expected_ocr_targets = zip( + *[x.values() for x in dummy_ocr_data], strict=True + ) + with patch( # noqa: SIM117 + "arc_spice.data.multieurlex_utils.datasets.load_dataset", return_value=ds + ): + with patch("arc_spice.data.multieurlex_utils.make_ocr_data") as mock_mod: + mock_mod.side_effect = dummy_ocr_data * 3 # handle all 3 splits + dataset_dict, metadata = multieurlex_utils.load_multieurlex_for_pipeline( + data_dir=data_dir, + level=1, + lang_pair=lang_pair, + drop_empty=True, + load_ocr_data=True, + ) + assert len(dataset_dict) == 3 + for split in ["train", "validation", "test"]: + assert set(dataset_dict[split].features.keys()) == expected_keys + assert len(dataset_dict[split]) == 4 # 5 items, 1 is empty so dropped + _check_pipeline_text( + dataset=dataset_dict[split], + text_indices_kept=text_expected_non_empty_indices, # inds start at 1 + source_lang=lang_pair["source"], + target_lang=lang_pair["target"], + ) + _check_keys_untouched( + original_dataset=ds[split], + dataset=dataset_dict[split], + indices_kept=expected_non_empty_indices, # inds start at 0 + ignore_keys=["source_text", "target_text", "ocr_images", "ocr_targets"], + ) + + for row_index in range(len(dataset_dict[split])): + # OCR - images + # PIL.PngImagePlugin.PngImageFile vs PIL.Image.Image so compare as np + output_as_numpy = [ + np.asarray(im) for im in dataset_dict[split]["ocr_images"][row_index] + ] + expected_as_numpy = [np.asarray(im) for im in expected_ocr_data[row_index]] + np.testing.assert_array_equal(output_as_numpy, expected_as_numpy) + # OCR - targets + assert ( + dataset_dict[split]["ocr_targets"][row_index] + == expected_ocr_targets[row_index] + )