Skip to content

Commit

Permalink
ocr outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
lannelin committed Nov 26, 2024
1 parent 848f653 commit fd25f70
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 43 deletions.
58 changes: 37 additions & 21 deletions src/arc_spice/data/multieurlex_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
from typing import Any

import datasets
import torch
from datasets import Dataset, DatasetDict, Image, load_dataset
from datasets.formatting.formatting import LazyRow
from PIL import Image
from torch.nn.functional import one_hot
from trdg.generators import GeneratorFromStrings

Expand Down Expand Up @@ -64,18 +65,15 @@ def extract_articles(
}


def _make_ocr_data(text: str):
def _make_ocr_data(text: str) -> list[tuple[Image.Image, str]]:
text_split = text.split()
generator = GeneratorFromStrings(text_split, count=len(text_split))
feature = Image(decode=False)
return {
str(idx): {"image": feature.encode_example(gen[0]), "target": gen[1]}
for idx, gen in enumerate(generator)
}
return list(generator)


def make_ocr_data(item: LazyRow):
return {"ocr_data": _make_ocr_data(item["source_text"])}
def make_ocr_data(item: LazyRow) -> dict[str, tuple[Image.Image] | tuple[str]]:
images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True)
return {"ocr_images": images, "ocr_targets": targets}


class TranslationPreProcesser:
Expand Down Expand Up @@ -136,8 +134,7 @@ def load_multieurlex(
languages: list[str],
drop_empty: bool = True,
split: str | None = None,
load_ocr_data: bool = False,
) -> tuple[DatasetDict, dict[str, Any]]:
) -> tuple[datasets.DatasetDict, dict[str, Any]]:
"""
load the multieurlex dataset
Expand All @@ -159,7 +156,7 @@ def load_multieurlex(

load_langs = languages[0] if len(languages) == 1 else "all_languages"

dataset_dict = load_dataset(
dataset_dict = datasets.load_dataset(
"multi_eurlex",
load_langs,
label_level=f"level_{level}",
Expand All @@ -168,13 +165,13 @@ def load_multieurlex(
)
# ensure we always return dataset dict even if only single split
if split is not None:
if not isinstance(dataset_dict, Dataset):
if not isinstance(dataset_dict, datasets.Dataset):
msg = (
"Error. load_dataset should return a Dataset object if split specified"
)
raise ValueError(msg)

tmp = DatasetDict()
tmp = datasets.DatasetDict()
tmp[split] = dataset_dict
dataset_dict = tmp

Expand All @@ -191,21 +188,40 @@ def load_multieurlex(
lambda x: all(x is not None for x in x["text"].values())
)

if load_ocr_data:
dataset_dict = dataset_dict.map(make_ocr_data)

# return datasets and metadata
return dataset_dict, metadata


def load_multieurlex_for_translation(
data_dir: str, level: int, lang_pair: dict[str, str], drop_empty: bool = True
) -> tuple[DatasetDict, dict[str, Any]]:
def load_multieurlex_for_pipeline(
data_dir: str,
level: int,
lang_pair: dict[str, str],
drop_empty: bool = True,
load_ocr_data: bool = False,
) -> tuple[datasets.DatasetDict, dict[str, Any]]:
langs = [lang_pair["source"], lang_pair["target"]]
dataset_dict, meta_data = load_multieurlex(
data_dir=data_dir, level=level, languages=langs, drop_empty=drop_empty
)
# instantiate the preprocessor
preprocesser = TranslationPreProcesser(lang_pair)
# preprocess each split
return dataset_dict.map(preprocesser, remove_columns=["text"]), meta_data
dataset_dict = dataset_dict.map(preprocesser, remove_columns=["text"])

# TODO allow for OCR standalone?
if load_ocr_data:
# need to set features so loop through dataset dict manually
for k in dataset_dict:
feats = dataset_dict[k].features
dataset_dict[k] = dataset_dict[k].map(
make_ocr_data,
features=datasets.Features(
{
"ocr_images": datasets.Sequence(datasets.Image(decode=True)),
"ocr_targets": datasets.Sequence(datasets.Value("string")),
**feats,
}
),
)

return dataset_dict, meta_data
Loading

0 comments on commit fd25f70

Please sign in to comment.