Skip to content

Commit

Permalink
commented data utils
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Nov 13, 2024
1 parent 7f6db59 commit 704f5f7
Showing 1 changed file with 41 additions and 8 deletions.
49 changes: 41 additions & 8 deletions src/arc_spice/data/multieurlex_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Union

import torch
from datasets import load_dataset
Expand All @@ -14,14 +15,19 @@
class_descriptors = json.loads(descriptors_file.read())
descriptors_file.close()

# For identifying where the adopted decisions begin
ARTICLE_1_MARKERS = {"en": "\nArticle 1\n", "fr": "\nArticle premier\n"}


# creates a multi-hot vector for classification loss
class MultiHot:
def __init__(self, n_classes):
"""Class that will multi-hot encode the classes for classification."""

def __init__(self, n_classes: int) -> None:
self.n_classes = n_classes

def __call__(self, class_labels):
def __call__(self, class_labels: list[int]) -> torch.Tensor:
# create list of one-hots and sum down the class axis
one_hot_class_labels = one_hot(
torch.tensor(class_labels),
num_classes=self.n_classes,
Expand Down Expand Up @@ -55,11 +61,24 @@ def extract_articles(item: LazyRow, lang_pair: dict[str:str]):


class PreProcesser:
def __init__(self, language_pair):
"""Function to preprocess the data, for the purposes of removing unused languages"""

def __init__(self, language_pair: dict[str:str]) -> None:
self.source_language = language_pair["source"]
self.target_language = language_pair["target"]

def __call__(self, data_row):
def __call__(
self, data_row: dict[str : Union[str, list]]
) -> dict[str : Union[str, list]]:
"""
processes the row in the dataset
Args:
data_row: input row
Returns:
row : processed row with relevant items
"""
source_text = data_row["text"][self.source_language]
target_text = data_row["text"][self.target_language]
labels = data_row["labels"]
Expand All @@ -71,30 +90,43 @@ def __call__(self, data_row):
return row


def load_multieurlex(level, lang_pair):
def load_multieurlex(
level: int, lang_pair: dict[str:str]
) -> tuple[list[Dataset], dict[str : Union[int, list]]]:
"""
load the multieurlex dataset
assert level in [1, 2, 3], "there are 3 levels of hierarchy: 1,2,3."
Args:
level: level of hierarchy/specicifity of the labels
lang_pair: dictionary specifying the language pair.
Returns:
List of datasets and a dictionary with some metadata information
"""
assert level in [1, 2, 3], "there are 3 levels of hierarchy: 1,2,3."
# format level for the class descriptor dictionary, add these to a list
level = f"level_{level}"
classes = class_concepts[level]
descriptors = []
for class_id in classes:
descriptors.append(class_descriptors[class_id])

# load the dataset with huggingface API
data = load_dataset(
"multi_eurlex",
"all_languages",
label_level=level,
trust_remote_code=True,
)

# define metadata
meta_data = {
"n_classes": len(classes),
"class_labels": classes,
"class_descriptors": descriptors,
}
# instantiate the preprocessor
preprocesser = PreProcesser(lang_pair)

# preprocess each split
train_dataset = data["train"].map(preprocesser, remove_columns=["text"])
extracted_train = train_dataset.map(
extract_articles,
Expand All @@ -110,4 +142,5 @@ def load_multieurlex(level, lang_pair):
extract_articles,
fn_kwargs={"lang_pair": lang_pair},
)
# return datasets and metadata
return [extracted_train, extracted_test, extracted_val], meta_data

0 comments on commit 704f5f7

Please sign in to comment.