diff --git a/src/arc_spice/data/multieurlex_utils.py b/src/arc_spice/data/multieurlex_utils.py index 8607b07..d914cea 100644 --- a/src/arc_spice/data/multieurlex_utils.py +++ b/src/arc_spice/data/multieurlex_utils.py @@ -1,4 +1,5 @@ import json +from typing import Union import torch from datasets import load_dataset @@ -14,14 +15,19 @@ class_descriptors = json.loads(descriptors_file.read()) descriptors_file.close() +# For identifying where the adopted decisions begin ARTICLE_1_MARKERS = {"en": "\nArticle 1\n", "fr": "\nArticle premier\n"} +# creates a multi-hot vector for classification loss class MultiHot: - def __init__(self, n_classes): + """Class that will multi-hot encode the classes for classification.""" + + def __init__(self, n_classes: int) -> None: self.n_classes = n_classes - def __call__(self, class_labels): + def __call__(self, class_labels: list[int]) -> torch.Tensor: + # create list of one-hots and sum down the class axis one_hot_class_labels = one_hot( torch.tensor(class_labels), num_classes=self.n_classes, @@ -55,11 +61,24 @@ def extract_articles(item: LazyRow, lang_pair: dict[str:str]): class PreProcesser: - def __init__(self, language_pair): + """Function to preprocess the data, for the purposes of removing unused languages""" + + def __init__(self, language_pair: dict[str:str]) -> None: self.source_language = language_pair["source"] self.target_language = language_pair["target"] - def __call__(self, data_row): + def __call__( + self, data_row: dict[str : Union[str, list]] + ) -> dict[str : Union[str, list]]: + """ + processes the row in the dataset + + Args: + data_row: input row + + Returns: + row : processed row with relevant items + """ source_text = data_row["text"][self.source_language] target_text = data_row["text"][self.target_language] labels = data_row["labels"] @@ -71,30 +90,43 @@ def __call__(self, data_row): return row -def load_multieurlex(level, lang_pair): +def load_multieurlex( + level: int, lang_pair: dict[str:str] +) -> tuple[list[Dataset], dict[str : Union[int, list]]]: + """ + load the multieurlex dataset - assert level in [1, 2, 3], "there are 3 levels of hierarchy: 1,2,3." + Args: + level: level of hierarchy/specicifity of the labels + lang_pair: dictionary specifying the language pair. + Returns: + List of datasets and a dictionary with some metadata information + """ + assert level in [1, 2, 3], "there are 3 levels of hierarchy: 1,2,3." + # format level for the class descriptor dictionary, add these to a list level = f"level_{level}" classes = class_concepts[level] descriptors = [] for class_id in classes: descriptors.append(class_descriptors[class_id]) + # load the dataset with huggingface API data = load_dataset( "multi_eurlex", "all_languages", label_level=level, trust_remote_code=True, ) - + # define metadata meta_data = { "n_classes": len(classes), "class_labels": classes, "class_descriptors": descriptors, } + # instantiate the preprocessor preprocesser = PreProcesser(lang_pair) - + # preprocess each split train_dataset = data["train"].map(preprocesser, remove_columns=["text"]) extracted_train = train_dataset.map( extract_articles, @@ -110,4 +142,5 @@ def load_multieurlex(level, lang_pair): extract_articles, fn_kwargs={"lang_pair": lang_pair}, ) + # return datasets and metadata return [extracted_train, extracted_test, extracted_val], meta_data