diff --git a/.gitignore b/.gitignore index 1285fd4..4cefc4c 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,6 @@ slurm_scripts/slurm_logs* temp .vscode local_notebooks + +# test caches +tests/testdata/*/*/cache* diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 4e0d0d1..0000000 --- a/requirements.txt +++ /dev/null @@ -1,173 +0,0 @@ -absl-py==2.1.0 -aiohappyeyeballs==2.4.0 -aiohttp==3.10.6 -aiosignal==1.3.1 -annotated-types==0.7.0 -appnope==0.1.4 --e git+https://github.com/alan-turing-institute/ARC-SPICE.git@1ae06a2e9bff17854af1aa01cb6d068642b69358#egg=ARC_SPICE -asttokens==2.4.1 -async-timeout==4.0.3 -attrs==24.2.0 -audioread==3.0.1 -blis==1.0.1 -boto3==1.35.26 -botocore==1.35.26 -catalogue==2.0.10 -certifi==2024.8.30 -cffi==1.17.1 -charset-normalizer==3.3.2 -click==8.1.7 -cloudpathlib==0.20.0 -colorama==0.4.6 -comm==0.2.2 -confection==0.1.5 -contourpy==1.3.0 -cycler==0.12.1 -cymem==2.0.8 -datasets==3.0.0 -debugpy==1.8.6 -decorator==5.1.1 -dill==0.3.8 -en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85 -entmax==1.3 -exceptiongroup==1.2.2 -executing==2.1.0 -filelock==3.16.1 -filetype==1.2.0 -fire==0.6.0 -fonttools==4.54.1 -frozenlist==1.4.1 -fsspec==2024.6.1 -gitdb==4.0.11 -GitPython==3.1.43 -grpcio==1.66.1 -huggingface==0.0.1 -huggingface-hub==0.25.1 -hypothesis==6.112.2 -idna==3.7 -ipykernel==6.29.5 -ipython==8.27.0 -jedi==0.19.1 -Jinja2==3.1.4 -jmespath==1.0.1 -joblib==1.4.2 -jsonargparse==3.13.1 -jupyter_client==8.6.3 -jupyter_core==5.7.2 -kenlm==0.2.0 -kiwisolver==1.4.7 -langcodes==3.4.1 -language_data==1.2.0 -lazy_loader==0.4 -librosa==0.10.2.post1 -lightning-utilities==0.11.8 -llvmlite==0.43.0 -lxml==5.3.0 -marisa-trie==1.2.1 -Markdown==3.7 -markdown-it-py==3.0.0 -MarkupSafe==2.1.5 -matplotlib==3.9.2 -matplotlib-inline==0.1.7 -mdurl==0.1.2 -mpmath==1.3.0 -msgpack==1.1.0 -multidict==6.1.0 -multiprocess==0.70.16 -murmurhash==1.0.10 -nest-asyncio==1.6.0 -networkx==3.3 -nltk==3.9.1 -numba==0.60.0 -numpy==1.26.4 -opencv-python==4.9.0.80 -opencv-python-headless==4.10.0.84 -packaging==24.1 -pandas==2.2.3 -parso==0.8.4 -pexpect==4.9.0 -pillow==10.4.0 -platformdirs==4.3.6 -pooch==1.8.2 -portalocker==2.10.1 -preshed==3.0.9 -prompt_toolkit==3.0.48 -protobuf==4.25.5 -psutil==6.0.0 -ptyprocess==0.7.0 -pure_eval==0.2.3 -py-cpuinfo==9.0.0 -pyarrow==17.0.0 -pybboxes==0.1.6 -pycparser==2.22 -pyctcdecode==0.5.0 -pydantic==2.9.2 -pydantic_core==2.23.4 -Pygments==2.18.0 -pygtrie==2.5.0 -pyparsing==3.1.4 -python-dateutil==2.9.0.post0 -python-dotenv==1.0.1 -pytorch-lightning==2.4.0 -pytz==2024.2 -PyYAML==6.0.2 -pyzmq==26.2.0 -regex==2024.9.11 -requests==2.32.3 -requests-toolbelt==1.0.0 -rich==13.9.3 -roboflow==1.1.45 -s3transfer==0.10.2 -sacrebleu==2.4.3 -safetensors==0.4.5 -sahi==0.11.18 -scikit-learn==1.5.2 -scipy==1.14.1 -seaborn==0.13.2 -sentencepiece==0.1.99 -shapely==2.0.6 -shellingham==1.5.4 -six==1.16.0 -smart-open==7.0.5 -smmap==5.0.1 -sortedcontainers==2.4.0 -soundfile==0.12.1 -soxr==0.5.0.post1 -spacy==3.8.2 -spacy-legacy==3.0.12 -spacy-loggers==1.0.5 -srsly==2.4.8 -stack-data==0.6.3 -sympy==1.13.3 -tabulate==0.9.0 -tensorboard==2.17.1 -tensorboard-data-server==0.7.2 -termcolor==2.4.0 -terminaltables==3.1.10 -thinc==8.3.2 -thop==0.1.1.post2209072238 -threadpoolctl==3.5.0 -tokenizers==0.19.1 -torch==2.4.1 -torcheval==0.0.7 -torchmetrics==0.10.3 -torchvision==0.19.1 -tornado==6.4.1 -tqdm==4.66.5 -traitlets==5.14.3 -transformers==4.44.2 -typer==0.12.5 -typing_extensions==4.12.2 -tzdata==2024.2 -ultralytics==8.2.101 -ultralytics-thop==2.0.8 -unbabel-comet==2.2.2 -urllib3==2.2.3 -wasabi==1.1.3 -wcwidth==0.2.13 -weasel==0.4.1 -Werkzeug==3.0.4 -wrapt==1.16.0 -xxhash==3.5.0 -yarl==1.12.1 -yolov5==7.0.13 diff --git a/scripts/create_test_ds.py b/scripts/create_test_ds.py new file mode 100644 index 0000000..488a061 --- /dev/null +++ b/scripts/create_test_ds.py @@ -0,0 +1,128 @@ +import os +import shutil + +import datasets +from datasets import load_dataset + +from arc_spice.data import multieurlex_utils + +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +TESTDATA_DIR = os.path.join(PROJECT_ROOT, "tests/testdata") +BASE_DATASET_INFO_MULTILANG = os.path.join( + TESTDATA_DIR, "base_testdata/dataset_info.json" +) +BASE_DATASET_INFO_EN = os.path.join(TESTDATA_DIR, "base_testdata/dataset_info_en.json") +BASE_DATASET_METADATA_DIR = os.path.join(TESTDATA_DIR, "base_testdata/MultiEURLEX") + +# TODO +CONTENT_MULTILANG: list[dict[str, str]] = [ + { + "en": f"Some text before the marker 1 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 1", # noqa: E501 + "fr": f"Some text before the marker 1 {multieurlex_utils.ARTICLE_1_MARKERS['fr']} Some text after the marker 1", # noqa: E501 + "de": f"Some text before the marker 1 {multieurlex_utils.ARTICLE_1_MARKERS['de']} Some text after the marker 1", # noqa: E501 + }, + { + "en": f"Some text before the marker 2 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 2", # noqa: E501 + "fr": f"Some text before the marker 2 {multieurlex_utils.ARTICLE_1_MARKERS['fr']} Some text after the marker 2", # noqa: E501 + "de": f"Some text before the marker 2 {multieurlex_utils.ARTICLE_1_MARKERS['de']} Some text after the marker 2", # noqa: E501 + }, + { + "en": "Some text before the marker 3", # no marker, no text after marker + "fr": "Some text before the marker 3", + "de": "Some text before the marker 3", + }, + { + "en": f"Some text before the marker 4 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 4", # noqa: E501 + "fr": f"Some text before the marker 4 {multieurlex_utils.ARTICLE_1_MARKERS['fr']} Some text after the marker 4", # noqa: E501 + "de": f"Some text before the marker 4 {multieurlex_utils.ARTICLE_1_MARKERS['de']} Some text after the marker 4", # noqa: E501 + }, + { + "en": f"Some text before the marker 5 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 5", # noqa: E501 + "fr": f"Some text before the marker 5 {multieurlex_utils.ARTICLE_1_MARKERS['fr']} Some text after the marker 5", # noqa: E501 + "de": f"Some text before the marker 5 {multieurlex_utils.ARTICLE_1_MARKERS['de']} Some text after the marker 5", # noqa: E501 + }, +] +CONTENT_EN: list[str] = [ + f"Some text before the marker 1 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 1", # noqa: E501 + f"Some text before the marker 2 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 2", # noqa: E501 + "Some text before the marker 3", # no marker, no text after marker + f"Some text before the marker 4 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 4", # noqa: E501 + f"Some text before the marker 5 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 5", # noqa: E501 +] + + +def overwrite_text( + _orig, + i: int, + content: list[dict[str, str]] | list[str], +) -> dict[str, str | dict[str, str]]: + return {"text": content[i]} + + +def create_test_ds( + testdata_dir: str, + ds_name: str, + content: list[dict[str, str]] | list[str], + dataset_info_fpath: str, +) -> None: + dataset = load_dataset( + "multi_eurlex", + "all_languages", + label_level="level_1", + trust_remote_code=True, + ) + + dataset["train"] = dataset["train"].take(5) + dataset["validation"] = dataset["validation"].take(5) + dataset["test"] = dataset["test"].take(5) + + dataset = dataset.map( + overwrite_text, + with_indices=True, + fn_kwargs={"content": content}, + ) + + dataset.save_to_disk(os.path.join(testdata_dir, ds_name)) + + shutil.copy( + dataset_info_fpath, + os.path.join(testdata_dir, ds_name, "train/dataset_info.json"), + ) + shutil.copy( + dataset_info_fpath, + os.path.join(testdata_dir, ds_name, "validation/dataset_info.json"), + ) + shutil.copy( + dataset_info_fpath, + os.path.join(testdata_dir, ds_name, "test/dataset_info.json"), + ) + # metadata copy + shutil.copytree( + BASE_DATASET_METADATA_DIR, + os.path.join(testdata_dir, ds_name, "MultiEURLEX"), + ) + + assert datasets.load_from_disk(os.path.join(testdata_dir, ds_name)) is not None + + +if __name__ == "__main__": + os.makedirs(TESTDATA_DIR, exist_ok=True) + + content = [ + "Some text before the marker en Some text after the marker", + "Some text before the marker fr Some text after the marker", + ] + + create_test_ds( + testdata_dir=TESTDATA_DIR, + ds_name="multieurlex_test", + content=CONTENT_MULTILANG, + dataset_info_fpath=BASE_DATASET_INFO_MULTILANG, + ) + + create_test_ds( + testdata_dir=TESTDATA_DIR, + ds_name="multieurlex_test_en", + content=CONTENT_EN, + dataset_info_fpath=BASE_DATASET_INFO_EN, + ) diff --git a/scripts/variational_RTC_example.py b/scripts/variational_RTC_example.py index 5c615f1..fdeea22 100644 --- a/scripts/variational_RTC_example.py +++ b/scripts/variational_RTC_example.py @@ -3,64 +3,39 @@ """ import logging -import os -import random from random import randint -import numpy as np import torch from torch.nn.functional import binary_cross_entropy -from arc_spice.data.multieurlex_utils import MultiHot, load_multieurlex +from arc_spice.data.multieurlex_utils import MultiHot, load_multieurlex_for_translation from arc_spice.eval.classification_error import hamming_accuracy from arc_spice.eval.translation_error import get_comet_model +from arc_spice.utils import seed_everything from arc_spice.variational_pipelines.RTC_variational_pipeline import ( RTCVariationalPipeline, ) -def seed_everything(seed): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - np.random.seed(seed) - random.seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) +def get_random_test_row(train_data): + row_iterator = iter(train_data) + for _ in range(randint(1, 25)): + test_row = next(row_iterator) + return test_row def load_test_row(): lang_pair = {"source": "fr", "target": "en"} - (train, _, _), metadata_params = load_multieurlex( + dataset_dict, metadata_params = load_multieurlex_for_translation( data_dir="data", level=1, lang_pair=lang_pair ) + train = dataset_dict["train"] multi_onehot = MultiHot(metadata_params["n_classes"]) - test_row = get_test_row(train) - class_labels = multi_onehot(test_row["class_labels"]) + test_row = get_random_test_row(train) + class_labels = multi_onehot(test_row["labels"]) return test_row, class_labels, metadata_params -def get_test_row(train_data): - row_iterator = iter(train_data) - for _ in range(randint(1, 25)): - test_row = next(row_iterator) - - # debug row if needed - return { - "source_text": ( - "Le renard brun rapide a sauté par-dessus le chien paresseux." - "Le renard a sauté par-dessus le chien paresseux." - ), - "target_text": ( - "The quick brown fox jumped over the lazy dog. The fox jumped" - " over the lazy dog" - ), - "class_labels": [0, 1], - } - # Normal row - return test_row - - def print_results(clean_output, var_output, class_labels, test_row, comet_model): # ### TRANSLATION ### print("\nTranslation:") diff --git a/src/arc_spice/data/multieurlex_utils.py b/src/arc_spice/data/multieurlex_utils.py index bbdedaa..e38fdab 100644 --- a/src/arc_spice/data/multieurlex_utils.py +++ b/src/arc_spice/data/multieurlex_utils.py @@ -1,12 +1,17 @@ import json +from typing import Any import torch -from datasets import load_dataset +from datasets import DatasetDict, load_dataset from datasets.formatting.formatting import LazyRow from torch.nn.functional import one_hot # For identifying where the adopted decisions begin -ARTICLE_1_MARKERS = {"en": "\nArticle 1\n", "fr": "\nArticle premier\n"} +ARTICLE_1_MARKERS = { + "en": "\nArticle 1\n", + "fr": "\nArticle premier\n", + "de": "\nArtikel 1\n", +} # creates a multi-hot vector for classification loss @@ -34,23 +39,32 @@ def _extract_articles(text: str, article_1_marker: str): return text[start:] -def extract_articles(item: LazyRow, lang_pair: dict[str, str]): - lang_source = lang_pair["source"] - lang_target = lang_pair["target"] +def extract_articles( + item: LazyRow, languages: list[str] +) -> dict[str, str] | dict[str, dict[str, str]]: + # single lang has different structure that isn't nested + if len(languages) == 1 and isinstance(item["text"], str): + return { + "text": _extract_articles( + text=item["text"], + article_1_marker=ARTICLE_1_MARKERS[languages[0]], + ) + } + + # else return { - "source_text": _extract_articles( - text=item["source_text"], - article_1_marker=ARTICLE_1_MARKERS[lang_source], - ), - "target_text": _extract_articles( - text=item["target_text"], - article_1_marker=ARTICLE_1_MARKERS[lang_target], - ), + "text": { + lang: _extract_articles( + text=item["text"][lang], + article_1_marker=ARTICLE_1_MARKERS[lang], + ) + for lang in languages + } } -class PreProcesser: - """Function to preprocess the data, for the purposes of removing unused languages""" +class TranslationPreProcesser: + """Prepares the data for the translation task""" def __init__(self, language_pair: dict[str, str]) -> None: self.source_language = language_pair["source"] @@ -70,28 +84,13 @@ def __call__( """ source_text = data_row["text"][self.source_language] target_text = data_row["text"][self.target_language] - labels = data_row["labels"] return { "source_text": source_text, "target_text": target_text, - "class_labels": labels, } -def load_multieurlex( - data_dir: str, level: int, lang_pair: dict[str, str] -) -> tuple[list, dict[str, int | list]]: - """ - load the multieurlex dataset - - Args: - data_dir: root directory for the dataset class descriptors and concepts - level: level of hierarchy/specicifity of the labels - lang_pair: dictionary specifying the language pair. - - Returns: - List of datasets and a dictionary with some metadata information - """ +def load_mutlieurlex_metadata(data_dir: str, level: int) -> dict[str, Any]: assert level in [1, 2, 3], "there are 3 levels of hierarchy: 1,2,3." with open(f"{data_dir}/MultiEURLEX/data/eurovoc_concepts.json") as concepts_file: class_concepts = json.loads(concepts_file.read()) @@ -103,35 +102,78 @@ def load_multieurlex( class_descriptors = json.loads(descriptors_file.read()) descriptors_file.close() # format level for the class descriptor dictionary, add these to a list - classes = class_concepts[level] + classes = class_concepts[f"level_{level}"] descriptors = [] for class_id in classes: descriptors.append(class_descriptors[class_id]) - # load the dataset with huggingface API - data = load_dataset( - "multi_eurlex", - "all_languages", - label_level=f"level_{level}", - trust_remote_code=True, - ) # define metadata - meta_data = { + return { "n_classes": len(classes), "class_labels": classes, "class_descriptors": descriptors, } - # instantiate the preprocessor - preprocesser = PreProcesser(lang_pair) - # preprocess each split - dataset = data.map(preprocesser, remove_columns=["text"]) - extracted_dataset = dataset.map( + + +def load_multieurlex( + data_dir: str, + level: int, + languages: list[str], + drop_empty: bool = True, +) -> tuple[DatasetDict, dict[str, Any]]: + """ + load the multieurlex dataset + + Args: + data_dir: root directory for the dataset class descriptors and concepts + level: level of hierarchy/specicifity of the labels + languages: a list of iso codes for languages to be used + + Returns: + List of datasets and a dictionary with some metadata information + """ + metadata = load_mutlieurlex_metadata(data_dir=data_dir, level=level) + + # load the dataset with huggingface API + if isinstance(languages, list): + if len(languages) == 0: + msg = "languages list cannot be empty" + raise Exception(msg) + + load_langs = languages[0] if len(languages) == 1 else "all_languages" + + dataset_dict = load_dataset( + "multi_eurlex", + load_langs, + label_level=f"level_{level}", + trust_remote_code=True, + ) + + dataset_dict = dataset_dict.map( extract_articles, - fn_kwargs={"lang_pair": lang_pair}, + fn_kwargs={"languages": languages}, ) + + if drop_empty: + if len(languages) == 1: + dataset_dict = dataset_dict.filter(lambda x: x["text"] is not None) + else: + dataset_dict = dataset_dict.filter( + lambda x: all(x is not None for x in x["text"].values()) + ) + # return datasets and metadata - return [ - extracted_dataset["train"], - extracted_dataset["test"], - extracted_dataset["validation"], - ], meta_data + return dataset_dict, metadata + + +def load_multieurlex_for_translation( + data_dir: str, level: int, lang_pair: dict[str, str], drop_empty: bool = True +) -> tuple[DatasetDict, dict[str, Any]]: + langs = [lang_pair["source"], lang_pair["target"]] + dataset_dict, meta_data = load_multieurlex( + data_dir=data_dir, level=level, languages=langs, drop_empty=drop_empty + ) + # instantiate the preprocessor + preprocesser = TranslationPreProcesser(lang_pair) + # preprocess each split + return dataset_dict.map(preprocesser, remove_columns=["text"]), meta_data diff --git a/src/arc_spice/utils.py b/src/arc_spice/utils.py new file mode 100644 index 0000000..d3430c1 --- /dev/null +++ b/src/arc_spice/utils.py @@ -0,0 +1,15 @@ +import os +import random + +import numpy as np +import torch + + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) diff --git a/tests/test_multieurlex_utils.py b/tests/test_multieurlex_utils.py new file mode 100644 index 0000000..4701f51 --- /dev/null +++ b/tests/test_multieurlex_utils.py @@ -0,0 +1,98 @@ +import os +from unittest.mock import patch + +import datasets +import pyarrow as pa +from datasets.formatting import PythonFormatter +from datasets.formatting.formatting import LazyRow + +from arc_spice.data import multieurlex_utils + +TEST_ROOT = os.path.dirname(os.path.abspath(__file__)) + + +def _create_row(text) -> LazyRow: + pa_table = pa.Table.from_pydict({"text": [text]}) + formatter = PythonFormatter(lazy=True) + return formatter.format_row(pa_table) + + +def _create_multilang_row(texts_by_lang: dict[str, str]) -> LazyRow: + d = [{"text": texts_by_lang}] + pa_table = pa.Table.from_pylist(d) + formatter = PythonFormatter(lazy=True) + return formatter.format_row(pa_table) + + +def test_extract_articles_single_lang(): + langs = ["en"] + pre_text = "Some text before the marker" + post_text = "Some text after the marker" + row = _create_row( + text=f"{pre_text} {multieurlex_utils.ARTICLE_1_MARKERS['en']} {post_text}" + ) + out = multieurlex_utils.extract_articles(item=row, languages=langs) + assert out == {"text": f"{multieurlex_utils.ARTICLE_1_MARKERS['en']} {post_text}"} + + +def test_extract_articles_multi_lang(): + langs = ["en", "fr"] + pre_text = "Some text before the marker" + post_text = "Some text after the marker" + texts = { + lang: f"{pre_text} {multieurlex_utils.ARTICLE_1_MARKERS[lang]} {post_text}" + for lang in langs + } + row = _create_multilang_row(texts_by_lang=texts) + out = multieurlex_utils.extract_articles(item=row, languages=langs) + assert out == { + "text": { + "en": f"{multieurlex_utils.ARTICLE_1_MARKERS['en']} {post_text}", + "fr": f"{multieurlex_utils.ARTICLE_1_MARKERS['fr']} {post_text}", + } + } + + +def test_load_multieurlex_en(): + data_dir = f"{TEST_ROOT}/testdata/multieurlex_test_en" + level = 1 + languages = ["en"] + drop_empty = True + + ds = datasets.load_from_disk(data_dir) + with patch("arc_spice.data.multieurlex_utils.load_dataset", return_value=ds): + dataset_dict, metadata = multieurlex_utils.load_multieurlex( + data_dir=data_dir, level=level, languages=languages, drop_empty=drop_empty + ) + assert len(dataset_dict) == 3 + assert len(dataset_dict["train"]) == 4 # 5 items, 1 is empty so dropped + assert len(dataset_dict["validation"]) == 4 # 5 items, 1 is empty so dropped + assert len(dataset_dict["test"]) == 4 # 5 items, 1 is empty so dropped + assert dataset_dict["train"]["text"] == [ + f"{multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker {i}" # noqa: E501 + for i in [1, 2, 4, 5] # 3 dropped + ] + + +def test_load_multieurlex_for_translation(): + data_dir = f"{TEST_ROOT}/testdata/multieurlex_test" + level = 1 + languages = ["de", "en", "fr"] + drop_empty = True + + ds = datasets.load_from_disk(data_dir) + with patch("arc_spice.data.multieurlex_utils.load_dataset", return_value=ds): + dataset_dict, metadata = multieurlex_utils.load_multieurlex( + data_dir=data_dir, level=level, languages=languages, drop_empty=drop_empty + ) + assert len(dataset_dict) == 3 + assert len(dataset_dict["train"]) == 4 # 5 items, 1 is empty so dropped + assert len(dataset_dict["validation"]) == 4 # 5 items, 1 is empty so dropped + assert len(dataset_dict["test"]) == 4 # 5 items, 1 is empty so dropped + assert dataset_dict["train"]["text"] == [ # + { + lang: f"{multieurlex_utils.ARTICLE_1_MARKERS[lang]} Some text after the marker {i}" # noqa: E501 + for lang in languages + } + for i in [1, 2, 4, 5] # 3 dropped + ] diff --git a/tests/testdata/base_testdata/MultiEURLEX/data/eurovoc_concepts.json b/tests/testdata/base_testdata/MultiEURLEX/data/eurovoc_concepts.json new file mode 100644 index 0000000..8a6033f --- /dev/null +++ b/tests/testdata/base_testdata/MultiEURLEX/data/eurovoc_concepts.json @@ -0,0 +1,5 @@ +{ + "level_1": [ + "0" + ] +} diff --git a/tests/testdata/base_testdata/MultiEURLEX/data/eurovoc_descriptors.json b/tests/testdata/base_testdata/MultiEURLEX/data/eurovoc_descriptors.json new file mode 100644 index 0000000..3159b9b --- /dev/null +++ b/tests/testdata/base_testdata/MultiEURLEX/data/eurovoc_descriptors.json @@ -0,0 +1,5 @@ +{ + "0": { + "en": "something" + } +} diff --git a/tests/testdata/base_testdata/dataset_info.json b/tests/testdata/base_testdata/dataset_info.json new file mode 100644 index 0000000..ddf9a0d --- /dev/null +++ b/tests/testdata/base_testdata/dataset_info.json @@ -0,0 +1,72 @@ +{ + "builder_name": "multi_eurlex_test", + "config_name": "all_languages", + "dataset_name": "multi_eurlex", + "features": { + "celex_id": { + "dtype": "string", + "_type": "Value" + }, + "text": { + "languages": [ + "en", + "de", + "fr" + ], + "_type": "Translation" + }, + "labels": { + "feature": { + "names": [ + "100149", + "100160", + "100148", + "100147", + "100152", + "100143", + "100156", + "100158", + "100154", + "100153", + "100142", + "100145", + "100150", + "100162", + "100159", + "100144", + "100151", + "100157", + "100161", + "100146", + "100155" + ], + "_type": "ClassLabel" + }, + "_type": "Sequence" + } + }, + "splits": { + "train": { + "name": "train", + "num_examples": 5, + "dataset_name": "multi_eurlex_test" + }, + "test": { + "name": "test", + "num_examples": 5, + "dataset_name": "multi_eurlex_test" + }, + "validation": { + "name": "validation", + "num_examples": 5, + "dataset_name": "multi_eurlex_test" + } + }, + "version": { + "version_str": "1.0.0", + "description": "", + "major": 1, + "minor": 0, + "patch": 0 + } +} diff --git a/tests/testdata/base_testdata/dataset_info_en.json b/tests/testdata/base_testdata/dataset_info_en.json new file mode 100644 index 0000000..f55d6bd --- /dev/null +++ b/tests/testdata/base_testdata/dataset_info_en.json @@ -0,0 +1,68 @@ +{ + "builder_name": "multi_eurlex_test_en", + "config_name": "en", + "dataset_name": "multi_eurlex_test_en", + "features": { + "celex_id": { + "dtype": "string", + "_type": "Value" + }, + "text": { + "dtype": "string", + "_type": "Value" + }, + "labels": { + "feature": { + "names": [ + "100149", + "100160", + "100148", + "100147", + "100152", + "100143", + "100156", + "100158", + "100154", + "100153", + "100142", + "100145", + "100150", + "100162", + "100159", + "100144", + "100151", + "100157", + "100161", + "100146", + "100155" + ], + "_type": "ClassLabel" + }, + "_type": "Sequence" + } + }, + "splits": { + "train": { + "name": "train", + "num_examples": 5, + "dataset_name": "multi_eurlex_test_en" + }, + "test": { + "name": "test", + "num_examples": 5, + "dataset_name": "multi_eurlex_test_en" + }, + "validation": { + "name": "validation", + "num_examples": 5, + "dataset_name": "multi_eurlex_test_en" + } + }, + "version": { + "version_str": "1.0.0", + "description": "", + "major": 1, + "minor": 0, + "patch": 0 + } +} diff --git a/tests/testdata/multieurlex_test/MultiEURLEX/data/eurovoc_concepts.json b/tests/testdata/multieurlex_test/MultiEURLEX/data/eurovoc_concepts.json new file mode 100644 index 0000000..8a6033f --- /dev/null +++ b/tests/testdata/multieurlex_test/MultiEURLEX/data/eurovoc_concepts.json @@ -0,0 +1,5 @@ +{ + "level_1": [ + "0" + ] +} diff --git a/tests/testdata/multieurlex_test/MultiEURLEX/data/eurovoc_descriptors.json b/tests/testdata/multieurlex_test/MultiEURLEX/data/eurovoc_descriptors.json new file mode 100644 index 0000000..3159b9b --- /dev/null +++ b/tests/testdata/multieurlex_test/MultiEURLEX/data/eurovoc_descriptors.json @@ -0,0 +1,5 @@ +{ + "0": { + "en": "something" + } +} diff --git a/tests/testdata/multieurlex_test/dataset_dict.json b/tests/testdata/multieurlex_test/dataset_dict.json new file mode 100644 index 0000000..f15a9f8 --- /dev/null +++ b/tests/testdata/multieurlex_test/dataset_dict.json @@ -0,0 +1 @@ +{"splits": ["train", "test", "validation"]} diff --git a/tests/testdata/multieurlex_test/test/data-00000-of-00001.arrow b/tests/testdata/multieurlex_test/test/data-00000-of-00001.arrow new file mode 100644 index 0000000..653d4f2 Binary files /dev/null and b/tests/testdata/multieurlex_test/test/data-00000-of-00001.arrow differ diff --git a/tests/testdata/multieurlex_test/test/dataset_info.json b/tests/testdata/multieurlex_test/test/dataset_info.json new file mode 100644 index 0000000..ddf9a0d --- /dev/null +++ b/tests/testdata/multieurlex_test/test/dataset_info.json @@ -0,0 +1,72 @@ +{ + "builder_name": "multi_eurlex_test", + "config_name": "all_languages", + "dataset_name": "multi_eurlex", + "features": { + "celex_id": { + "dtype": "string", + "_type": "Value" + }, + "text": { + "languages": [ + "en", + "de", + "fr" + ], + "_type": "Translation" + }, + "labels": { + "feature": { + "names": [ + "100149", + "100160", + "100148", + "100147", + "100152", + "100143", + "100156", + "100158", + "100154", + "100153", + "100142", + "100145", + "100150", + "100162", + "100159", + "100144", + "100151", + "100157", + "100161", + "100146", + "100155" + ], + "_type": "ClassLabel" + }, + "_type": "Sequence" + } + }, + "splits": { + "train": { + "name": "train", + "num_examples": 5, + "dataset_name": "multi_eurlex_test" + }, + "test": { + "name": "test", + "num_examples": 5, + "dataset_name": "multi_eurlex_test" + }, + "validation": { + "name": "validation", + "num_examples": 5, + "dataset_name": "multi_eurlex_test" + } + }, + "version": { + "version_str": "1.0.0", + "description": "", + "major": 1, + "minor": 0, + "patch": 0 + } +} diff --git a/tests/testdata/multieurlex_test/test/state.json b/tests/testdata/multieurlex_test/test/state.json new file mode 100644 index 0000000..7820d12 --- /dev/null +++ b/tests/testdata/multieurlex_test/test/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "339bb159123fd537", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": "test" +} diff --git a/tests/testdata/multieurlex_test/train/data-00000-of-00001.arrow b/tests/testdata/multieurlex_test/train/data-00000-of-00001.arrow new file mode 100644 index 0000000..62bd2af Binary files /dev/null and b/tests/testdata/multieurlex_test/train/data-00000-of-00001.arrow differ diff --git a/tests/testdata/multieurlex_test/train/dataset_info.json b/tests/testdata/multieurlex_test/train/dataset_info.json new file mode 100644 index 0000000..ddf9a0d --- /dev/null +++ b/tests/testdata/multieurlex_test/train/dataset_info.json @@ -0,0 +1,72 @@ +{ + "builder_name": "multi_eurlex_test", + "config_name": "all_languages", + "dataset_name": "multi_eurlex", + "features": { + "celex_id": { + "dtype": "string", + "_type": "Value" + }, + "text": { + "languages": [ + "en", + "de", + "fr" + ], + "_type": "Translation" + }, + "labels": { + "feature": { + "names": [ + "100149", + "100160", + "100148", + "100147", + "100152", + "100143", + "100156", + "100158", + "100154", + "100153", + "100142", + "100145", + "100150", + "100162", + "100159", + "100144", + "100151", + "100157", + "100161", + "100146", + "100155" + ], + "_type": "ClassLabel" + }, + "_type": "Sequence" + } + }, + "splits": { + "train": { + "name": "train", + "num_examples": 5, + "dataset_name": "multi_eurlex_test" + }, + "test": { + "name": "test", + "num_examples": 5, + "dataset_name": "multi_eurlex_test" + }, + "validation": { + "name": "validation", + "num_examples": 5, + "dataset_name": "multi_eurlex_test" + } + }, + "version": { + "version_str": "1.0.0", + "description": "", + "major": 1, + "minor": 0, + "patch": 0 + } +} diff --git a/tests/testdata/multieurlex_test/train/state.json b/tests/testdata/multieurlex_test/train/state.json new file mode 100644 index 0000000..43f32db --- /dev/null +++ b/tests/testdata/multieurlex_test/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "95b517d0a3072460", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": "train" +} diff --git a/tests/testdata/multieurlex_test/validation/data-00000-of-00001.arrow b/tests/testdata/multieurlex_test/validation/data-00000-of-00001.arrow new file mode 100644 index 0000000..fe90448 Binary files /dev/null and b/tests/testdata/multieurlex_test/validation/data-00000-of-00001.arrow differ diff --git a/tests/testdata/multieurlex_test/validation/dataset_info.json b/tests/testdata/multieurlex_test/validation/dataset_info.json new file mode 100644 index 0000000..ddf9a0d --- /dev/null +++ b/tests/testdata/multieurlex_test/validation/dataset_info.json @@ -0,0 +1,72 @@ +{ + "builder_name": "multi_eurlex_test", + "config_name": "all_languages", + "dataset_name": "multi_eurlex", + "features": { + "celex_id": { + "dtype": "string", + "_type": "Value" + }, + "text": { + "languages": [ + "en", + "de", + "fr" + ], + "_type": "Translation" + }, + "labels": { + "feature": { + "names": [ + "100149", + "100160", + "100148", + "100147", + "100152", + "100143", + "100156", + "100158", + "100154", + "100153", + "100142", + "100145", + "100150", + "100162", + "100159", + "100144", + "100151", + "100157", + "100161", + "100146", + "100155" + ], + "_type": "ClassLabel" + }, + "_type": "Sequence" + } + }, + "splits": { + "train": { + "name": "train", + "num_examples": 5, + "dataset_name": "multi_eurlex_test" + }, + "test": { + "name": "test", + "num_examples": 5, + "dataset_name": "multi_eurlex_test" + }, + "validation": { + "name": "validation", + "num_examples": 5, + "dataset_name": "multi_eurlex_test" + } + }, + "version": { + "version_str": "1.0.0", + "description": "", + "major": 1, + "minor": 0, + "patch": 0 + } +} diff --git a/tests/testdata/multieurlex_test/validation/state.json b/tests/testdata/multieurlex_test/validation/state.json new file mode 100644 index 0000000..ee9c01e --- /dev/null +++ b/tests/testdata/multieurlex_test/validation/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "71189a1028d8d0fe", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": "validation" +} diff --git a/tests/testdata/multieurlex_test_en/MultiEURLEX/data/eurovoc_concepts.json b/tests/testdata/multieurlex_test_en/MultiEURLEX/data/eurovoc_concepts.json new file mode 100644 index 0000000..8a6033f --- /dev/null +++ b/tests/testdata/multieurlex_test_en/MultiEURLEX/data/eurovoc_concepts.json @@ -0,0 +1,5 @@ +{ + "level_1": [ + "0" + ] +} diff --git a/tests/testdata/multieurlex_test_en/MultiEURLEX/data/eurovoc_descriptors.json b/tests/testdata/multieurlex_test_en/MultiEURLEX/data/eurovoc_descriptors.json new file mode 100644 index 0000000..3159b9b --- /dev/null +++ b/tests/testdata/multieurlex_test_en/MultiEURLEX/data/eurovoc_descriptors.json @@ -0,0 +1,5 @@ +{ + "0": { + "en": "something" + } +} diff --git a/tests/testdata/multieurlex_test_en/dataset_dict.json b/tests/testdata/multieurlex_test_en/dataset_dict.json new file mode 100644 index 0000000..f15a9f8 --- /dev/null +++ b/tests/testdata/multieurlex_test_en/dataset_dict.json @@ -0,0 +1 @@ +{"splits": ["train", "test", "validation"]} diff --git a/tests/testdata/multieurlex_test_en/test/data-00000-of-00001.arrow b/tests/testdata/multieurlex_test_en/test/data-00000-of-00001.arrow new file mode 100644 index 0000000..af6087a Binary files /dev/null and b/tests/testdata/multieurlex_test_en/test/data-00000-of-00001.arrow differ diff --git a/tests/testdata/multieurlex_test_en/test/dataset_info.json b/tests/testdata/multieurlex_test_en/test/dataset_info.json new file mode 100644 index 0000000..f55d6bd --- /dev/null +++ b/tests/testdata/multieurlex_test_en/test/dataset_info.json @@ -0,0 +1,68 @@ +{ + "builder_name": "multi_eurlex_test_en", + "config_name": "en", + "dataset_name": "multi_eurlex_test_en", + "features": { + "celex_id": { + "dtype": "string", + "_type": "Value" + }, + "text": { + "dtype": "string", + "_type": "Value" + }, + "labels": { + "feature": { + "names": [ + "100149", + "100160", + "100148", + "100147", + "100152", + "100143", + "100156", + "100158", + "100154", + "100153", + "100142", + "100145", + "100150", + "100162", + "100159", + "100144", + "100151", + "100157", + "100161", + "100146", + "100155" + ], + "_type": "ClassLabel" + }, + "_type": "Sequence" + } + }, + "splits": { + "train": { + "name": "train", + "num_examples": 5, + "dataset_name": "multi_eurlex_test_en" + }, + "test": { + "name": "test", + "num_examples": 5, + "dataset_name": "multi_eurlex_test_en" + }, + "validation": { + "name": "validation", + "num_examples": 5, + "dataset_name": "multi_eurlex_test_en" + } + }, + "version": { + "version_str": "1.0.0", + "description": "", + "major": 1, + "minor": 0, + "patch": 0 + } +} diff --git a/tests/testdata/multieurlex_test_en/test/state.json b/tests/testdata/multieurlex_test_en/test/state.json new file mode 100644 index 0000000..fd06043 --- /dev/null +++ b/tests/testdata/multieurlex_test_en/test/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "78e077eaa0509910", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": "test" +} diff --git a/tests/testdata/multieurlex_test_en/train/data-00000-of-00001.arrow b/tests/testdata/multieurlex_test_en/train/data-00000-of-00001.arrow new file mode 100644 index 0000000..1834103 Binary files /dev/null and b/tests/testdata/multieurlex_test_en/train/data-00000-of-00001.arrow differ diff --git a/tests/testdata/multieurlex_test_en/train/dataset_info.json b/tests/testdata/multieurlex_test_en/train/dataset_info.json new file mode 100644 index 0000000..f55d6bd --- /dev/null +++ b/tests/testdata/multieurlex_test_en/train/dataset_info.json @@ -0,0 +1,68 @@ +{ + "builder_name": "multi_eurlex_test_en", + "config_name": "en", + "dataset_name": "multi_eurlex_test_en", + "features": { + "celex_id": { + "dtype": "string", + "_type": "Value" + }, + "text": { + "dtype": "string", + "_type": "Value" + }, + "labels": { + "feature": { + "names": [ + "100149", + "100160", + "100148", + "100147", + "100152", + "100143", + "100156", + "100158", + "100154", + "100153", + "100142", + "100145", + "100150", + "100162", + "100159", + "100144", + "100151", + "100157", + "100161", + "100146", + "100155" + ], + "_type": "ClassLabel" + }, + "_type": "Sequence" + } + }, + "splits": { + "train": { + "name": "train", + "num_examples": 5, + "dataset_name": "multi_eurlex_test_en" + }, + "test": { + "name": "test", + "num_examples": 5, + "dataset_name": "multi_eurlex_test_en" + }, + "validation": { + "name": "validation", + "num_examples": 5, + "dataset_name": "multi_eurlex_test_en" + } + }, + "version": { + "version_str": "1.0.0", + "description": "", + "major": 1, + "minor": 0, + "patch": 0 + } +} diff --git a/tests/testdata/multieurlex_test_en/train/state.json b/tests/testdata/multieurlex_test_en/train/state.json new file mode 100644 index 0000000..3fea560 --- /dev/null +++ b/tests/testdata/multieurlex_test_en/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "8ae6be9ed4cd9e2d", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": "train" +} diff --git a/tests/testdata/multieurlex_test_en/validation/data-00000-of-00001.arrow b/tests/testdata/multieurlex_test_en/validation/data-00000-of-00001.arrow new file mode 100644 index 0000000..ec81ae5 Binary files /dev/null and b/tests/testdata/multieurlex_test_en/validation/data-00000-of-00001.arrow differ diff --git a/tests/testdata/multieurlex_test_en/validation/dataset_info.json b/tests/testdata/multieurlex_test_en/validation/dataset_info.json new file mode 100644 index 0000000..f55d6bd --- /dev/null +++ b/tests/testdata/multieurlex_test_en/validation/dataset_info.json @@ -0,0 +1,68 @@ +{ + "builder_name": "multi_eurlex_test_en", + "config_name": "en", + "dataset_name": "multi_eurlex_test_en", + "features": { + "celex_id": { + "dtype": "string", + "_type": "Value" + }, + "text": { + "dtype": "string", + "_type": "Value" + }, + "labels": { + "feature": { + "names": [ + "100149", + "100160", + "100148", + "100147", + "100152", + "100143", + "100156", + "100158", + "100154", + "100153", + "100142", + "100145", + "100150", + "100162", + "100159", + "100144", + "100151", + "100157", + "100161", + "100146", + "100155" + ], + "_type": "ClassLabel" + }, + "_type": "Sequence" + } + }, + "splits": { + "train": { + "name": "train", + "num_examples": 5, + "dataset_name": "multi_eurlex_test_en" + }, + "test": { + "name": "test", + "num_examples": 5, + "dataset_name": "multi_eurlex_test_en" + }, + "validation": { + "name": "validation", + "num_examples": 5, + "dataset_name": "multi_eurlex_test_en" + } + }, + "version": { + "version_str": "1.0.0", + "description": "", + "major": 1, + "minor": 0, + "patch": 0 + } +} diff --git a/tests/testdata/multieurlex_test_en/validation/state.json b/tests/testdata/multieurlex_test_en/validation/state.json new file mode 100644 index 0000000..654a15e --- /dev/null +++ b/tests/testdata/multieurlex_test_en/validation/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "06c8600442860ac7", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": "validation" +}