From 92a98f06617d5b5471e5b700ac694c26f68e63ef Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Fri, 6 Dec 2024 11:36:53 +0000 Subject: [PATCH] addressed comments from pull request, also added additional outputs to translation to allow additional confidence measures --- src/arc_spice/eval/classification_error.py | 7 -- src/arc_spice/eval/inference_utils.py | 27 ++++--- .../RTC_single_component_pipeline.py | 21 ++---- .../RTC_variational_pipeline.py | 57 +------------- src/arc_spice/variational_pipelines/utils.py | 74 ++++++++++++++++++- tests/test_inference.py | 31 -------- 6 files changed, 99 insertions(+), 118 deletions(-) diff --git a/src/arc_spice/eval/classification_error.py b/src/arc_spice/eval/classification_error.py index 71292d3..2e8f36c 100644 --- a/src/arc_spice/eval/classification_error.py +++ b/src/arc_spice/eval/classification_error.py @@ -1,7 +1,4 @@ -import math - import torch -from sklearn.metrics import zero_one_loss def aggregate_score(probs: torch.Tensor) -> torch.Tensor: @@ -11,10 +8,6 @@ def aggregate_score(probs: torch.Tensor) -> torch.Tensor: return 1 - torch.mean(distance) -def zero_one_loss_ceil(y_target, y_pred): - return math.ceil(zero_one_loss(y_target, y_pred, normalize=True)) - - def MC_dropout_scores( variational_probs: list[float], epsilon: float = 1e-14 ) -> dict[str, torch.Tensor]: diff --git a/src/arc_spice/eval/inference_utils.py b/src/arc_spice/eval/inference_utils.py index 27513cf..bc8c774 100644 --- a/src/arc_spice/eval/inference_utils.py +++ b/src/arc_spice/eval/inference_utils.py @@ -17,15 +17,7 @@ ) RecognitionResults = namedtuple("RecognitionResults", ["confidence", "accuracy"]) -ClassificationResults = namedtuple( - "ClassificationResults", - [ - "clean_scores", - "mean_scores", - "hamming_accuracy", - "mean_predicted_entropy", - ], -) + TranslationResults = namedtuple( "TranslationResults", [ @@ -33,6 +25,18 @@ "clean_conditional_probability", "comet_score", "weighted_semantic_density", + "mean_entropy", + "sequence_lengths", + ], +) + +ClassificationResults = namedtuple( + "ClassificationResults", + [ + "clean_scores", + "mean_scores", + "hamming_accuracy", + "mean_predicted_entropy", ], ) @@ -79,6 +83,8 @@ def translation_results( source_text = test_row["target_text"] target_text = test_row["target_text"] clean_translation = clean_output["translation"]["full_output"] + clean_entropy: torch.Tensor = clean_output["translation"]["mean_entropy"] + seq_lens: torch.Tensor = var_output["translation"]["sequence_length"] probs: list[torch.Tensor] = clean_output["translation"]["probs"] clean_cond_prob = [ conditional_probability(prob.squeeze()).detach().tolist() for prob in probs @@ -102,6 +108,8 @@ def translation_results( comet_score=comet_output["scores"][0], full_output=clean_translation, clean_conditional_probability=clean_cond_prob, + mean_entropy=clean_entropy, + sequence_lengths=seq_lens, weighted_semantic_density=var_output["translation"][ "weighted_semantic_density" ], @@ -144,4 +152,5 @@ def run_inference( test_row=inp, ) results.append({inp["celex_id"]: row_results_dict}) + break return results diff --git a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py index 0be0025..a0788d1 100644 --- a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py @@ -4,10 +4,14 @@ from transformers import pipeline from arc_spice.variational_pipelines.RTC_variational_pipeline import ( - CustomTranslationPipeline, RTCVariationalPipelineBase, ) -from arc_spice.variational_pipelines.utils import dropout_off, dropout_on, set_dropout +from arc_spice.variational_pipelines.utils import ( + CustomTranslationPipeline, + dropout_off, + dropout_on, + set_dropout, +) class RTCSingleComponentPipeline(RTCVariationalPipelineBase): @@ -34,19 +38,6 @@ def __init__( # define objects that are needed and nothing else # naive outputs can remain the same, though only the appropriate outputs will # be outputted - self.naive_outputs = { - "recognition": [ - "outputs", - ], - "translation": [ - "full_output", - "outputs", - "probs", - ], - "classification": [ - "scores", - ], - } self.step_name = step_name self.input_key = input_key self.forward_function = forward_function diff --git a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py index 15bf990..2d9f1ea 100644 --- a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py @@ -1,11 +1,10 @@ -import copy from typing import Any import torch -from torch.nn.functional import softmax -from transformers import TranslationPipeline, pipeline +from transformers import pipeline from arc_spice.variational_pipelines.utils import ( + CustomTranslationPipeline, RTCVariationalPipelineBase, dropout_off, dropout_on, @@ -134,55 +133,3 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]: # on standard call return the clean output def __call__(self, x): return self.clean_inference(x) - - -# Translation pipeline with additional functionality to save logits from fwd pass -class CustomTranslationPipeline(TranslationPipeline): - """ - custom translation pipeline to return the logits with the generated text. Largely - the same as the pytorch version with some additional arguments passed to the - `generate` method. - """ - - def postprocess( - self, - model_outputs: dict, - **postprocess_params, - ): - # model_outputs gets overwritten in the super().postprocess call - # make a copy here so we retain the information we want - raw_out = copy.deepcopy(model_outputs) - processed = super().postprocess(model_outputs, **postprocess_params) - - return { - "translation_text": processed[0]["translation_text"], - "raw_outputs": raw_out, - } - - def _forward(self, model_inputs, **generate_kwargs): - if self.framework == "pt": - in_b, input_length = model_inputs["input_ids"].shape - elif self.framework == "tf": - raise NotImplementedError - - self.check_inputs( - input_length, - generate_kwargs.get("min_length", self.model.config.min_length), - generate_kwargs.get("max_length", self.model.config.max_length), - ) - out = self.model.generate(**model_inputs, **generate_kwargs) - output_ids = out["sequences"] - out_b = output_ids.shape[0] - if self.framework == "pt": - output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:]) - elif self.framework == "tf": - raise NotImplementedError - - # logits are a tuple of length output_ids[-1]-1 - # each element is a tensor of shape (batch_size, vocab_size) - logits = torch.stack(out["logits"], dim=1) - # get softmax of the logits to get token probabilities - softmax_logits = softmax(logits, dim=-1) - max_token_scores = torch.max(softmax_logits, dim=-1).values - - return {"output_ids": output_ids, "scores": max_token_scores} diff --git a/src/arc_spice/variational_pipelines/utils.py b/src/arc_spice/variational_pipelines/utils.py index 1de7427..f37ca30 100644 --- a/src/arc_spice/variational_pipelines/utils.py +++ b/src/arc_spice/variational_pipelines/utils.py @@ -1,11 +1,18 @@ +import copy import logging +import math from abc import ABC, abstractmethod from functools import partial from typing import Any import torch from torch.nn.functional import softmax -from transformers import AutoModelForSequenceClassification, AutoTokenizer, Pipeline +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + Pipeline, + TranslationPipeline, +) logger = logging.Logger("RTC_variational_pipeline") @@ -117,6 +124,7 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8): "full_output", "outputs", "probs", + "mean_entropy", ], "classification": [ "scores", @@ -264,6 +272,9 @@ def translate(self, text: str) -> dict[str, torch.Tensor | str]: { "outputs": translator_output["translation_text"], "probs": translator_output["raw_outputs"]["scores"], + "mean_entropy": torch.mean(translator_output["raw_outputs"]["entropy"]) + .detach() + .tolist(), } for translator_output in translator_outputs ] @@ -430,6 +441,7 @@ def translation_semantic_density( { "semantic_densities": densities, "weighted_semantic_density": weighted_average.item(), + "sequence_length": sequence_lengths, } ) @@ -480,3 +492,63 @@ def get_classification_confidence( } ) return var_output + + +# Translation pipeline with additional functionality to save logits from fwd pass +class CustomTranslationPipeline(TranslationPipeline): + """ + custom translation pipeline to return the logits with the generated text. Largely + the same as the pytorch version with some additional arguments passed to the + `generate` method. + """ + + def postprocess( + self, + model_outputs: dict, + **postprocess_params, + ): + # model_outputs gets overwritten in the super().postprocess call + # make a copy here so we retain the information we want + raw_out = copy.deepcopy(model_outputs) + processed = super().postprocess(model_outputs, **postprocess_params) + + return { + "translation_text": processed[0]["translation_text"], + "raw_outputs": raw_out, + } + + def _forward(self, model_inputs, **generate_kwargs): + if self.framework == "pt": + in_b, input_length = model_inputs["input_ids"].shape + elif self.framework == "tf": + raise NotImplementedError + + self.check_inputs( + input_length, + generate_kwargs.get("min_length", self.model.config.min_length), + generate_kwargs.get("max_length", self.model.config.max_length), + ) + out = self.model.generate(**model_inputs, **generate_kwargs) + output_ids = out["sequences"] + out_b = output_ids.shape[0] + if self.framework == "pt": + output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:]) + elif self.framework == "tf": + raise NotImplementedError + + # logits are a tuple of length output_ids[-1]-1 + # each element is a tensor of shape (batch_size, vocab_size) + logits = torch.stack(out["logits"], dim=1) + # get softmax of the logits to get token probabilities + softmax_logits = softmax(logits, dim=-1) + vocab_size = softmax_logits.shape[-1] + normalised_entropy = torch.distributions.Categorical( + probs=softmax_logits + ).entropy() / math.log(vocab_size) + max_token_scores = torch.max(softmax_logits, dim=-1).values + + return { + "output_ids": output_ids, + "scores": max_token_scores, + "entropy": normalised_entropy, + } diff --git a/tests/test_inference.py b/tests/test_inference.py index 19fa240..a7e6b08 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -2,9 +2,7 @@ from unittest.mock import MagicMock, patch import pytest -from sklearn.metrics import hamming_loss -from arc_spice.eval.classification_error import zero_one_loss_ceil from arc_spice.utils import open_yaml_path from arc_spice.variational_pipelines.RTC_single_component_pipeline import ( ClassificationVariationalPipeline, @@ -42,35 +40,6 @@ def dummy_metadata(): } -def test_errors(): - dummy_target = [0, 1, 0, 1, 0] - dummy_middle_output = [1, 1, 0, 1, 0] - - assert hamming_loss(dummy_target, dummy_middle_output) == pytest.approx( - 0.2, abs=1e-5 - ) - assert zero_one_loss_ceil(dummy_target, dummy_middle_output) == pytest.approx( - 1.0, abs=1e-5 - ) - - dummy_correct_output = [0, 1, 0, 1, 0] - - assert hamming_loss(dummy_target, dummy_correct_output) == pytest.approx( - 0.0, abs=1e-5 - ) - assert zero_one_loss_ceil(dummy_target, dummy_correct_output) == pytest.approx( - 0.0, abs=1e-5 - ) - - dummy_incorrect_output = [1, 0, 1, 0, 1] - assert hamming_loss(dummy_target, dummy_incorrect_output) == pytest.approx( - 1.0, abs=1e-5 - ) - assert zero_one_loss_ceil(dummy_target, dummy_incorrect_output) == pytest.approx( - 1.0, abs=1e-5 - ) - - def test_pipeline_inputs(dummy_data, dummy_metadata): pipeline_config = open_yaml_path(PIPELINE_PATH)