Skip to content

Commit

Permalink
🔨 added ocr UQ
Browse files Browse the repository at this point in the history
  • Loading branch information
eddableheath committed Dec 6, 2024
1 parent daed9c0 commit 2361023
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
13 changes: 10 additions & 3 deletions src/arc_spice/eval/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tqdm import tqdm

from arc_spice.data.multieurlex_utils import MultiHot
from arc_spice.eval.ocr_error import ocr_error
from arc_spice.eval.translation_error import get_comet_model
from arc_spice.variational_pipelines.RTC_single_component_pipeline import (
RTCSingleComponentPipeline,
Expand Down Expand Up @@ -63,10 +64,16 @@ def get_results(
)._asdict()
return results_dict

def recognition_results(self, *args, **kwargs):
def recognition_results(
self,
clean_output: dict[str, str | list[dict[str, str | torch.Tensor]]],
var_output: dict[str, dict],
**kwargs,
):
# ### RECOGNITION ###
# TODO: add this into results_getter : issue #14
return RecognitionResults(confidence=None, accuracy=None)
charerror = ocr_error(clean_output)
confidence = var_output["recognition"]["mean_entropy"]
return RecognitionResults(confidence=confidence, accuracy=charerror)

def translation_results(
self,
Expand Down
1 change: 1 addition & 0 deletions src/arc_spice/eval/ocr_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def ocr_error(ocr_output: dict[Any, Any]) -> float:
{
'generated_text': gen text from the ocr model (str)
'target': target text (str)
'entropies': entropies for UQ (torch.Tensor)
}
]
'outpu': pieced back together full string (str)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
step_name="recognition",
input_key="ocr_data",
forward_function=self.recognise,
confidence_function=self.recognise, # THIS WILL NEED UPDATING : #issue 14
confidence_function=self.get_ocr_confidence,
n_variational_runs=n_variational_runs,
**kwargs,
)
Expand Down
4 changes: 2 additions & 2 deletions src/arc_spice/variational_pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,8 @@ def get_ocr_confidence(self, var_output: dict) -> dict[str, float]:
# Adapted for variational methods from: https://arxiv.org/pdf/2412.01221
stacked_entropies = torch.stack(
[
output["entropies"]
for output in var_output["recognition"]["full_output"]
[data["entropies"] for data in output["full_output"]]
for output in var_output["recognition"]
],
dim=1,
)
Expand Down

0 comments on commit 2361023

Please sign in to comment.