diff --git a/pyproject.toml b/pyproject.toml index f2d8cad..43b14ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ ] description = "Sample Level, Pipeline Introduced Cumulative Errors (SPICE). Investigating methods for generalisable measurement of cumulative errors in multi-model (and multi-modal) ML pipelines." readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.10.4" classifiers = [ "Development Status :: 1 - Planning", "Intended Audience :: Science/Research", @@ -27,29 +27,11 @@ classifiers = [ "Topic :: Scientific/Engineering", "Typing :: Typed", ] -dependencies = [ - "transformers", - "huggingface", - "datasets", - "numpy", - "sentencepiece", - "librosa", - "torch", - "torcheval", - "pillow", - "unbabel-comet", - "accelerate", - "tensorboard", - "scikit-learn" -] +dynamic = ["dependencies", "optional-dependencies"] -[project.optional-dependencies] -dev = [ - "pytest >=6", - "pytest-cov >=3", - "pre-commit", - "mypy", -] +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} +optional-dependencies = {dev = { file = ["requirements-dev.txt"] }} [project.urls] Homepage = "https://github.com/alan-turing-institute/ARC-SPICE" diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..d3ed352 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,13 @@ +cfgv==3.4.0 +coverage==7.6.8 +distlib==0.3.9 +identify==2.6.3 +iniconfig==2.0.0 +mypy==1.13.0 +mypy-extensions==1.0.0 +nodeenv==1.9.1 +pluggy==1.5.0 +pre_commit==4.0.1 +pytest==8.3.3 +pytest-cov==6.0.0 +virtualenv==20.28.0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..145f100 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,91 @@ +absl-py==2.1.0 +accelerate==1.1.1 +aiohappyeyeballs==2.4.3 +aiohttp==3.11.7 +aiosignal==1.3.1 +arabic-reshaper==2.1.4 +attrs==24.2.0 +audioread==3.0.1 +beautifulsoup4==4.12.3 +certifi==2024.8.30 +cffi==1.17.1 +charset-normalizer==3.4.0 +colorama==0.4.6 +datasets==3.1.0 +decorator==5.1.1 +diffimg==0.2.3 +dill==0.3.8 +entmax==1.3 +filelock==3.16.1 +frozenlist==1.5.0 +fsspec==2024.9.0 +grpcio==1.68.0 +huggingface==0.0.1 +huggingface-hub==0.26.2 +idna==3.10 +Jinja2==3.1.4 +joblib==1.4.2 +jsonargparse==3.13.1 +lazy_loader==0.4 +librosa==0.10.2.post1 +lightning-utilities==0.11.9 +llvmlite==0.43.0 +lxml==5.3.0 +Markdown==3.7 +MarkupSafe==3.0.2 +mpmath==1.3.0 +msgpack==1.1.0 +multidict==6.1.0 +multiprocess==0.70.16 +networkx==3.4.2 +numba==0.60.0 +numpy==1.26.4 +opencv-python==4.10.0.84 +packaging==24.2 +pandas==2.2.3 +pillow==11.0.0 +platformdirs==4.3.6 +pooch==1.8.2 +portalocker==3.0.0 +propcache==0.2.0 +protobuf==4.25.5 +psutil==6.1.0 +pyarrow==18.1.0 +pycparser==2.22 +python-bidi==0.4.2 +python-dateutil==2.9.0.post0 +pytorch-lightning==2.4.0 +pytz==2024.2 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.32.3 +sacrebleu==2.4.3 +safetensors==0.4.5 +scikit-learn==1.5.2 +scipy==1.14.1 +sentencepiece==0.1.99 +setuptools==75.6.0 +six==1.16.0 +soundfile==0.12.1 +soupsieve==2.6 +soxr==0.5.0.post1 +sympy==1.13.1 +tabulate==0.9.0 +tensorboard==2.18.0 +tensorboard-data-server==0.7.2 +threadpoolctl==3.5.0 +tokenizers==0.20.3 +torch==2.5.1 +torcheval==0.0.7 +torchmetrics==0.10.3 +tqdm==4.67.1 +transformers==4.46.3 +trdg @ git+https://github.com/alan-turing-institute/TextRecognitionDataGenerator.git@c17c6162dcbff41e3a0c576454a79561b0f954bf +typing_extensions==4.12.2 +tzdata==2024.2 +unbabel-comet==2.2.2 +urllib3==2.2.3 +Werkzeug==3.1.3 +wikipedia==1.4.0 +xxhash==3.5.0 +yarl==1.18.0 diff --git a/src/arc_spice/data/multieurlex_utils.py b/src/arc_spice/data/multieurlex_utils.py index 2965a66..d46da5e 100644 --- a/src/arc_spice/data/multieurlex_utils.py +++ b/src/arc_spice/data/multieurlex_utils.py @@ -1,10 +1,12 @@ import json from typing import Any +import datasets import torch -from datasets import Dataset, DatasetDict, load_dataset from datasets.formatting.formatting import LazyRow +from PIL import Image from torch.nn.functional import one_hot +from trdg.generators import GeneratorFromStrings # For identifying where the adopted decisions begin ARTICLE_1_MARKERS = { @@ -63,6 +65,17 @@ def extract_articles( } +def _make_ocr_data(text: str) -> list[tuple[Image.Image, str]]: + text_split = text.split() + generator = GeneratorFromStrings(text_split, count=len(text_split)) + return list(generator) + + +def make_ocr_data(item: LazyRow) -> dict[str, tuple[Image.Image] | tuple[str]]: + images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True) + return {"ocr_images": images, "ocr_targets": targets} + + class TranslationPreProcesser: """Prepares the data for the translation task""" @@ -121,7 +134,7 @@ def load_multieurlex( languages: list[str], drop_empty: bool = True, split: str | None = None, -) -> tuple[DatasetDict, dict[str, Any]]: +) -> tuple[datasets.DatasetDict, dict[str, Any]]: """ load the multieurlex dataset @@ -143,7 +156,7 @@ def load_multieurlex( load_langs = languages[0] if len(languages) == 1 else "all_languages" - dataset_dict = load_dataset( + dataset_dict = datasets.load_dataset( "multi_eurlex", load_langs, label_level=f"level_{level}", @@ -152,13 +165,13 @@ def load_multieurlex( ) # ensure we always return dataset dict even if only single split if split is not None: - if not isinstance(dataset_dict, Dataset): + if not isinstance(dataset_dict, datasets.Dataset): msg = ( "Error. load_dataset should return a Dataset object if split specified" ) raise ValueError(msg) - tmp = DatasetDict() + tmp = datasets.DatasetDict() tmp[split] = dataset_dict dataset_dict = tmp @@ -179,9 +192,13 @@ def load_multieurlex( return dataset_dict, metadata -def load_multieurlex_for_translation( - data_dir: str, level: int, lang_pair: dict[str, str], drop_empty: bool = True -) -> tuple[DatasetDict, dict[str, Any]]: +def load_multieurlex_for_pipeline( + data_dir: str, + level: int, + lang_pair: dict[str, str], + drop_empty: bool = True, + load_ocr_data: bool = False, +) -> tuple[datasets.DatasetDict, dict[str, Any]]: langs = [lang_pair["source"], lang_pair["target"]] dataset_dict, meta_data = load_multieurlex( data_dir=data_dir, level=level, languages=langs, drop_empty=drop_empty @@ -189,4 +206,22 @@ def load_multieurlex_for_translation( # instantiate the preprocessor preprocesser = TranslationPreProcesser(lang_pair) # preprocess each split - return dataset_dict.map(preprocesser, remove_columns=["text"]), meta_data + dataset_dict = dataset_dict.map(preprocesser, remove_columns=["text"]) + + # TODO allow for OCR standalone? + if load_ocr_data: + # need to set features so loop through dataset dict manually + for k in dataset_dict: + feats = dataset_dict[k].features + dataset_dict[k] = dataset_dict[k].map( + make_ocr_data, + features=datasets.Features( + { + "ocr_images": datasets.Sequence(datasets.Image(decode=True)), + "ocr_targets": datasets.Sequence(datasets.Value("string")), + **feats, + } + ), + ) + + return dataset_dict, meta_data diff --git a/tests/test_multieurlex_utils.py b/tests/test_multieurlex_utils.py index 4701f51..41b162c 100644 --- a/tests/test_multieurlex_utils.py +++ b/tests/test_multieurlex_utils.py @@ -1,17 +1,20 @@ import os +from typing import Any from unittest.mock import patch import datasets +import numpy as np import pyarrow as pa from datasets.formatting import PythonFormatter from datasets.formatting.formatting import LazyRow +from PIL import Image as PILImage from arc_spice.data import multieurlex_utils TEST_ROOT = os.path.dirname(os.path.abspath(__file__)) -def _create_row(text) -> LazyRow: +def _create_single_lang_row(text) -> LazyRow: pa_table = pa.Table.from_pydict({"text": [text]}) formatter = PythonFormatter(lazy=True) return formatter.format_row(pa_table) @@ -24,11 +27,19 @@ def _create_multilang_row(texts_by_lang: dict[str, str]) -> LazyRow: return formatter.format_row(pa_table) +def _create_translation_row(source_text: str, target_text: str) -> LazyRow: + pa_table = pa.Table.from_pydict( + {"source_text": [source_text], "target_text": [target_text]} + ) + formatter = PythonFormatter(lazy=True) + return formatter.format_row(pa_table) + + def test_extract_articles_single_lang(): langs = ["en"] pre_text = "Some text before the marker" post_text = "Some text after the marker" - row = _create_row( + row = _create_single_lang_row( text=f"{pre_text} {multieurlex_utils.ARTICLE_1_MARKERS['en']} {post_text}" ) out = multieurlex_utils.extract_articles(item=row, languages=langs) @@ -53,6 +64,38 @@ def test_extract_articles_multi_lang(): } +def test_make_ocr_data(): + source_text = "Some text to make into an image" + row = _create_translation_row(source_text=source_text, target_text="foo") + dummy_im1 = PILImage.fromarray( + np.random.randint(0, 255, (10, 10, 3)).astype(np.uint8) + ) + dummy_im2 = PILImage.fromarray( + np.random.randint(0, 255, (10, 10, 3)).astype(np.uint8) + ) + + with patch("arc_spice.data.multieurlex_utils.GeneratorFromStrings") as mock_gen: + mock_gen.return_value = [(dummy_im1, "target1"), (dummy_im2, "target2")] + output = multieurlex_utils.make_ocr_data(row) + + assert output == { + "ocr_images": (dummy_im1, dummy_im2), + "ocr_targets": ("target1", "target2"), + } + + +def _check_keys_untouched( + original_dataset: datasets.Dataset, + dataset: datasets.Dataset, + indices_kept: list[int], + ignore_keys=list[str], +) -> None: + # check remaining keys are untouched + for key in dataset.features: + if key not in ignore_keys: + assert dataset[key] == [original_dataset[key][i] for i in indices_kept] + + def test_load_multieurlex_en(): data_dir = f"{TEST_ROOT}/testdata/multieurlex_test_en" level = 1 @@ -60,39 +103,178 @@ def test_load_multieurlex_en(): drop_empty = True ds = datasets.load_from_disk(data_dir) - with patch("arc_spice.data.multieurlex_utils.load_dataset", return_value=ds): + expected_keys = {"celex_id", "text", "labels"} + expected_non_empty_indices = [0, 1, 3, 4] + text_expected_non_empty_indices = [i + 1 for i in expected_non_empty_indices] + with patch( + "arc_spice.data.multieurlex_utils.datasets.load_dataset", return_value=ds + ): dataset_dict, metadata = multieurlex_utils.load_multieurlex( data_dir=data_dir, level=level, languages=languages, drop_empty=drop_empty ) assert len(dataset_dict) == 3 - assert len(dataset_dict["train"]) == 4 # 5 items, 1 is empty so dropped - assert len(dataset_dict["validation"]) == 4 # 5 items, 1 is empty so dropped - assert len(dataset_dict["test"]) == 4 # 5 items, 1 is empty so dropped - assert dataset_dict["train"]["text"] == [ - f"{multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker {i}" # noqa: E501 - for i in [1, 2, 4, 5] # 3 dropped - ] + for split in ["train", "validation", "test"]: + assert set(dataset_dict[split].features.keys()) == expected_keys + assert len(dataset_dict[split]) == 4 # 5 items, 1 is empty so dropped + assert dataset_dict[split]["text"] == [ + f"{multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker {i}" # noqa: E501 + for i in text_expected_non_empty_indices # 3 dropped + ] + _check_keys_untouched( + original_dataset=ds[split], + dataset=dataset_dict[split], + indices_kept=expected_non_empty_indices, + ignore_keys=["text"], + ) -def test_load_multieurlex_for_translation(): +def test_load_multieurlex_multi_lang(): data_dir = f"{TEST_ROOT}/testdata/multieurlex_test" level = 1 languages = ["de", "en", "fr"] drop_empty = True ds = datasets.load_from_disk(data_dir) - with patch("arc_spice.data.multieurlex_utils.load_dataset", return_value=ds): + expected_keys = {"celex_id", "text", "labels"} + expected_non_empty_indices = [0, 1, 3, 4] + text_expected_non_empty_indices = [i + 1 for i in expected_non_empty_indices] + with patch( + "arc_spice.data.multieurlex_utils.datasets.load_dataset", return_value=ds + ): dataset_dict, metadata = multieurlex_utils.load_multieurlex( data_dir=data_dir, level=level, languages=languages, drop_empty=drop_empty ) assert len(dataset_dict) == 3 - assert len(dataset_dict["train"]) == 4 # 5 items, 1 is empty so dropped - assert len(dataset_dict["validation"]) == 4 # 5 items, 1 is empty so dropped - assert len(dataset_dict["test"]) == 4 # 5 items, 1 is empty so dropped - assert dataset_dict["train"]["text"] == [ # - { - lang: f"{multieurlex_utils.ARTICLE_1_MARKERS[lang]} Some text after the marker {i}" # noqa: E501 - for lang in languages - } - for i in [1, 2, 4, 5] # 3 dropped - ] + for split in ["train", "validation", "test"]: + assert set(dataset_dict[split].features.keys()) == expected_keys + assert len(dataset_dict[split]) == 4 # 5 items, 1 is empty so dropped + assert dataset_dict[split]["text"] == [ # + { + lang: f"{multieurlex_utils.ARTICLE_1_MARKERS[lang]} Some text after the marker {i}" # noqa: E501 + for lang in languages + } + for i in text_expected_non_empty_indices # 3 dropped + ] + _check_keys_untouched( + original_dataset=ds[split], + dataset=dataset_dict[split], + indices_kept=expected_non_empty_indices, + ignore_keys=["text"], + ) + + +def _check_pipeline_text( + dataset: datasets.Dataset, + text_indices_kept: list[int], + source_lang: str, + target_lang: str, +): + assert dataset["source_text"] == [ + f"{multieurlex_utils.ARTICLE_1_MARKERS[source_lang]} Some text after the marker {i}" # noqa: E501 + for i in text_indices_kept + ] + assert dataset["target_text"] == [ + f"{multieurlex_utils.ARTICLE_1_MARKERS[target_lang]} Some text after the marker {i}" # noqa: E501 + for i in text_indices_kept + ] + + +def _test_load_multieurlex_for_pipeline(expected_keys: set[str], load_ocr_data: bool): + data_dir = f"{TEST_ROOT}/testdata/multieurlex_test" + lang_pair = {"source": "de", "target": "en"} + + ds = datasets.load_from_disk(data_dir) + + expected_non_empty_indices = [0, 1, 3, 4] + text_expected_non_empty_indices = [i + 1 for i in expected_non_empty_indices] + with patch( + "arc_spice.data.multieurlex_utils.datasets.load_dataset", return_value=ds + ): + dataset_dict, metadata = multieurlex_utils.load_multieurlex_for_pipeline( + data_dir=data_dir, + level=1, + lang_pair=lang_pair, + drop_empty=True, + load_ocr_data=load_ocr_data, + ) + assert len(dataset_dict) == 3 + for split in ["train", "validation", "test"]: + assert set(dataset_dict[split].features.keys()) == expected_keys + assert len(dataset_dict[split]) == 4 # 5 items, 1 is empty so dropped + _check_pipeline_text( + dataset=dataset_dict[split], + text_indices_kept=text_expected_non_empty_indices, # inds start at 1 + source_lang=lang_pair["source"], + target_lang=lang_pair["target"], + ) + _check_keys_untouched( + original_dataset=ds[split], + dataset=dataset_dict[split], + indices_kept=expected_non_empty_indices, # inds start at 0 + ignore_keys=["source_text", "target_text", "ocr_images", "ocr_targets"], + ) + + return dataset_dict, metadata + + +def test_load_multieurlex_for_pipeline(): + expected_keys = {"celex_id", "labels", "source_text", "target_text"} + _test_load_multieurlex_for_pipeline( + expected_keys=expected_keys, load_ocr_data=False + ) + + +def _create_ocr_data( + expected_n_rows: int, +) -> tuple[list[dict[str, Any]], tuple[list[PILImage.Image]], tuple[list[str]]]: + dummy_ocr_data = [ + { + "ocr_images": [PILImage.fromarray(np.ones((5, 5, 3)).astype(np.uint8) * i)] + * (i + 1), + "ocr_targets": [f"foo {i}"] * (i + 1), + } + for i in range(expected_n_rows) + ] # different value at each call, different number of items each time (nonzero i+1) + + # unpack data to give expected valuess + expected_ocr_data, expected_ocr_targets = zip( + *[x.values() for x in dummy_ocr_data], strict=True + ) + return dummy_ocr_data, expected_ocr_data, expected_ocr_targets + + +# same as above but with OCR data checks +def test_load_multieurlex_for_pipeline_ocr(): + expected_keys = { + "celex_id", + "labels", + "source_text", + "target_text", + "ocr_images", + "ocr_targets", + } + + expected_n_rows = 4 + dummy_ocr_data, expected_ocr_data, expected_ocr_targets = _create_ocr_data( + expected_n_rows=expected_n_rows + ) + + with patch("arc_spice.data.multieurlex_utils.make_ocr_data") as mock_mod: + mock_mod.side_effect = dummy_ocr_data * 3 # handle all 3 splits + dataset_dict, metadata = _test_load_multieurlex_for_pipeline( + expected_keys=expected_keys, load_ocr_data=True + ) + for split in ["train", "validation", "test"]: + for row_index in range(len(dataset_dict[split])): + # OCR - images + # PIL.PngImagePlugin.PngImageFile vs PIL.Image.Image so compare as np + output_as_numpy = [ + np.asarray(im) for im in dataset_dict[split]["ocr_images"][row_index] + ] + expected_as_numpy = [np.asarray(im) for im in expected_ocr_data[row_index]] + np.testing.assert_array_equal(output_as_numpy, expected_as_numpy) + # OCR - targets + assert ( + dataset_dict[split]["ocr_targets"][row_index] + == expected_ocr_targets[row_index] + )