Skip to content

Commit

Permalink
Merge pull request #19 from alan-turing-institute/10-ocr-synthetic-data
Browse files Browse the repository at this point in the history
10 ocr synthetic data
  • Loading branch information
lannelin authored Nov 27, 2024
2 parents 419add8 + 8426046 commit 7df8c51
Show file tree
Hide file tree
Showing 5 changed files with 357 additions and 54 deletions.
28 changes: 5 additions & 23 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
]
description = "Sample Level, Pipeline Introduced Cumulative Errors (SPICE). Investigating methods for generalisable measurement of cumulative errors in multi-model (and multi-modal) ML pipelines."
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.10.4"
classifiers = [
"Development Status :: 1 - Planning",
"Intended Audience :: Science/Research",
Expand All @@ -27,29 +27,11 @@ classifiers = [
"Topic :: Scientific/Engineering",
"Typing :: Typed",
]
dependencies = [
"transformers",
"huggingface",
"datasets",
"numpy",
"sentencepiece",
"librosa",
"torch",
"torcheval",
"pillow",
"unbabel-comet",
"accelerate",
"tensorboard",
"scikit-learn"
]
dynamic = ["dependencies", "optional-dependencies"]

[project.optional-dependencies]
dev = [
"pytest >=6",
"pytest-cov >=3",
"pre-commit",
"mypy",
]
[tool.setuptools.dynamic]
dependencies = {file = ["requirements.txt"]}
optional-dependencies = {dev = { file = ["requirements-dev.txt"] }}

[project.urls]
Homepage = "https://github.com/alan-turing-institute/ARC-SPICE"
Expand Down
13 changes: 13 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
cfgv==3.4.0
coverage==7.6.8
distlib==0.3.9
identify==2.6.3
iniconfig==2.0.0
mypy==1.13.0
mypy-extensions==1.0.0
nodeenv==1.9.1
pluggy==1.5.0
pre_commit==4.0.1
pytest==8.3.3
pytest-cov==6.0.0
virtualenv==20.28.0
91 changes: 91 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
absl-py==2.1.0
accelerate==1.1.1
aiohappyeyeballs==2.4.3
aiohttp==3.11.7
aiosignal==1.3.1
arabic-reshaper==2.1.4
attrs==24.2.0
audioread==3.0.1
beautifulsoup4==4.12.3
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.4.0
colorama==0.4.6
datasets==3.1.0
decorator==5.1.1
diffimg==0.2.3
dill==0.3.8
entmax==1.3
filelock==3.16.1
frozenlist==1.5.0
fsspec==2024.9.0
grpcio==1.68.0
huggingface==0.0.1
huggingface-hub==0.26.2
idna==3.10
Jinja2==3.1.4
joblib==1.4.2
jsonargparse==3.13.1
lazy_loader==0.4
librosa==0.10.2.post1
lightning-utilities==0.11.9
llvmlite==0.43.0
lxml==5.3.0
Markdown==3.7
MarkupSafe==3.0.2
mpmath==1.3.0
msgpack==1.1.0
multidict==6.1.0
multiprocess==0.70.16
networkx==3.4.2
numba==0.60.0
numpy==1.26.4
opencv-python==4.10.0.84
packaging==24.2
pandas==2.2.3
pillow==11.0.0
platformdirs==4.3.6
pooch==1.8.2
portalocker==3.0.0
propcache==0.2.0
protobuf==4.25.5
psutil==6.1.0
pyarrow==18.1.0
pycparser==2.22
python-bidi==0.4.2
python-dateutil==2.9.0.post0
pytorch-lightning==2.4.0
pytz==2024.2
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
sacrebleu==2.4.3
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
sentencepiece==0.1.99
setuptools==75.6.0
six==1.16.0
soundfile==0.12.1
soupsieve==2.6
soxr==0.5.0.post1
sympy==1.13.1
tabulate==0.9.0
tensorboard==2.18.0
tensorboard-data-server==0.7.2
threadpoolctl==3.5.0
tokenizers==0.20.3
torch==2.5.1
torcheval==0.0.7
torchmetrics==0.10.3
tqdm==4.67.1
transformers==4.46.3
trdg @ git+https://github.com/alan-turing-institute/TextRecognitionDataGenerator.git@c17c6162dcbff41e3a0c576454a79561b0f954bf
typing_extensions==4.12.2
tzdata==2024.2
unbabel-comet==2.2.2
urllib3==2.2.3
Werkzeug==3.1.3
wikipedia==1.4.0
xxhash==3.5.0
yarl==1.18.0
53 changes: 44 additions & 9 deletions src/arc_spice/data/multieurlex_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
from typing import Any

import datasets
import torch
from datasets import Dataset, DatasetDict, load_dataset
from datasets.formatting.formatting import LazyRow
from PIL import Image
from torch.nn.functional import one_hot
from trdg.generators import GeneratorFromStrings

# For identifying where the adopted decisions begin
ARTICLE_1_MARKERS = {
Expand Down Expand Up @@ -63,6 +65,17 @@ def extract_articles(
}


def _make_ocr_data(text: str) -> list[tuple[Image.Image, str]]:
text_split = text.split()
generator = GeneratorFromStrings(text_split, count=len(text_split))
return list(generator)


def make_ocr_data(item: LazyRow) -> dict[str, tuple[Image.Image] | tuple[str]]:
images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True)
return {"ocr_images": images, "ocr_targets": targets}


class TranslationPreProcesser:
"""Prepares the data for the translation task"""

Expand Down Expand Up @@ -121,7 +134,7 @@ def load_multieurlex(
languages: list[str],
drop_empty: bool = True,
split: str | None = None,
) -> tuple[DatasetDict, dict[str, Any]]:
) -> tuple[datasets.DatasetDict, dict[str, Any]]:
"""
load the multieurlex dataset
Expand All @@ -143,7 +156,7 @@ def load_multieurlex(

load_langs = languages[0] if len(languages) == 1 else "all_languages"

dataset_dict = load_dataset(
dataset_dict = datasets.load_dataset(
"multi_eurlex",
load_langs,
label_level=f"level_{level}",
Expand All @@ -152,13 +165,13 @@ def load_multieurlex(
)
# ensure we always return dataset dict even if only single split
if split is not None:
if not isinstance(dataset_dict, Dataset):
if not isinstance(dataset_dict, datasets.Dataset):
msg = (
"Error. load_dataset should return a Dataset object if split specified"
)
raise ValueError(msg)

tmp = DatasetDict()
tmp = datasets.DatasetDict()
tmp[split] = dataset_dict
dataset_dict = tmp

Expand All @@ -179,14 +192,36 @@ def load_multieurlex(
return dataset_dict, metadata


def load_multieurlex_for_translation(
data_dir: str, level: int, lang_pair: dict[str, str], drop_empty: bool = True
) -> tuple[DatasetDict, dict[str, Any]]:
def load_multieurlex_for_pipeline(
data_dir: str,
level: int,
lang_pair: dict[str, str],
drop_empty: bool = True,
load_ocr_data: bool = False,
) -> tuple[datasets.DatasetDict, dict[str, Any]]:
langs = [lang_pair["source"], lang_pair["target"]]
dataset_dict, meta_data = load_multieurlex(
data_dir=data_dir, level=level, languages=langs, drop_empty=drop_empty
)
# instantiate the preprocessor
preprocesser = TranslationPreProcesser(lang_pair)
# preprocess each split
return dataset_dict.map(preprocesser, remove_columns=["text"]), meta_data
dataset_dict = dataset_dict.map(preprocesser, remove_columns=["text"])

# TODO allow for OCR standalone?
if load_ocr_data:
# need to set features so loop through dataset dict manually
for k in dataset_dict:
feats = dataset_dict[k].features
dataset_dict[k] = dataset_dict[k].map(
make_ocr_data,
features=datasets.Features(
{
"ocr_images": datasets.Sequence(datasets.Image(decode=True)),
"ocr_targets": datasets.Sequence(datasets.Value("string")),
**feats,
}
),
)

return dataset_dict, meta_data
Loading

0 comments on commit 7df8c51

Please sign in to comment.