From 00e759bc919d1f142a65e0723f17cb5ee4596657 Mon Sep 17 00:00:00 2001 From: eddableheath Date: Tue, 3 Dec 2024 13:02:21 +0000 Subject: [PATCH 1/6] :hammer: added ocr error --- src/arc_spice/eval/ocr_error.py | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 src/arc_spice/eval/ocr_error.py diff --git a/src/arc_spice/eval/ocr_error.py b/src/arc_spice/eval/ocr_error.py new file mode 100644 index 0000000..8013cd9 --- /dev/null +++ b/src/arc_spice/eval/ocr_error.py @@ -0,0 +1,35 @@ +""" +OCR error computation for eval. +""" + +from typing import Any + +from torchmetrics.text import CharErrorRate + + +def ocr_error(ocr_output: dict[Any, Any]) -> float: + """ + Compute the character error rate for the predicted ocr character. + + NB: - this puts all strings into lower case for comparisons. + - ideal error rate is 0, worst case is 1. + + Args: + ocr_output: output from the ocr model, with structure, + { + 'full_output: [ + { + 'generated_text': gen text from the ocr model (str) + 'target': target text (str) + } + ] + 'outpu': pieced back together full string (str) + } + + Returns: + Character error rate across entire output of OCR (float) + """ + preds = [itm["generated_text"].lower() for itm in ocr_output["full_output"]] + targs = [itm["target"].lower() for itm in ocr_output["full_output"]] + cer = CharErrorRate() + return cer(preds, targs).item() From 51a706b360c845d4d0ef882950b52dc3815b8f18 Mon Sep 17 00:00:00 2001 From: eddableheath Date: Tue, 3 Dec 2024 13:08:46 +0000 Subject: [PATCH 2/6] :hammer: added ocr inference --- src/arc_spice/variational_pipelines/utils.py | 37 +++++++++++++++----- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/src/arc_spice/variational_pipelines/utils.py b/src/arc_spice/variational_pipelines/utils.py index c9e2692..8c680ba 100644 --- a/src/arc_spice/variational_pipelines/utils.py +++ b/src/arc_spice/variational_pipelines/utils.py @@ -217,21 +217,40 @@ def check_dropout(self): set_dropout(model=pl.model, dropout_flag=False) logger.debug("-------------------------------------------------------\n\n") - def recognise(self, inp) -> dict[str, str]: + def recognise(self, inp) -> dict[str, str | list[dict[str, str]]]: """ - Function to perform OCR + Function to perform OCR. Args: - inp: input + inp: input dict with key 'ocr_data', containing dict, + { + 'ocr_images': list[ocr images], + 'ocr_targets': list[ocr target words] + } Returns: - dictionary of outputs + dictionary of outputs: + { + 'full_output': [ + { + 'generated_text': generated text from ocr model (str), + 'target': original target text (str) + } + ], + 'output': pieced back together string (str) + } """ - # Until the OCR data is available - # This will need the below comment: - # type: ignore[misc] - # TODO https://github.com/alan-turing-institute/ARC-SPICE/issues/14 - return {"outputs": inp["source_text"]} + out = self.ocr(inp["ocr_data"]["ocr_images"]) # type: ignore[misc] + text = " ".join([itm[0]["generated_text"] for itm in out]) + return { + "full_output": [ + {"target": target, "generated_text": gen_text[0]["generated_text"]} + for target, gen_text in zip( + inp["ocr_data"]["ocr_targets"], out, strict=True + ) + ], + "output": text, + } def translate(self, text: str) -> dict[str, torch.Tensor | str]: """ From daed9c043100bde4c9d735a1a55f2af4f446d924 Mon Sep 17 00:00:00 2001 From: eddableheath Date: Fri, 6 Dec 2024 16:40:56 +0000 Subject: [PATCH 3/6] :hammer: added ocr UQ --- .../RTC_single_component_pipeline.py | 6 ++- .../RTC_variational_pipeline.py | 39 ++++++++++++++++++- src/arc_spice/variational_pipelines/utils.py | 30 +++++++++++++- 3 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py index 9a13646..e292a47 100644 --- a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py @@ -4,6 +4,7 @@ from transformers import pipeline from arc_spice.variational_pipelines.RTC_variational_pipeline import ( + CustomOCRPipeline, CustomTranslationPipeline, RTCVariationalPipelineBase, ) @@ -94,13 +95,16 @@ def __init__( self, model_pars: dict[str, dict[str, str]], n_variational_runs=5, + ocr_batch_size=64, **kwargs, ): self.set_device() self.ocr = pipeline( - task=model_pars["OCR"]["specific_task"], model=model_pars["OCR"]["model"], device=self.device, + pipeline_class=CustomOCRPipeline, + max_new_tokens=20, + batch_size=ocr_batch_size, **kwargs, ) self.model = self.ocr.model diff --git a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py index eba0c16..cbd0293 100644 --- a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py @@ -1,8 +1,10 @@ import copy from typing import Any +import numpy as np import torch -from transformers import TranslationPipeline, pipeline +from torch.distributions import Categorical +from transformers import ImageToTextPipeline, TranslationPipeline, pipeline from arc_spice.variational_pipelines.utils import ( RTCVariationalPipelineBase, @@ -182,3 +184,38 @@ def _forward(self, model_inputs, **generate_kwargs): logits = torch.stack(out["logits"], dim=1) return {"output_ids": output_ids, "logits": logits} + + +class CustomOCRPipeline(ImageToTextPipeline): + """ + custom OCR pipeline to return logits with the generated text. + """ + + def postprocess(self, model_outputs: dict, **postprocess_params): + raw_out = copy.deepcopy(model_outputs) + processed = super().postprocess( + model_outputs["model_output"], **postprocess_params + ) + + return {"generated_text": processed[0]["generated_text"], "raw_output": raw_out} + + def _forward(self, model_inputs, **generate_kwargs): + if ( + "input_ids" in model_inputs + and isinstance(model_inputs["input_ids"], list) + and all(x is None for x in model_inputs["input_ids"]) + ): + model_inputs["input_ids"] = None + + inputs = model_inputs.pop(self.model.main_input_name) + out = self.model.generate( + inputs, + **model_inputs, + **generate_kwargs, + output_logits=True, + return_dict_in_generate=True, + ) + + logits = torch.stack(out.logits, dim=1) + entropy = Categorical(logits=logits).entropy() / np.log(logits[0].size()[1]) + return {"model_output": out.sequences, "entropies": entropy} diff --git a/src/arc_spice/variational_pipelines/utils.py b/src/arc_spice/variational_pipelines/utils.py index 8c680ba..affbf08 100644 --- a/src/arc_spice/variational_pipelines/utils.py +++ b/src/arc_spice/variational_pipelines/utils.py @@ -217,7 +217,7 @@ def check_dropout(self): set_dropout(model=pl.model, dropout_flag=False) logger.debug("-------------------------------------------------------\n\n") - def recognise(self, inp) -> dict[str, str | list[dict[str, str]]]: + def recognise(self, inp) -> dict[str, str | list[dict[str, str | torch.Tensor]]]: """ Function to perform OCR. @@ -244,7 +244,11 @@ def recognise(self, inp) -> dict[str, str | list[dict[str, str]]]: text = " ".join([itm[0]["generated_text"] for itm in out]) return { "full_output": [ - {"target": target, "generated_text": gen_text[0]["generated_text"]} + { + "target": target, + "generated_text": gen_text["generated_text"], + "entropies": gen_text["entropies"], + } for target, gen_text in zip( inp["ocr_data"]["ocr_targets"], out, strict=True ) @@ -502,3 +506,25 @@ def get_classification_confidence( } ) return var_output + + def get_ocr_confidence(self, var_output: dict) -> dict[str, float]: + """Generate the ocr confidence score. + + Args: + var_output: variational run outputs + + Returns: + dictionary with metrics + """ + # Adapted for variational methods from: https://arxiv.org/pdf/2412.01221 + stacked_entropies = torch.stack( + [ + output["entropies"] + for output in var_output["recognition"]["full_output"] + ], + dim=1, + ) + # mean entropy + mean = torch.mean(stacked_entropies) + var_output["recognition"].update({"mean_entropy": mean}) + return var_output From 2361023522dec6100e92eba21ef9e87ccfb1715d Mon Sep 17 00:00:00 2001 From: eddableheath Date: Fri, 6 Dec 2024 16:55:57 +0000 Subject: [PATCH 4/6] :hammer: added ocr UQ --- src/arc_spice/eval/inference_utils.py | 13 ++++++++++--- src/arc_spice/eval/ocr_error.py | 1 + .../RTC_single_component_pipeline.py | 2 +- src/arc_spice/variational_pipelines/utils.py | 4 ++-- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/arc_spice/eval/inference_utils.py b/src/arc_spice/eval/inference_utils.py index 0794fa2..8c5ec32 100644 --- a/src/arc_spice/eval/inference_utils.py +++ b/src/arc_spice/eval/inference_utils.py @@ -8,6 +8,7 @@ from tqdm import tqdm from arc_spice.data.multieurlex_utils import MultiHot +from arc_spice.eval.ocr_error import ocr_error from arc_spice.eval.translation_error import get_comet_model from arc_spice.variational_pipelines.RTC_single_component_pipeline import ( RTCSingleComponentPipeline, @@ -63,10 +64,16 @@ def get_results( )._asdict() return results_dict - def recognition_results(self, *args, **kwargs): + def recognition_results( + self, + clean_output: dict[str, str | list[dict[str, str | torch.Tensor]]], + var_output: dict[str, dict], + **kwargs, + ): # ### RECOGNITION ### - # TODO: add this into results_getter : issue #14 - return RecognitionResults(confidence=None, accuracy=None) + charerror = ocr_error(clean_output) + confidence = var_output["recognition"]["mean_entropy"] + return RecognitionResults(confidence=confidence, accuracy=charerror) def translation_results( self, diff --git a/src/arc_spice/eval/ocr_error.py b/src/arc_spice/eval/ocr_error.py index 8013cd9..df370dc 100644 --- a/src/arc_spice/eval/ocr_error.py +++ b/src/arc_spice/eval/ocr_error.py @@ -21,6 +21,7 @@ def ocr_error(ocr_output: dict[Any, Any]) -> float: { 'generated_text': gen text from the ocr model (str) 'target': target text (str) + 'entropies': entropies for UQ (torch.Tensor) } ] 'outpu': pieced back together full string (str) diff --git a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py index e292a47..397e2e0 100644 --- a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py @@ -112,7 +112,7 @@ def __init__( step_name="recognition", input_key="ocr_data", forward_function=self.recognise, - confidence_function=self.recognise, # THIS WILL NEED UPDATING : #issue 14 + confidence_function=self.get_ocr_confidence, n_variational_runs=n_variational_runs, **kwargs, ) diff --git a/src/arc_spice/variational_pipelines/utils.py b/src/arc_spice/variational_pipelines/utils.py index affbf08..8eedcd1 100644 --- a/src/arc_spice/variational_pipelines/utils.py +++ b/src/arc_spice/variational_pipelines/utils.py @@ -519,8 +519,8 @@ def get_ocr_confidence(self, var_output: dict) -> dict[str, float]: # Adapted for variational methods from: https://arxiv.org/pdf/2412.01221 stacked_entropies = torch.stack( [ - output["entropies"] - for output in var_output["recognition"]["full_output"] + [data["entropies"] for data in output["full_output"]] + for output in var_output["recognition"] ], dim=1, ) From 807129b437fea0354579c45fa22d76ffde49401d Mon Sep 17 00:00:00 2001 From: eddableheath Date: Fri, 6 Dec 2024 17:13:21 +0000 Subject: [PATCH 5/6] small fix --- src/arc_spice/eval/inference_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/arc_spice/eval/inference_utils.py b/src/arc_spice/eval/inference_utils.py index 8ab9a5d..9aa36b4 100644 --- a/src/arc_spice/eval/inference_utils.py +++ b/src/arc_spice/eval/inference_utils.py @@ -159,5 +159,4 @@ def run_inference( test_row=inp, ) results.append({inp["celex_id"]: row_results_dict}) - break return results From 510d863178989b5727471e327d4d6f9e8d5487df Mon Sep 17 00:00:00 2001 From: eddableheath Date: Fri, 6 Dec 2024 17:20:10 +0000 Subject: [PATCH 6/6] fixed bug --- src/arc_spice/variational_pipelines/RTC_variational_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py index bfc05b7..5a812f2 100644 --- a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py @@ -49,7 +49,7 @@ def __init__( super().__init__(self.zero_shot, n_variational_runs, translation_batch_size) # defining the pipeline objects self.ocr = pipeline( - model=model_pars["OCR"]["model"], + model=model_pars["ocr"]["model"], device=self.device, pipeline_class=CustomOCRPipeline, max_new_tokens=20,