Skip to content

Commit

Permalink
Made changes suggested in pull request
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Nov 15, 2024
1 parent 704f5f7 commit 570a507
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 185 deletions.
2 changes: 2 additions & 0 deletions data/MultiEURLEX/data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Multi-EURLEX files
This folder contains files for the loading of [Multi-EURLEX](https://aclanthology.org/2021.emnlp-main.559/), these files are taken from the [official repo](https://github.com/nlpaueb/multi-eurlex).
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ dependencies = [
"numpy",
"sentencepiece",
"librosa",
"soundfile",
"torch",
"torcheval",
"pillow",
Expand Down
52 changes: 23 additions & 29 deletions scripts/variational_RTC_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def seed_everything(seed):

def load_test_row():
lang_pair = {"source": "fr", "target": "en"}
(train, _, _), metadata_params = load_multieurlex(level=1, lang_pair=lang_pair)
(train, _, _), metadata_params = load_multieurlex(
data_dir="data", level=1, lang_pair=lang_pair
)
multi_onehot = MultiHot(metadata_params["n_classes"])
test_row = get_test_row(train)
class_labels = multi_onehot(test_row["class_labels"])
Expand All @@ -44,26 +46,22 @@ def get_test_row(train_data):
test_row = next(row_iterator)

# debug row if needed
# return {
# "source_text": "Le renard brun rapide a sauté par-dessus le chien paresseux. Le renard a sauté par-dessus le chien paresseux.",
# "target_text": "The quick brown fox jumped over the lazy dog. The fox jumped over the lazy dog",
# "class_labels": [0, 1],
# }
return {
"source_text": "Le renard brun rapide a sauté par-dessus le chien paresseux. Le renard a sauté par-dessus le chien paresseux.",
"target_text": "The quick brown fox jumped over the lazy dog. The fox jumped over the lazy dog",
"class_labels": [0, 1],
}
# Normal row
return test_row


def print_results(rtc_variational_pipeline, class_labels, test_row, comet_model):
def print_results(clean_output, var_output, class_labels, test_row, comet_model):
# ### TRANSLATION ###
print("\nTranslation:")
source_text = test_row["target_text"]
target_text = test_row["target_text"]
clean_translation = rtc_variational_pipeline.clean_output["translation"][
"full_output"
]
print(
f"Semantic density: {rtc_variational_pipeline.var_output['translation']['weighted_semantic_density']}"
)
clean_translation = clean_output["translation"]["full_output"]
print(f"Semantic density: {var_output['translation']['weighted_semantic_density']}")

# load error model
comet_inp = [
Expand All @@ -83,50 +81,46 @@ def print_results(rtc_variational_pipeline, class_labels, test_row, comet_model)

# ### CLASSIFICATION ###
print("\nClassification:")
mean_scores = rtc_variational_pipeline.var_output["classification"]["mean_scores"]
mean_scores = var_output["classification"]["mean_scores"]
print(f"BCE: {binary_cross_entropy(mean_scores.float(), class_labels.float())}")
preds = torch.round(mean_scores)
hamming_acc = hamming_accuracy(preds=preds, class_labels=class_labels)
print(f"hamming accuracy: {hamming_acc}")

mean_entropy = torch.mean(
rtc_variational_pipeline.var_output["classification"]["predicted_entropy"]
)
mean_variances = torch.mean(
rtc_variational_pipeline.var_output["classification"]["var_scores"]
)
mean_MI = torch.mean(
rtc_variational_pipeline.var_output["classification"]["mutual_information"]
)
mean_entropy = torch.mean(var_output["classification"]["predicted_entropy"])
mean_variances = torch.mean(var_output["classification"]["var_scores"])
mean_MI = torch.mean(var_output["classification"]["mutual_information"])

print("Predictive entropy: " f"{mean_entropy}")
print("MI (model uncertainty): " f"{mean_MI}")
print("Variance (model uncertainty): " f"{mean_variances}")


def main(RTC_pars):
def main(rtc_pars):
seed_everything(seed=42)

logging.basicConfig(level=logging.INFO)

test_row, class_labels, metadata_params = load_test_row()

# initialise pipeline
rtc_variational_pipeline = RTCVariationalPipeline(RTC_pars, metadata_params)
rtc_variational_pipeline = RTCVariationalPipeline(rtc_pars, metadata_params)

# check dropout exists
rtc_variational_pipeline.check_dropout()

# perform variational inference
rtc_variational_pipeline.variational_inference(test_row["source_text"])
clean_output, var_output = rtc_variational_pipeline.variational_inference(
test_row["source_text"]
)

comet_model = get_comet_model()

print_results(rtc_variational_pipeline, class_labels, test_row, comet_model)
print_results(clean_output, var_output, class_labels, test_row, comet_model)


if __name__ == "__main__":
RTC_pars = {
rtc_pars = {
"OCR": {
"specific_task": "image-to-text",
"model": "microsoft/trocr-base-handwritten",
Expand All @@ -140,4 +134,4 @@ def main(RTC_pars):
"model": "claritylab/zero-shot-explicit-binary-bert",
},
}
main(RTC_pars=RTC_pars)
main(rtc_pars=rtc_pars)
45 changes: 21 additions & 24 deletions src/arc_spice/data/multieurlex_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,6 @@
from datasets import load_dataset
from datasets.formatting.formatting import LazyRow
from torch.nn.functional import one_hot
from torch.utils.data import Dataset

with open("data/MultiEURLEX/data/eurovoc_concepts.json", "r") as concepts_file:
class_concepts = json.loads(concepts_file.read())
concepts_file.close()

with open("data/MultiEURLEX/data/eurovoc_descriptors.json", "r") as descriptors_file:
class_descriptors = json.loads(descriptors_file.read())
descriptors_file.close()

# For identifying where the adopted decisions begin
ARTICLE_1_MARKERS = {"en": "\nArticle 1\n", "fr": "\nArticle premier\n"}
Expand Down Expand Up @@ -91,19 +82,31 @@ def __call__(


def load_multieurlex(
level: int, lang_pair: dict[str:str]
) -> tuple[list[Dataset], dict[str : Union[int, list]]]:
data_dir: str, level: int, lang_pair: dict[str:str]
) -> tuple[list, dict[str : Union[int, list]]]:
"""
load the multieurlex dataset
Args:
data_dir: root directory for the dataset class descriptors and concepts
level: level of hierarchy/specicifity of the labels
lang_pair: dictionary specifying the language pair.
Returns:
List of datasets and a dictionary with some metadata information
"""
assert level in [1, 2, 3], "there are 3 levels of hierarchy: 1,2,3."
with open(
f"{data_dir}/MultiEURLEX/data/eurovoc_concepts.json", "r"
) as concepts_file:
class_concepts = json.loads(concepts_file.read())
concepts_file.close()

with open(
f"{data_dir}/MultiEURLEX/data/eurovoc_descriptors.json", "r"
) as descriptors_file:
class_descriptors = json.loads(descriptors_file.read())
descriptors_file.close()
# format level for the class descriptor dictionary, add these to a list
level = f"level_{level}"
classes = class_concepts[level]
Expand All @@ -127,20 +130,14 @@ def load_multieurlex(
# instantiate the preprocessor
preprocesser = PreProcesser(lang_pair)
# preprocess each split
train_dataset = data["train"].map(preprocesser, remove_columns=["text"])
extracted_train = train_dataset.map(
extract_articles,
fn_kwargs={"lang_pair": lang_pair},
)
test_dataset = data["test"].map(preprocesser, remove_columns=["text"])
extracted_test = train_dataset.map(
extract_articles,
fn_kwargs={"lang_pair": lang_pair},
)
val_dataset = data["validation"].map(preprocesser, remove_columns=["text"])
extracted_val = train_dataset.map(
dataset = data.map(preprocesser, remove_columns=["text"])
extracted_dataset = dataset.map(
extract_articles,
fn_kwargs={"lang_pair": lang_pair},
)
# return datasets and metadata
return [extracted_train, extracted_test, extracted_val], meta_data
return [
extracted_dataset["train"],
extracted_dataset["test"],
extracted_dataset["validation"],
], meta_data
11 changes: 5 additions & 6 deletions src/arc_spice/eval/classification_error.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
# hamming accuracy
# zero-one accuracy

import torch


def hamming_accuracy(preds: torch.Tensor, class_labels: torch.Tensor):
def hamming_accuracy(preds: torch.Tensor, class_labels: torch.Tensor) -> torch.Tensor:
# Inverse of the hamming loss (the fraction of labels incorrectly predicted)
accuracy = torch.mean((preds.float() == class_labels.float()).float())
return accuracy


def aggregate_score(probs: torch.tensor):
def aggregate_score(probs: torch.Tensor) -> torch.Tensor:
# average 'distance' from the predicted class
preds = torch.round(probs).float()
distance = torch.abs(preds - probs)
return 1 - torch.mean(distance)


def MC_dropout_scores(variational_probs, epsilon=1e-14):
def MC_dropout_scores(
variational_probs: list[float], epsilon: float = 1e-14
) -> dict[str : torch.Tensor]:
# aggregate over the classes, performing MC Dropout on each class treating it
# as a binary classification problem
stacked_probs = torch.stack(
Expand Down
90 changes: 0 additions & 90 deletions src/arc_spice/pipelines/RTC_pipeline.py

This file was deleted.

Loading

0 comments on commit 570a507

Please sign in to comment.