diff --git a/src/arc_spice/eval/inference_utils.py b/src/arc_spice/eval/inference_utils.py index accec55..5bf4cdd 100644 --- a/src/arc_spice/eval/inference_utils.py +++ b/src/arc_spice/eval/inference_utils.py @@ -8,6 +8,7 @@ from tqdm import tqdm from arc_spice.data.multieurlex_utils import MultiHot +from arc_spice.eval.ocr_error import ocr_error from arc_spice.eval.translation_error import conditional_probability, get_comet_model from arc_spice.variational_pipelines.RTC_single_component_pipeline import ( RTCSingleComponentPipeline, @@ -68,10 +69,16 @@ def get_results( )._asdict() return results_dict - def recognition_results(self, *args, **kwargs): + def recognition_results( + self, + clean_output: dict[str, str | list[dict[str, str | torch.Tensor]]], + var_output: dict[str, dict], + **kwargs, + ): # ### RECOGNITION ### - # TODO: add this into results_getter : issue #14 - return RecognitionResults(confidence=None, accuracy=None) + charerror = ocr_error(clean_output) + confidence = var_output["recognition"]["mean_entropy"] + return RecognitionResults(confidence=confidence, accuracy=charerror) def translation_results( self, diff --git a/src/arc_spice/eval/ocr_error.py b/src/arc_spice/eval/ocr_error.py new file mode 100644 index 0000000..df370dc --- /dev/null +++ b/src/arc_spice/eval/ocr_error.py @@ -0,0 +1,36 @@ +""" +OCR error computation for eval. +""" + +from typing import Any + +from torchmetrics.text import CharErrorRate + + +def ocr_error(ocr_output: dict[Any, Any]) -> float: + """ + Compute the character error rate for the predicted ocr character. + + NB: - this puts all strings into lower case for comparisons. + - ideal error rate is 0, worst case is 1. + + Args: + ocr_output: output from the ocr model, with structure, + { + 'full_output: [ + { + 'generated_text': gen text from the ocr model (str) + 'target': target text (str) + 'entropies': entropies for UQ (torch.Tensor) + } + ] + 'outpu': pieced back together full string (str) + } + + Returns: + Character error rate across entire output of OCR (float) + """ + preds = [itm["generated_text"].lower() for itm in ocr_output["full_output"]] + targs = [itm["target"].lower() for itm in ocr_output["full_output"]] + cer = CharErrorRate() + return cer(preds, targs).item() diff --git a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py index d077e1c..6b08c10 100644 --- a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py @@ -8,6 +8,7 @@ RTCVariationalPipelineBase, ) from arc_spice.variational_pipelines.utils import ( + CustomOCRPipeline, CustomTranslationPipeline, dropout_off, dropout_on, @@ -87,13 +88,16 @@ def __init__( self, model_pars: dict[str, dict[str, str]], n_variational_runs=5, + ocr_batch_size=64, **kwargs, ): self.set_device() self.ocr: transformers.Pipeline = pipeline( - task=model_pars["ocr"]["specific_task"], model=model_pars["ocr"]["model"], device=self.device, + pipeline_class=CustomOCRPipeline, + max_new_tokens=20, + batch_size=ocr_batch_size, **kwargs, ) self.model = self.ocr.model @@ -101,7 +105,7 @@ def __init__( step_name="recognition", input_key="ocr_data", forward_function=self.recognise, - confidence_function=self.recognise, # THIS WILL NEED UPDATING : #issue 14 + confidence_function=self.get_ocr_confidence, n_variational_runs=n_variational_runs, **kwargs, ) @@ -160,9 +164,9 @@ def __init__( super().__init__( step_name="classification", input_key="target_text", - forward_function=self.classify_topic_zero_shot - if zero_shot - else self.classify_topic, + forward_function=( + self.classify_topic_zero_shot if zero_shot else self.classify_topic + ), confidence_function=self.get_classification_confidence, n_variational_runs=n_variational_runs, **kwargs, diff --git a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py index e3ff420..5a812f2 100644 --- a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py @@ -4,6 +4,7 @@ from transformers import pipeline from arc_spice.variational_pipelines.utils import ( + CustomOCRPipeline, CustomTranslationPipeline, RTCVariationalPipelineBase, dropout_off, @@ -38,6 +39,7 @@ def __init__( data_pars: dict[str, Any], n_variational_runs=5, translation_batch_size=16, + ocr_batch_size=64, ) -> None: # are we doing zero-shot-classification? if model_pars["classifier"]["specific_task"] == "zero-shot-classification": @@ -47,9 +49,11 @@ def __init__( super().__init__(self.zero_shot, n_variational_runs, translation_batch_size) # defining the pipeline objects self.ocr = pipeline( - task=model_pars["ocr"]["specific_task"], model=model_pars["ocr"]["model"], device=self.device, + pipeline_class=CustomOCRPipeline, + max_new_tokens=20, + batch_size=ocr_batch_size, ) self.translator = pipeline( task=model_pars["translator"]["specific_task"], diff --git a/src/arc_spice/variational_pipelines/utils.py b/src/arc_spice/variational_pipelines/utils.py index 03a5937..bf7a9a9 100644 --- a/src/arc_spice/variational_pipelines/utils.py +++ b/src/arc_spice/variational_pipelines/utils.py @@ -5,12 +5,15 @@ from functools import partial from typing import Any +import numpy as np import torch import transformers +from torch.distributions import Categorical from torch.nn.functional import softmax from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, + ImageToTextPipeline, Pipeline, TranslationPipeline, pipeline, @@ -149,9 +152,9 @@ def __init__(self, zero_shot: bool, n_variational_runs=5, translation_batch_size self.func_map = { "recognition": self.recognise, "translation": self.translate, - "classification": self.classify_topic_zero_shot - if zero_shot - else self.classify_topic, + "classification": ( + self.classify_topic_zero_shot if zero_shot else self.classify_topic + ), } # the naive outputs of the pipeline stages calculated in self.clean_inference self.naive_outputs = { @@ -266,21 +269,44 @@ def check_dropout(pipeline_map: transformers.Pipeline): set_dropout(model=pl.model, dropout_flag=False) logger.debug("-------------------------------------------------------\n\n") - def recognise(self, inp) -> dict[str, str]: + def recognise(self, inp) -> dict[str, str | list[dict[str, str | torch.Tensor]]]: """ - Function to perform OCR + Function to perform OCR. Args: - inp: input + inp: input dict with key 'ocr_data', containing dict, + { + 'ocr_images': list[ocr images], + 'ocr_targets': list[ocr target words] + } Returns: - dictionary of outputs + dictionary of outputs: + { + 'full_output': [ + { + 'generated_text': generated text from ocr model (str), + 'target': original target text (str) + } + ], + 'output': pieced back together string (str) + } """ - # Until the OCR data is available - # This will need the below comment: - # type: ignore[misc] - # TODO https://github.com/alan-turing-institute/ARC-SPICE/issues/14 - return {"outputs": inp["source_text"]} + out = self.ocr(inp["ocr_data"]["ocr_images"]) # type: ignore[misc] + text = " ".join([itm[0]["generated_text"] for itm in out]) + return { + "full_output": [ + { + "target": target, + "generated_text": gen_text["generated_text"], + "entropies": gen_text["entropies"], + } + for target, gen_text in zip( + inp["ocr_data"]["ocr_targets"], out, strict=True + ) + ], + "output": text, + } def translate(self, text: str) -> dict[str, torch.Tensor | str]: """ @@ -352,9 +378,7 @@ def classify_topic_zero_shot(self, text: str) -> dict[str, list[float] | dict]: descriptors["en"] for descriptors in self.dataset_meta_data["class_descriptors"] # type: ignore[index] ] - forward = self.classifier( # type: ignore[misc] - text, labels - ) + forward = self.classifier(text, labels) # type: ignore[misc] return collate_scores( [ {"label": label, "score": score} @@ -560,6 +584,28 @@ def get_classification_confidence( ) return var_output + def get_ocr_confidence(self, var_output: dict) -> dict[str, float]: + """Generate the ocr confidence score. + + Args: + var_output: variational run outputs + + Returns: + dictionary with metrics + """ + # Adapted for variational methods from: https://arxiv.org/pdf/2412.01221 + stacked_entropies = torch.stack( + [ + [data["entropies"] for data in output["full_output"]] + for output in var_output["recognition"] + ], + dim=1, + ) + # mean entropy + mean = torch.mean(stacked_entropies) + var_output["recognition"].update({"mean_entropy": mean}) + return var_output + # Translation pipeline with additional functionality to save logits from fwd pass class CustomTranslationPipeline(TranslationPipeline): @@ -619,3 +665,38 @@ def _forward(self, model_inputs, **generate_kwargs): "scores": max_token_scores, "entropy": normalised_entropy, } + + +class CustomOCRPipeline(ImageToTextPipeline): + """ + custom OCR pipeline to return logits with the generated text. + """ + + def postprocess(self, model_outputs: dict, **postprocess_params): + raw_out = copy.deepcopy(model_outputs) + processed = super().postprocess( + model_outputs["model_output"], **postprocess_params + ) + + return {"generated_text": processed[0]["generated_text"], "raw_output": raw_out} + + def _forward(self, model_inputs, **generate_kwargs): + if ( + "input_ids" in model_inputs + and isinstance(model_inputs["input_ids"], list) + and all(x is None for x in model_inputs["input_ids"]) + ): + model_inputs["input_ids"] = None + + inputs = model_inputs.pop(self.model.main_input_name) + out = self.model.generate( + inputs, + **model_inputs, + **generate_kwargs, + output_logits=True, + return_dict_in_generate=True, + ) + + logits = torch.stack(out.logits, dim=1) + entropy = Categorical(logits=logits).entropy() / np.log(logits[0].size()[1]) + return {"model_output": out.sequences, "entropies": entropy}