Skip to content

Commit

Permalink
addressed comments from pull request, also added additional outputs t…
Browse files Browse the repository at this point in the history
…o translation to allow additional confidence measures
  • Loading branch information
J-Dymond committed Dec 6, 2024
1 parent 7ac88b8 commit 92a98f0
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 118 deletions.
7 changes: 0 additions & 7 deletions src/arc_spice/eval/classification_error.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import math

import torch
from sklearn.metrics import zero_one_loss


def aggregate_score(probs: torch.Tensor) -> torch.Tensor:
Expand All @@ -11,10 +8,6 @@ def aggregate_score(probs: torch.Tensor) -> torch.Tensor:
return 1 - torch.mean(distance)


def zero_one_loss_ceil(y_target, y_pred):
return math.ceil(zero_one_loss(y_target, y_pred, normalize=True))


def MC_dropout_scores(
variational_probs: list[float], epsilon: float = 1e-14
) -> dict[str, torch.Tensor]:
Expand Down
27 changes: 18 additions & 9 deletions src/arc_spice/eval/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,26 @@
)

RecognitionResults = namedtuple("RecognitionResults", ["confidence", "accuracy"])
ClassificationResults = namedtuple(
"ClassificationResults",
[
"clean_scores",
"mean_scores",
"hamming_accuracy",
"mean_predicted_entropy",
],
)

TranslationResults = namedtuple(
"TranslationResults",
[
"full_output",
"clean_conditional_probability",
"comet_score",
"weighted_semantic_density",
"mean_entropy",
"sequence_lengths",
],
)

ClassificationResults = namedtuple(
"ClassificationResults",
[
"clean_scores",
"mean_scores",
"hamming_accuracy",
"mean_predicted_entropy",
],
)

Expand Down Expand Up @@ -79,6 +83,8 @@ def translation_results(
source_text = test_row["target_text"]
target_text = test_row["target_text"]
clean_translation = clean_output["translation"]["full_output"]
clean_entropy: torch.Tensor = clean_output["translation"]["mean_entropy"]
seq_lens: torch.Tensor = var_output["translation"]["sequence_length"]
probs: list[torch.Tensor] = clean_output["translation"]["probs"]
clean_cond_prob = [
conditional_probability(prob.squeeze()).detach().tolist() for prob in probs
Expand All @@ -102,6 +108,8 @@ def translation_results(
comet_score=comet_output["scores"][0],
full_output=clean_translation,
clean_conditional_probability=clean_cond_prob,
mean_entropy=clean_entropy,
sequence_lengths=seq_lens,
weighted_semantic_density=var_output["translation"][
"weighted_semantic_density"
],
Expand Down Expand Up @@ -144,4 +152,5 @@ def run_inference(
test_row=inp,
)
results.append({inp["celex_id"]: row_results_dict})
break
return results
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
from transformers import pipeline

from arc_spice.variational_pipelines.RTC_variational_pipeline import (
CustomTranslationPipeline,
RTCVariationalPipelineBase,
)
from arc_spice.variational_pipelines.utils import dropout_off, dropout_on, set_dropout
from arc_spice.variational_pipelines.utils import (
CustomTranslationPipeline,
dropout_off,
dropout_on,
set_dropout,
)


class RTCSingleComponentPipeline(RTCVariationalPipelineBase):
Expand All @@ -34,19 +38,6 @@ def __init__(
# define objects that are needed and nothing else
# naive outputs can remain the same, though only the appropriate outputs will
# be outputted
self.naive_outputs = {
"recognition": [
"outputs",
],
"translation": [
"full_output",
"outputs",
"probs",
],
"classification": [
"scores",
],
}
self.step_name = step_name
self.input_key = input_key
self.forward_function = forward_function
Expand Down
57 changes: 2 additions & 55 deletions src/arc_spice/variational_pipelines/RTC_variational_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import copy
from typing import Any

import torch
from torch.nn.functional import softmax
from transformers import TranslationPipeline, pipeline
from transformers import pipeline

from arc_spice.variational_pipelines.utils import (
CustomTranslationPipeline,
RTCVariationalPipelineBase,
dropout_off,
dropout_on,
Expand Down Expand Up @@ -134,55 +133,3 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
# on standard call return the clean output
def __call__(self, x):
return self.clean_inference(x)


# Translation pipeline with additional functionality to save logits from fwd pass
class CustomTranslationPipeline(TranslationPipeline):
"""
custom translation pipeline to return the logits with the generated text. Largely
the same as the pytorch version with some additional arguments passed to the
`generate` method.
"""

def postprocess(
self,
model_outputs: dict,
**postprocess_params,
):
# model_outputs gets overwritten in the super().postprocess call
# make a copy here so we retain the information we want
raw_out = copy.deepcopy(model_outputs)
processed = super().postprocess(model_outputs, **postprocess_params)

return {
"translation_text": processed[0]["translation_text"],
"raw_outputs": raw_out,
}

def _forward(self, model_inputs, **generate_kwargs):
if self.framework == "pt":
in_b, input_length = model_inputs["input_ids"].shape
elif self.framework == "tf":
raise NotImplementedError

self.check_inputs(
input_length,
generate_kwargs.get("min_length", self.model.config.min_length),
generate_kwargs.get("max_length", self.model.config.max_length),
)
out = self.model.generate(**model_inputs, **generate_kwargs)
output_ids = out["sequences"]
out_b = output_ids.shape[0]
if self.framework == "pt":
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
elif self.framework == "tf":
raise NotImplementedError

# logits are a tuple of length output_ids[-1]-1
# each element is a tensor of shape (batch_size, vocab_size)
logits = torch.stack(out["logits"], dim=1)
# get softmax of the logits to get token probabilities
softmax_logits = softmax(logits, dim=-1)
max_token_scores = torch.max(softmax_logits, dim=-1).values

return {"output_ids": output_ids, "scores": max_token_scores}
74 changes: 73 additions & 1 deletion src/arc_spice/variational_pipelines/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import copy
import logging
import math
from abc import ABC, abstractmethod
from functools import partial
from typing import Any

import torch
from torch.nn.functional import softmax
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Pipeline
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Pipeline,
TranslationPipeline,
)

logger = logging.Logger("RTC_variational_pipeline")

Expand Down Expand Up @@ -117,6 +124,7 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8):
"full_output",
"outputs",
"probs",
"mean_entropy",
],
"classification": [
"scores",
Expand Down Expand Up @@ -264,6 +272,9 @@ def translate(self, text: str) -> dict[str, torch.Tensor | str]:
{
"outputs": translator_output["translation_text"],
"probs": translator_output["raw_outputs"]["scores"],
"mean_entropy": torch.mean(translator_output["raw_outputs"]["entropy"])
.detach()
.tolist(),
}
for translator_output in translator_outputs
]
Expand Down Expand Up @@ -430,6 +441,7 @@ def translation_semantic_density(
{
"semantic_densities": densities,
"weighted_semantic_density": weighted_average.item(),
"sequence_length": sequence_lengths,
}
)

Expand Down Expand Up @@ -480,3 +492,63 @@ def get_classification_confidence(
}
)
return var_output


# Translation pipeline with additional functionality to save logits from fwd pass
class CustomTranslationPipeline(TranslationPipeline):
"""
custom translation pipeline to return the logits with the generated text. Largely
the same as the pytorch version with some additional arguments passed to the
`generate` method.
"""

def postprocess(
self,
model_outputs: dict,
**postprocess_params,
):
# model_outputs gets overwritten in the super().postprocess call
# make a copy here so we retain the information we want
raw_out = copy.deepcopy(model_outputs)
processed = super().postprocess(model_outputs, **postprocess_params)

return {
"translation_text": processed[0]["translation_text"],
"raw_outputs": raw_out,
}

def _forward(self, model_inputs, **generate_kwargs):
if self.framework == "pt":
in_b, input_length = model_inputs["input_ids"].shape
elif self.framework == "tf":
raise NotImplementedError

self.check_inputs(
input_length,
generate_kwargs.get("min_length", self.model.config.min_length),
generate_kwargs.get("max_length", self.model.config.max_length),
)
out = self.model.generate(**model_inputs, **generate_kwargs)
output_ids = out["sequences"]
out_b = output_ids.shape[0]
if self.framework == "pt":
output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
elif self.framework == "tf":
raise NotImplementedError

# logits are a tuple of length output_ids[-1]-1
# each element is a tensor of shape (batch_size, vocab_size)
logits = torch.stack(out["logits"], dim=1)
# get softmax of the logits to get token probabilities
softmax_logits = softmax(logits, dim=-1)
vocab_size = softmax_logits.shape[-1]
normalised_entropy = torch.distributions.Categorical(
probs=softmax_logits
).entropy() / math.log(vocab_size)
max_token_scores = torch.max(softmax_logits, dim=-1).values

return {
"output_ids": output_ids,
"scores": max_token_scores,
"entropy": normalised_entropy,
}
31 changes: 0 additions & 31 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from unittest.mock import MagicMock, patch

import pytest
from sklearn.metrics import hamming_loss

from arc_spice.eval.classification_error import zero_one_loss_ceil
from arc_spice.utils import open_yaml_path
from arc_spice.variational_pipelines.RTC_single_component_pipeline import (
ClassificationVariationalPipeline,
Expand Down Expand Up @@ -42,35 +40,6 @@ def dummy_metadata():
}


def test_errors():
dummy_target = [0, 1, 0, 1, 0]
dummy_middle_output = [1, 1, 0, 1, 0]

assert hamming_loss(dummy_target, dummy_middle_output) == pytest.approx(
0.2, abs=1e-5
)
assert zero_one_loss_ceil(dummy_target, dummy_middle_output) == pytest.approx(
1.0, abs=1e-5
)

dummy_correct_output = [0, 1, 0, 1, 0]

assert hamming_loss(dummy_target, dummy_correct_output) == pytest.approx(
0.0, abs=1e-5
)
assert zero_one_loss_ceil(dummy_target, dummy_correct_output) == pytest.approx(
0.0, abs=1e-5
)

dummy_incorrect_output = [1, 0, 1, 0, 1]
assert hamming_loss(dummy_target, dummy_incorrect_output) == pytest.approx(
1.0, abs=1e-5
)
assert zero_one_loss_ceil(dummy_target, dummy_incorrect_output) == pytest.approx(
1.0, abs=1e-5
)


def test_pipeline_inputs(dummy_data, dummy_metadata):
pipeline_config = open_yaml_path(PIPELINE_PATH)

Expand Down

0 comments on commit 92a98f0

Please sign in to comment.