Skip to content

Commit

Permalink
Merge pull request #32 from alan-turing-institute/14-ocr-reverted
Browse files Browse the repository at this point in the history
14 ocr reverted
  • Loading branch information
J-Dymond authored Dec 6, 2024
2 parents af48055 + 510d863 commit 788dccc
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 24 deletions.
13 changes: 10 additions & 3 deletions src/arc_spice/eval/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tqdm import tqdm

from arc_spice.data.multieurlex_utils import MultiHot
from arc_spice.eval.ocr_error import ocr_error
from arc_spice.eval.translation_error import conditional_probability, get_comet_model
from arc_spice.variational_pipelines.RTC_single_component_pipeline import (
RTCSingleComponentPipeline,
Expand Down Expand Up @@ -68,10 +69,16 @@ def get_results(
)._asdict()
return results_dict

def recognition_results(self, *args, **kwargs):
def recognition_results(
self,
clean_output: dict[str, str | list[dict[str, str | torch.Tensor]]],
var_output: dict[str, dict],
**kwargs,
):
# ### RECOGNITION ###
# TODO: add this into results_getter : issue #14
return RecognitionResults(confidence=None, accuracy=None)
charerror = ocr_error(clean_output)
confidence = var_output["recognition"]["mean_entropy"]
return RecognitionResults(confidence=confidence, accuracy=charerror)

def translation_results(
self,
Expand Down
36 changes: 36 additions & 0 deletions src/arc_spice/eval/ocr_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
OCR error computation for eval.
"""

from typing import Any

from torchmetrics.text import CharErrorRate


def ocr_error(ocr_output: dict[Any, Any]) -> float:
"""
Compute the character error rate for the predicted ocr character.
NB: - this puts all strings into lower case for comparisons.
- ideal error rate is 0, worst case is 1.
Args:
ocr_output: output from the ocr model, with structure,
{
'full_output: [
{
'generated_text': gen text from the ocr model (str)
'target': target text (str)
'entropies': entropies for UQ (torch.Tensor)
}
]
'outpu': pieced back together full string (str)
}
Returns:
Character error rate across entire output of OCR (float)
"""
preds = [itm["generated_text"].lower() for itm in ocr_output["full_output"]]
targs = [itm["target"].lower() for itm in ocr_output["full_output"]]
cer = CharErrorRate()
return cer(preds, targs).item()
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
RTCVariationalPipelineBase,
)
from arc_spice.variational_pipelines.utils import (
CustomOCRPipeline,
CustomTranslationPipeline,
dropout_off,
dropout_on,
Expand Down Expand Up @@ -87,21 +88,24 @@ def __init__(
self,
model_pars: dict[str, dict[str, str]],
n_variational_runs=5,
ocr_batch_size=64,
**kwargs,
):
self.set_device()
self.ocr: transformers.Pipeline = pipeline(
task=model_pars["ocr"]["specific_task"],
model=model_pars["ocr"]["model"],
device=self.device,
pipeline_class=CustomOCRPipeline,
max_new_tokens=20,
batch_size=ocr_batch_size,
**kwargs,
)
self.model = self.ocr.model
super().__init__(
step_name="recognition",
input_key="ocr_data",
forward_function=self.recognise,
confidence_function=self.recognise, # THIS WILL NEED UPDATING : #issue 14
confidence_function=self.get_ocr_confidence,
n_variational_runs=n_variational_runs,
**kwargs,
)
Expand Down Expand Up @@ -160,9 +164,9 @@ def __init__(
super().__init__(
step_name="classification",
input_key="target_text",
forward_function=self.classify_topic_zero_shot
if zero_shot
else self.classify_topic,
forward_function=(
self.classify_topic_zero_shot if zero_shot else self.classify_topic
),
confidence_function=self.get_classification_confidence,
n_variational_runs=n_variational_runs,
**kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from transformers import pipeline

from arc_spice.variational_pipelines.utils import (
CustomOCRPipeline,
CustomTranslationPipeline,
RTCVariationalPipelineBase,
dropout_off,
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
data_pars: dict[str, Any],
n_variational_runs=5,
translation_batch_size=16,
ocr_batch_size=64,
) -> None:
# are we doing zero-shot-classification?
if model_pars["classifier"]["specific_task"] == "zero-shot-classification":
Expand All @@ -47,9 +49,11 @@ def __init__(
super().__init__(self.zero_shot, n_variational_runs, translation_batch_size)
# defining the pipeline objects
self.ocr = pipeline(
task=model_pars["ocr"]["specific_task"],
model=model_pars["ocr"]["model"],
device=self.device,
pipeline_class=CustomOCRPipeline,
max_new_tokens=20,
batch_size=ocr_batch_size,
)
self.translator = pipeline(
task=model_pars["translator"]["specific_task"],
Expand Down
111 changes: 96 additions & 15 deletions src/arc_spice/variational_pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
from functools import partial
from typing import Any

import numpy as np
import torch
import transformers
from torch.distributions import Categorical
from torch.nn.functional import softmax
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
ImageToTextPipeline,
Pipeline,
TranslationPipeline,
pipeline,
Expand Down Expand Up @@ -149,9 +152,9 @@ def __init__(self, zero_shot: bool, n_variational_runs=5, translation_batch_size
self.func_map = {
"recognition": self.recognise,
"translation": self.translate,
"classification": self.classify_topic_zero_shot
if zero_shot
else self.classify_topic,
"classification": (
self.classify_topic_zero_shot if zero_shot else self.classify_topic
),
}
# the naive outputs of the pipeline stages calculated in self.clean_inference
self.naive_outputs = {
Expand Down Expand Up @@ -266,21 +269,44 @@ def check_dropout(pipeline_map: transformers.Pipeline):
set_dropout(model=pl.model, dropout_flag=False)
logger.debug("-------------------------------------------------------\n\n")

def recognise(self, inp) -> dict[str, str]:
def recognise(self, inp) -> dict[str, str | list[dict[str, str | torch.Tensor]]]:
"""
Function to perform OCR
Function to perform OCR.
Args:
inp: input
inp: input dict with key 'ocr_data', containing dict,
{
'ocr_images': list[ocr images],
'ocr_targets': list[ocr target words]
}
Returns:
dictionary of outputs
dictionary of outputs:
{
'full_output': [
{
'generated_text': generated text from ocr model (str),
'target': original target text (str)
}
],
'output': pieced back together string (str)
}
"""
# Until the OCR data is available
# This will need the below comment:
# type: ignore[misc]
# TODO https://github.com/alan-turing-institute/ARC-SPICE/issues/14
return {"outputs": inp["source_text"]}
out = self.ocr(inp["ocr_data"]["ocr_images"]) # type: ignore[misc]
text = " ".join([itm[0]["generated_text"] for itm in out])
return {
"full_output": [
{
"target": target,
"generated_text": gen_text["generated_text"],
"entropies": gen_text["entropies"],
}
for target, gen_text in zip(
inp["ocr_data"]["ocr_targets"], out, strict=True
)
],
"output": text,
}

def translate(self, text: str) -> dict[str, torch.Tensor | str]:
"""
Expand Down Expand Up @@ -352,9 +378,7 @@ def classify_topic_zero_shot(self, text: str) -> dict[str, list[float] | dict]:
descriptors["en"]
for descriptors in self.dataset_meta_data["class_descriptors"] # type: ignore[index]
]
forward = self.classifier( # type: ignore[misc]
text, labels
)
forward = self.classifier(text, labels) # type: ignore[misc]
return collate_scores(
[
{"label": label, "score": score}
Expand Down Expand Up @@ -560,6 +584,28 @@ def get_classification_confidence(
)
return var_output

def get_ocr_confidence(self, var_output: dict) -> dict[str, float]:
"""Generate the ocr confidence score.
Args:
var_output: variational run outputs
Returns:
dictionary with metrics
"""
# Adapted for variational methods from: https://arxiv.org/pdf/2412.01221
stacked_entropies = torch.stack(
[
[data["entropies"] for data in output["full_output"]]
for output in var_output["recognition"]
],
dim=1,
)
# mean entropy
mean = torch.mean(stacked_entropies)
var_output["recognition"].update({"mean_entropy": mean})
return var_output


# Translation pipeline with additional functionality to save logits from fwd pass
class CustomTranslationPipeline(TranslationPipeline):
Expand Down Expand Up @@ -619,3 +665,38 @@ def _forward(self, model_inputs, **generate_kwargs):
"scores": max_token_scores,
"entropy": normalised_entropy,
}


class CustomOCRPipeline(ImageToTextPipeline):
"""
custom OCR pipeline to return logits with the generated text.
"""

def postprocess(self, model_outputs: dict, **postprocess_params):
raw_out = copy.deepcopy(model_outputs)
processed = super().postprocess(
model_outputs["model_output"], **postprocess_params
)

return {"generated_text": processed[0]["generated_text"], "raw_output": raw_out}

def _forward(self, model_inputs, **generate_kwargs):
if (
"input_ids" in model_inputs
and isinstance(model_inputs["input_ids"], list)
and all(x is None for x in model_inputs["input_ids"])
):
model_inputs["input_ids"] = None

inputs = model_inputs.pop(self.model.main_input_name)
out = self.model.generate(
inputs,
**model_inputs,
**generate_kwargs,
output_logits=True,
return_dict_in_generate=True,
)

logits = torch.stack(out.logits, dim=1)
entropy = Categorical(logits=logits).entropy() / np.log(logits[0].size()[1])
return {"model_output": out.sequences, "entropies": entropy}

0 comments on commit 788dccc

Please sign in to comment.