Skip to content

Commit

Permalink
fixed pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Nov 15, 2024
1 parent 570a507 commit 3fe9180
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 869 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ repos:
hooks:
- id: mypy
files: src
args: []
args: [--ignore-missing-imports]
additional_dependencies:
- pytest
2 changes: 1 addition & 1 deletion data/MultiEURLEX/data/README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Multi-EURLEX files
This folder contains files for the loading of [Multi-EURLEX](https://aclanthology.org/2021.emnlp-main.559/), these files are taken from the [official repo](https://github.com/nlpaueb/multi-eurlex).
This folder contains files for the loading of [Multi-EURLEX](https://aclanthology.org/2021.emnlp-main.559/), these files are taken from the [official repo](https://github.com/nlpaueb/multi-eurlex).
2 changes: 1 addition & 1 deletion data/MultiEURLEX/data/eurovoc_concepts.json
Original file line number Diff line number Diff line change
Expand Up @@ -14797,4 +14797,4 @@
"4355",
"5318"
]
}
}
809 changes: 0 additions & 809 deletions notebooks/tts_pipeline_nb.ipynb

This file was deleted.

14 changes: 10 additions & 4 deletions scripts/variational_RTC_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
An example use of the transcription, translation and summarisation pipeline.
An example use of the transcription, translation and summarisation pipeline.
"""

import logging
Expand Down Expand Up @@ -42,13 +42,19 @@ def load_test_row():

def get_test_row(train_data):
row_iterator = iter(train_data)
for _ in range(0, randint(1, 25)):
for _ in range(randint(1, 25)):
test_row = next(row_iterator)

# debug row if needed
return {
"source_text": "Le renard brun rapide a sauté par-dessus le chien paresseux. Le renard a sauté par-dessus le chien paresseux.",
"target_text": "The quick brown fox jumped over the lazy dog. The fox jumped over the lazy dog",
"source_text": (
"Le renard brun rapide a sauté par-dessus le chien paresseux."
"Le renard a sauté par-dessus le chien paresseux."
),
"target_text": (
"The quick brown fox jumped over the lazy dog. The fox jumped"
" over the lazy dog"
),
"class_labels": [0, 1],
}
# Normal row
Expand Down
4 changes: 3 additions & 1 deletion src/arc_spice/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
ARC-SPICE: Sample Level, Pipeline Introduced Cumulative Errors (SPICE). Investigating methods for generalisable measurement of cumulative errors in multi-model (and multi-modal) ML pipelines.
ARC-SPICE: Sample Level, Pipeline Introduced Cumulative Errors (SPICE).
Investigating methods for generalisable measurement of cumulative errors in
multi-model (and multi-modal) ML pipelines.
"""

from __future__ import annotations
Expand Down
28 changes: 11 additions & 17 deletions src/arc_spice/data/multieurlex_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from typing import Union

import torch
from datasets import load_dataset
Expand All @@ -23,8 +22,7 @@ def __call__(self, class_labels: list[int]) -> torch.Tensor:
torch.tensor(class_labels),
num_classes=self.n_classes,
)
one_hot_multi_class = torch.sum(one_hot_class_labels, dim=0)
return one_hot_multi_class
return torch.sum(one_hot_class_labels, dim=0)


def _extract_articles(text: str, article_1_marker: str):
Expand All @@ -36,7 +34,7 @@ def _extract_articles(text: str, article_1_marker: str):
return text[start:]


def extract_articles(item: LazyRow, lang_pair: dict[str:str]):
def extract_articles(item: LazyRow, lang_pair: dict[str, str]):
lang_source = lang_pair["source"]
lang_target = lang_pair["target"]
return {
Expand All @@ -54,13 +52,13 @@ def extract_articles(item: LazyRow, lang_pair: dict[str:str]):
class PreProcesser:
"""Function to preprocess the data, for the purposes of removing unused languages"""

def __init__(self, language_pair: dict[str:str]) -> None:
def __init__(self, language_pair: dict[str, str]) -> None:
self.source_language = language_pair["source"]
self.target_language = language_pair["target"]

def __call__(
self, data_row: dict[str : Union[str, list]]
) -> dict[str : Union[str, list]]:
self, data_row: dict[str, dict[str, str]]
) -> dict[str, str | dict[str, str]]:
"""
processes the row in the dataset
Expand All @@ -73,17 +71,16 @@ def __call__(
source_text = data_row["text"][self.source_language]
target_text = data_row["text"][self.target_language]
labels = data_row["labels"]
row = {
return {
"source_text": source_text,
"target_text": target_text,
"class_labels": labels,
}
return row


def load_multieurlex(
data_dir: str, level: int, lang_pair: dict[str:str]
) -> tuple[list, dict[str : Union[int, list]]]:
data_dir: str, level: int, lang_pair: dict[str, str]
) -> tuple[list, dict[str, int | list]]:
"""
load the multieurlex dataset
Expand All @@ -96,19 +93,16 @@ def load_multieurlex(
List of datasets and a dictionary with some metadata information
"""
assert level in [1, 2, 3], "there are 3 levels of hierarchy: 1,2,3."
with open(
f"{data_dir}/MultiEURLEX/data/eurovoc_concepts.json", "r"
) as concepts_file:
with open(f"{data_dir}/MultiEURLEX/data/eurovoc_concepts.json") as concepts_file:
class_concepts = json.loads(concepts_file.read())
concepts_file.close()

with open(
f"{data_dir}/MultiEURLEX/data/eurovoc_descriptors.json", "r"
f"{data_dir}/MultiEURLEX/data/eurovoc_descriptors.json"
) as descriptors_file:
class_descriptors = json.loads(descriptors_file.read())
descriptors_file.close()
# format level for the class descriptor dictionary, add these to a list
level = f"level_{level}"
classes = class_concepts[level]
descriptors = []
for class_id in classes:
Expand All @@ -118,7 +112,7 @@ def load_multieurlex(
data = load_dataset(
"multi_eurlex",
"all_languages",
label_level=level,
label_level=f"level_{level}",
trust_remote_code=True,
)
# define metadata
Expand Down
5 changes: 2 additions & 3 deletions src/arc_spice/eval/classification_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

def hamming_accuracy(preds: torch.Tensor, class_labels: torch.Tensor) -> torch.Tensor:
# Inverse of the hamming loss (the fraction of labels incorrectly predicted)
accuracy = torch.mean((preds.float() == class_labels.float()).float())
return accuracy
return torch.mean((preds.float() == class_labels.float()).float())


def aggregate_score(probs: torch.Tensor) -> torch.Tensor:
Expand All @@ -16,7 +15,7 @@ def aggregate_score(probs: torch.Tensor) -> torch.Tensor:

def MC_dropout_scores(
variational_probs: list[float], epsilon: float = 1e-14
) -> dict[str : torch.Tensor]:
) -> dict[str, torch.Tensor]:
# aggregate over the classes, performing MC Dropout on each class treating it
# as a binary classification problem
stacked_probs = torch.stack(
Expand Down
3 changes: 1 addition & 2 deletions src/arc_spice/eval/translation_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ def get_bleu_score(target, translation):
def get_comet_model(model_path="Unbabel/wmt22-comet-da"):
# Load the model checkpoint:
comet_model_pth = download_model(model=model_path)
comet_model = load_from_checkpoint(comet_model_pth)
return comet_model
return load_from_checkpoint(comet_model_pth)
62 changes: 32 additions & 30 deletions src/arc_spice/variational_pipelines/RTC_variational_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import logging
from typing import Any, Union
from typing import Any

import torch
from torch.nn.functional import softmax
Expand Down Expand Up @@ -42,20 +42,21 @@ class RTCVariationalPipeline:

def __init__(
self,
model_pars: dict[str : dict[str:str]],
model_pars: dict[str, dict[str, str]],
data_pars,
n_variational_runs=5,
translation_batch_size=8,
) -> None:

# device for inference
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)

logging.info(f"Loading pipeline on device: {device}")
debug_msg_device = f"Loading pipeline on device: {device}"
logging.info(debug_msg_device)

# defining the pipeline objects
self.ocr = pipeline(
Expand Down Expand Up @@ -144,8 +145,7 @@ def split_translate_inputs(text: str, split_key: str) -> list[str]:
# for when string ends with with the delimiter
if split_rows[-1] == "":
split_rows = split_rows[:-1]
recovered_splits = [split + split_key for split in split_rows]
return recovered_splits
return [split + split_key for split in split_rows]

def check_dropout(self):
"""
Expand All @@ -158,17 +158,20 @@ def check_dropout(self):
for model_key, pl in self.pipeline_map.items():
# turn on dropout for this model
set_dropout(model=pl.model, dropout_flag=True)
logger.debug(f"Model key: {model_key}")
debug_msg_key = f"Model key: {model_key}"
logger.debug(debug_msg_key)
dropout_count = count_dropout(pipe=pl, dropout_flag=True)
logger.debug(
debug_msg_count = (
f"{dropout_count} dropout layers found in correct configuration."
)
logger.debug(debug_msg_count)
if dropout_count == 0:
raise ValueError(f"No dropout layers found in {model_key}")
error_message = f"No dropout layers found in {model_key}"
raise ValueError(error_message)
set_dropout(model=pl.model, dropout_flag=False)
logger.debug("-------------------------------------------------------\n\n")

def recognise(self, inp) -> dict[str:str]:
def recognise(self, inp) -> dict[str, str]:
"""
Function to perform OCR
Expand All @@ -182,7 +185,7 @@ def recognise(self, inp) -> dict[str:str]:
# TODO https://github.com/alan-turing-institute/ARC-SPICE/issues/14
return {"outputs": inp}

def translate(self, text: str) -> dict[str : [torch.Tensor, str]]:
def translate(self, text: str) -> dict[str, torch.Tensor | str]:
"""
Function to perform translation
Expand Down Expand Up @@ -224,12 +227,12 @@ def translate(self, text: str) -> dict[str : [torch.Tensor, str]]:
confidence_metrics
)
# add full output to the output dict
outputs = {"full_output": full_translation}
outputs: dict[str, Any] = {"full_output": full_translation}
outputs.update(stacked_conf_metrics)
# {full translation, sentence translations, logits, semantic embeddings}
return outputs

def classify_topic(self, text: str) -> dict[str:str]:
def classify_topic(self, text: str) -> dict[str, str]:
"""
Runs the classification model
Expand All @@ -240,8 +243,8 @@ def classify_topic(self, text: str) -> dict[str:str]:
return {"scores": forward["scores"]}

def stack_translator_sentence_metrics(
self, all_sentence_metrics: list[dict[str:Any]]
) -> dict[str : list[Any]]:
self, all_sentence_metrics: list[dict[str, Any]]
) -> dict[str, list[Any]]:
"""
Stacks values from dictionary list into lists under a single key
Expand All @@ -256,15 +259,15 @@ def stack_translator_sentence_metrics(
]
return stacked

def stack_variational_outputs(self, var_output):
def stack_variational_outputs(self, var_output: dict[str, list[Any]]):
"""
Similar to above but this stacks variational output dictinaries into lists
under a single key.
"""
# Create new dict
new_var_dict = {}
new_var_dict: dict[str, Any] = {}
# For each key create a new dict
for step in var_output.keys():
for step in var_output:
new_var_dict[step] = {}
# for each metric in a clean inference run (naive_ouputs)
for metric in self.naive_outputs[step]:
Expand Down Expand Up @@ -333,7 +336,7 @@ def sentence_density(

def translation_semantic_density(
self, clean_output, var_output: dict
) -> dict[str : Union[float, list[float]]]:
) -> dict[str, float | list[Any]]:
"""
Runs the semantic density measurement from https://arxiv.org/pdf/2405.13845.
Expand All @@ -353,8 +356,8 @@ def translation_semantic_density(
var_steps = var_output["translation"]
n_sentences = len(clean_out)
# define empty lists for the measurements
densities = [None] * n_sentences
sequence_lengths = [None] * n_sentences
densities: list[Any] = [None] * n_sentences
sequence_lengths: list[Any] = [None] * n_sentences
# stack the variational runs according to their sentences, then loop and pass to
# density calculation function
for sentence_index, clean_sentence in enumerate(clean_out):
Expand Down Expand Up @@ -387,7 +390,7 @@ def translation_semantic_density(

def get_classification_confidence(
self, var_output: dict, epsilon: float = 1e-15
) -> dict[str : Union[float, torch.Tensor]]:
) -> dict[str, float | torch.Tensor]:
"""
_summary_
Expand Down Expand Up @@ -431,10 +434,10 @@ def get_classification_confidence(
)
return var_output

def clean_inference(self, x: torch.Tensor) -> dict[str:dict]:
def clean_inference(self, x: torch.Tensor) -> dict[str, dict]:
"""Run the pipeline on an input x"""
# define output dictionary
clean_output = {
clean_output: dict[str, Any] = {
"recognition": {},
"translation": {},
"classification": {},
Expand All @@ -452,14 +455,14 @@ def clean_inference(self, x: torch.Tensor) -> dict[str:dict]:
)
return clean_output

def variational_inference(self, x: torch.Tensor) -> dict[str:dict]:
def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]:
"""
runs the variational inference with the pipeline
"""
# ...first run clean inference
clean_output = self.clean_inference(x)
# define output dictionary
var_output = {
var_output: dict[str, Any] = {
"recognition": [None] * self.n_variational_runs,
"translation": [None] * self.n_variational_runs,
"classification": [None] * self.n_variational_runs,
Expand Down Expand Up @@ -517,11 +520,10 @@ def postprocess(
raw_out = copy.deepcopy(model_outputs)
processed = super().postprocess(model_outputs, **postprocess_params)

new_output = {
return {
"translation_text": processed[0]["translation_text"],
"raw_outputs": raw_out,
}
return new_output

def _forward(self, model_inputs, **generate_kwargs):
if self.framework == "pt":
Expand Down

0 comments on commit 3fe9180

Please sign in to comment.