From 2361023522dec6100e92eba21ef9e87ccfb1715d Mon Sep 17 00:00:00 2001 From: eddableheath Date: Fri, 6 Dec 2024 16:55:57 +0000 Subject: [PATCH] :hammer: added ocr UQ --- src/arc_spice/eval/inference_utils.py | 13 ++++++++++--- src/arc_spice/eval/ocr_error.py | 1 + .../RTC_single_component_pipeline.py | 2 +- src/arc_spice/variational_pipelines/utils.py | 4 ++-- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/arc_spice/eval/inference_utils.py b/src/arc_spice/eval/inference_utils.py index 0794fa2..8c5ec32 100644 --- a/src/arc_spice/eval/inference_utils.py +++ b/src/arc_spice/eval/inference_utils.py @@ -8,6 +8,7 @@ from tqdm import tqdm from arc_spice.data.multieurlex_utils import MultiHot +from arc_spice.eval.ocr_error import ocr_error from arc_spice.eval.translation_error import get_comet_model from arc_spice.variational_pipelines.RTC_single_component_pipeline import ( RTCSingleComponentPipeline, @@ -63,10 +64,16 @@ def get_results( )._asdict() return results_dict - def recognition_results(self, *args, **kwargs): + def recognition_results( + self, + clean_output: dict[str, str | list[dict[str, str | torch.Tensor]]], + var_output: dict[str, dict], + **kwargs, + ): # ### RECOGNITION ### - # TODO: add this into results_getter : issue #14 - return RecognitionResults(confidence=None, accuracy=None) + charerror = ocr_error(clean_output) + confidence = var_output["recognition"]["mean_entropy"] + return RecognitionResults(confidence=confidence, accuracy=charerror) def translation_results( self, diff --git a/src/arc_spice/eval/ocr_error.py b/src/arc_spice/eval/ocr_error.py index 8013cd9..df370dc 100644 --- a/src/arc_spice/eval/ocr_error.py +++ b/src/arc_spice/eval/ocr_error.py @@ -21,6 +21,7 @@ def ocr_error(ocr_output: dict[Any, Any]) -> float: { 'generated_text': gen text from the ocr model (str) 'target': target text (str) + 'entropies': entropies for UQ (torch.Tensor) } ] 'outpu': pieced back together full string (str) diff --git a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py index e292a47..397e2e0 100644 --- a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py @@ -112,7 +112,7 @@ def __init__( step_name="recognition", input_key="ocr_data", forward_function=self.recognise, - confidence_function=self.recognise, # THIS WILL NEED UPDATING : #issue 14 + confidence_function=self.get_ocr_confidence, n_variational_runs=n_variational_runs, **kwargs, ) diff --git a/src/arc_spice/variational_pipelines/utils.py b/src/arc_spice/variational_pipelines/utils.py index affbf08..8eedcd1 100644 --- a/src/arc_spice/variational_pipelines/utils.py +++ b/src/arc_spice/variational_pipelines/utils.py @@ -519,8 +519,8 @@ def get_ocr_confidence(self, var_output: dict) -> dict[str, float]: # Adapted for variational methods from: https://arxiv.org/pdf/2412.01221 stacked_entropies = torch.stack( [ - output["entropies"] - for output in var_output["recognition"]["full_output"] + [data["entropies"] for data in output["full_output"]] + for output in var_output["recognition"] ], dim=1, )