From 7df93c356ed37a2776ee84a78981139ef2d7ccf1 Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Sat, 7 Dec 2024 17:10:48 +0000 Subject: [PATCH] changes made from baskerville to be merged into main --- config/RTC_configs/roberta-mt5-trained.yaml | 2 +- config/RTC_configs/roberta-mt5-zero-shot.yaml | 3 +- config/data_configs/l1_fr_to_en.yaml | 2 + scripts/single_component_inference.py | 8 ++- src/arc_spice/data/multieurlex_utils.py | 20 ++++-- src/arc_spice/eval/inference_utils.py | 6 +- src/arc_spice/eval/ocr_error.py | 4 +- .../RTC_single_component_pipeline.py | 18 ++--- .../RTC_variational_pipeline.py | 6 +- src/arc_spice/variational_pipelines/utils.py | 67 ++++++++++--------- 10 files changed, 72 insertions(+), 64 deletions(-) diff --git a/config/RTC_configs/roberta-mt5-trained.yaml b/config/RTC_configs/roberta-mt5-trained.yaml index 6d949e8..a613fd9 100644 --- a/config/RTC_configs/roberta-mt5-trained.yaml +++ b/config/RTC_configs/roberta-mt5-trained.yaml @@ -1,6 +1,6 @@ ocr: specific_task: "image-to-text" - model: "microsoft/trocr-base-handwritten" + model: "microsoft/trocr-small-printed" translator: specific_task: "translation_fr_to_en" diff --git a/config/RTC_configs/roberta-mt5-zero-shot.yaml b/config/RTC_configs/roberta-mt5-zero-shot.yaml index 85a2d79..5ba8c07 100644 --- a/config/RTC_configs/roberta-mt5-zero-shot.yaml +++ b/config/RTC_configs/roberta-mt5-zero-shot.yaml @@ -1,6 +1,5 @@ ocr: - specific_task: "image-to-text" - model: "microsoft/trocr-base-handwritten" + model: "microsoft/trocr-small-printed" translator: specific_task: "translation_fr_to_en" diff --git a/config/data_configs/l1_fr_to_en.yaml b/config/data_configs/l1_fr_to_en.yaml index 58e12f1..42ece9e 100644 --- a/config/data_configs/l1_fr_to_en.yaml +++ b/config/data_configs/l1_fr_to_en.yaml @@ -7,3 +7,5 @@ lang_pair: target: "en" drop_length: 1000 + +load_ocr_data: True diff --git a/scripts/single_component_inference.py b/scripts/single_component_inference.py index 6741dbd..ac538c0 100644 --- a/scripts/single_component_inference.py +++ b/scripts/single_component_inference.py @@ -46,15 +46,19 @@ def main( # initialise pipeline data_config = open_yaml_path(data_config_pth) pipeline_config = open_yaml_path(pipeline_config_pth) + + if model_key != "ocr": + data_config["load_ocr_data"] = False + data_sets, meta_data = load_multieurlex_for_pipeline(**data_config) test_loader = data_sets["test"] if model_key == "ocr": rtc_single_component_pipeline = RecognitionVariationalPipeline( - model_pars=pipeline_config, data_pars=meta_data + model_pars=pipeline_config ) elif model_key == "translator": rtc_single_component_pipeline = TranslationVariationalPipeline( - model_pars=pipeline_config, data_pars=meta_data + model_pars=pipeline_config ) elif model_key == "classifier": rtc_single_component_pipeline = ClassificationVariationalPipeline( diff --git a/src/arc_spice/data/multieurlex_utils.py b/src/arc_spice/data/multieurlex_utils.py index 767546c..779121a 100644 --- a/src/arc_spice/data/multieurlex_utils.py +++ b/src/arc_spice/data/multieurlex_utils.py @@ -67,14 +67,17 @@ def extract_articles( def _make_ocr_data(text: str) -> list[tuple[Image.Image, str]]: text_split = text.split() - text_split = [text for text in text_split if text not in ("", " ", None)] + text_split = [text for text in text_split if text not in ("", " ")] generator = GeneratorFromStrings(text_split, count=len(text_split)) return list(generator) -def make_ocr_data(item: LazyRow) -> dict[str, tuple[Image.Image] | tuple[str]]: - images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True) - return {"ocr_images": images, "ocr_targets": targets} +def make_ocr_data(item: LazyRow) -> dict: + try: + images, targets = zip(*_make_ocr_data(item["source_text"]), strict=True) + except ValueError: + return {"ocr_data": {"ocr_images": None, "ocr_targets": None}} + return {"ocr_data": {"ocr_images": images, "ocr_targets": targets}} class TranslationPreProcesser: @@ -229,11 +232,14 @@ def load_multieurlex_for_pipeline( make_ocr_data, features=datasets.Features( { - "ocr_images": datasets.Sequence(datasets.Image(decode=True)), - "ocr_targets": datasets.Sequence(datasets.Value("string")), + "ocr_data": { + "ocr_images": datasets.Sequence( + datasets.Image(decode=True) + ), + "ocr_targets": datasets.Sequence(datasets.Value("string")), + }, **feats, } ), ) - return dataset_dict, meta_data diff --git a/src/arc_spice/eval/inference_utils.py b/src/arc_spice/eval/inference_utils.py index 5bf4cdd..4e7e9e1 100644 --- a/src/arc_spice/eval/inference_utils.py +++ b/src/arc_spice/eval/inference_utils.py @@ -71,12 +71,12 @@ def get_results( def recognition_results( self, - clean_output: dict[str, str | list[dict[str, str | torch.Tensor]]], - var_output: dict[str, dict], + clean_output: dict, + var_output: dict, **kwargs, ): # ### RECOGNITION ### - charerror = ocr_error(clean_output) + charerror = ocr_error(clean_output["recognition"]) confidence = var_output["recognition"]["mean_entropy"] return RecognitionResults(confidence=confidence, accuracy=charerror) diff --git a/src/arc_spice/eval/ocr_error.py b/src/arc_spice/eval/ocr_error.py index df370dc..ac925a6 100644 --- a/src/arc_spice/eval/ocr_error.py +++ b/src/arc_spice/eval/ocr_error.py @@ -30,7 +30,7 @@ def ocr_error(ocr_output: dict[Any, Any]) -> float: Returns: Character error rate across entire output of OCR (float) """ - preds = [itm["generated_text"].lower() for itm in ocr_output["full_output"]] - targs = [itm["target"].lower() for itm in ocr_output["full_output"]] + preds = [itm["generated_text"].lower() for itm in ocr_output["outputs"]] + targs = [itm["target"].lower() for itm in ocr_output["outputs"]] cer = CharErrorRate() return cer(preds, targs).item() diff --git a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py index 6b08c10..ae333f4 100644 --- a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py @@ -89,26 +89,23 @@ def __init__( model_pars: dict[str, dict[str, str]], n_variational_runs=5, ocr_batch_size=64, - **kwargs, ): self.set_device() + super().__init__( + step_name="recognition", + input_key="ocr_data", + forward_function=self.recognise, + confidence_function=self.get_ocr_confidence, + n_variational_runs=n_variational_runs, + ) self.ocr: transformers.Pipeline = pipeline( model=model_pars["ocr"]["model"], device=self.device, pipeline_class=CustomOCRPipeline, max_new_tokens=20, batch_size=ocr_batch_size, - **kwargs, ) self.model = self.ocr.model - super().__init__( - step_name="recognition", - input_key="ocr_data", - forward_function=self.recognise, - confidence_function=self.get_ocr_confidence, - n_variational_runs=n_variational_runs, - **kwargs, - ) self._init_pipeline_map() @@ -118,7 +115,6 @@ def __init__( model_pars: dict[str, dict[str, str]], n_variational_runs=5, translation_batch_size=4, - **kwargs, ): self.set_device() # need to initialise the NLI models in this case diff --git a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py index 5a812f2..84e120f 100644 --- a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py @@ -79,7 +79,7 @@ def clean_inference(self, x: torch.Tensor) -> dict[str, dict]: # run the functions # UNTIL THE OCR DATA IS AVAILABLE - clean_output["recognition"] = self.recognise(x) + clean_output["recognition"] = self.recognise(x["ocr_data"]) clean_output["translation"] = self.translate( clean_output["recognition"]["outputs"] @@ -109,8 +109,8 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]: } # define the input map for brevity in forward pass input_map = { - "recognition": x, - "translation": clean_output["recognition"]["outputs"], + "recognition": x["ocr_data"], + "translation": clean_output["recognition"]["full_output"], "classification": clean_output["translation"]["full_output"], } diff --git a/src/arc_spice/variational_pipelines/utils.py b/src/arc_spice/variational_pipelines/utils.py index bf7a9a9..bee1856 100644 --- a/src/arc_spice/variational_pipelines/utils.py +++ b/src/arc_spice/variational_pipelines/utils.py @@ -283,29 +283,27 @@ def recognise(self, inp) -> dict[str, str | list[dict[str, str | torch.Tensor]]] Returns: dictionary of outputs: { - 'full_output': [ + 'outputs': [ { 'generated_text': generated text from ocr model (str), 'target': original target text (str) } ], - 'output': pieced back together string (str) + 'full_output': pieced back together string (str) } """ - out = self.ocr(inp["ocr_data"]["ocr_images"]) # type: ignore[misc] - text = " ".join([itm[0]["generated_text"] for itm in out]) + out = self.ocr(inp["ocr_images"]) # type: ignore[misc] + text = " ".join([itm["generated_text"] for itm in out]) return { - "full_output": [ + "outputs": [ { "target": target, "generated_text": gen_text["generated_text"], - "entropies": gen_text["entropies"], + "entropies": gen_text["raw_output"]["entropies"], } - for target, gen_text in zip( - inp["ocr_data"]["ocr_targets"], out, strict=True - ) + for target, gen_text in zip(inp["ocr_targets"], out, strict=True) ], - "output": text, + "full_output": text, } def translate(self, text: str) -> dict[str, torch.Tensor | str]: @@ -429,6 +427,31 @@ def stack_variational_outputs( # overwrite the existing output dictionary return new_var_dict + def get_ocr_confidence(self, var_output: dict, **kwargs) -> dict[str, float]: + """Generate the ocr confidence score. + + Args: + var_output: variational run outputs + + Returns: + dictionary with metrics + """ + # Adapted for variational methods from: https://arxiv.org/pdf/2412.01221 + entropies = [] + recognition_batches = var_output["recognition"]["outputs"] + for batch in recognition_batches: + for sequence in batch: + ent = sequence["entropies"] + if ent.dim() == 1: + entropies.append(ent) + else: + entropies.append(ent.squeeze()) + all_entropies = torch.cat(entropies) + # mean entropy + mean = torch.mean(all_entropies) + var_output["recognition"].update({"mean_entropy": mean}) + return var_output + def sentence_density( self, clean_sentence: str, @@ -584,28 +607,6 @@ def get_classification_confidence( ) return var_output - def get_ocr_confidence(self, var_output: dict) -> dict[str, float]: - """Generate the ocr confidence score. - - Args: - var_output: variational run outputs - - Returns: - dictionary with metrics - """ - # Adapted for variational methods from: https://arxiv.org/pdf/2412.01221 - stacked_entropies = torch.stack( - [ - [data["entropies"] for data in output["full_output"]] - for output in var_output["recognition"] - ], - dim=1, - ) - # mean entropy - mean = torch.mean(stacked_entropies) - var_output["recognition"].update({"mean_entropy": mean}) - return var_output - # Translation pipeline with additional functionality to save logits from fwd pass class CustomTranslationPipeline(TranslationPipeline): @@ -699,4 +700,4 @@ def _forward(self, model_inputs, **generate_kwargs): logits = torch.stack(out.logits, dim=1) entropy = Categorical(logits=logits).entropy() / np.log(logits[0].size()[1]) - return {"model_output": out.sequences, "entropies": entropy} + return {"model_output": out.sequences, "entropies": entropy.squeeze()}