Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow non-translation tasks #25

Merged
merged 10 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,6 @@ slurm_scripts/slurm_logs*
temp
.vscode
local_notebooks

# test caches
tests/testdata/*/*/cache*
173 changes: 0 additions & 173 deletions requirements.txt

This file was deleted.

128 changes: 128 additions & 0 deletions scripts/create_test_ds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os
import shutil

import datasets
from datasets import load_dataset

from arc_spice.data import multieurlex_utils

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TESTDATA_DIR = os.path.join(PROJECT_ROOT, "tests/testdata")
BASE_DATASET_INFO_MULTILANG = os.path.join(
TESTDATA_DIR, "base_testdata/dataset_info.json"
)
BASE_DATASET_INFO_EN = os.path.join(TESTDATA_DIR, "base_testdata/dataset_info_en.json")
BASE_DATASET_METADATA_DIR = os.path.join(TESTDATA_DIR, "base_testdata/MultiEURLEX")

# TODO
CONTENT_MULTILANG: list[dict[str, str]] = [
{
"en": f"Some text before the marker 1 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 1", # noqa: E501
"fr": f"Some text before the marker 1 {multieurlex_utils.ARTICLE_1_MARKERS['fr']} Some text after the marker 1", # noqa: E501
"de": f"Some text before the marker 1 {multieurlex_utils.ARTICLE_1_MARKERS['de']} Some text after the marker 1", # noqa: E501
},
{
"en": f"Some text before the marker 2 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 2", # noqa: E501
"fr": f"Some text before the marker 2 {multieurlex_utils.ARTICLE_1_MARKERS['fr']} Some text after the marker 2", # noqa: E501
"de": f"Some text before the marker 2 {multieurlex_utils.ARTICLE_1_MARKERS['de']} Some text after the marker 2", # noqa: E501
},
{
"en": "Some text before the marker 3", # no marker, no text after marker
"fr": "Some text before the marker 3",
"de": "Some text before the marker 3",
},
{
"en": f"Some text before the marker 4 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 4", # noqa: E501
"fr": f"Some text before the marker 4 {multieurlex_utils.ARTICLE_1_MARKERS['fr']} Some text after the marker 4", # noqa: E501
"de": f"Some text before the marker 4 {multieurlex_utils.ARTICLE_1_MARKERS['de']} Some text after the marker 4", # noqa: E501
},
{
"en": f"Some text before the marker 5 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 5", # noqa: E501
"fr": f"Some text before the marker 5 {multieurlex_utils.ARTICLE_1_MARKERS['fr']} Some text after the marker 5", # noqa: E501
"de": f"Some text before the marker 5 {multieurlex_utils.ARTICLE_1_MARKERS['de']} Some text after the marker 5", # noqa: E501
},
]
CONTENT_EN: list[str] = [
f"Some text before the marker 1 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 1", # noqa: E501
f"Some text before the marker 2 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 2", # noqa: E501
"Some text before the marker 3", # no marker, no text after marker
f"Some text before the marker 4 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 4", # noqa: E501
f"Some text before the marker 5 {multieurlex_utils.ARTICLE_1_MARKERS['en']} Some text after the marker 5", # noqa: E501
]


def overwrite_text(
_orig,
i: int,
content: list[dict[str, str]] | list[str],
) -> dict[str, str | dict[str, str]]:
return {"text": content[i]}


def create_test_ds(
testdata_dir: str,
ds_name: str,
content: list[dict[str, str]] | list[str],
dataset_info_fpath: str,
) -> None:
dataset = load_dataset(
"multi_eurlex",
"all_languages",
label_level="level_1",
trust_remote_code=True,
)

dataset["train"] = dataset["train"].take(5)
dataset["validation"] = dataset["validation"].take(5)
dataset["test"] = dataset["test"].take(5)

dataset = dataset.map(
overwrite_text,
with_indices=True,
fn_kwargs={"content": content},
)

dataset.save_to_disk(os.path.join(testdata_dir, ds_name))

shutil.copy(
dataset_info_fpath,
os.path.join(testdata_dir, ds_name, "train/dataset_info.json"),
)
shutil.copy(
dataset_info_fpath,
os.path.join(testdata_dir, ds_name, "validation/dataset_info.json"),
)
shutil.copy(
dataset_info_fpath,
os.path.join(testdata_dir, ds_name, "test/dataset_info.json"),
)
# metadata copy
shutil.copytree(
BASE_DATASET_METADATA_DIR,
os.path.join(testdata_dir, ds_name, "MultiEURLEX"),
)

assert datasets.load_from_disk(os.path.join(testdata_dir, ds_name)) is not None


if __name__ == "__main__":
os.makedirs(TESTDATA_DIR, exist_ok=True)

content = [
"Some text before the marker en Some text after the marker",
"Some text before the marker fr Some text after the marker",
]

create_test_ds(
testdata_dir=TESTDATA_DIR,
ds_name="multieurlex_test",
content=CONTENT_MULTILANG,
dataset_info_fpath=BASE_DATASET_INFO_MULTILANG,
)

create_test_ds(
testdata_dir=TESTDATA_DIR,
ds_name="multieurlex_test_en",
content=CONTENT_EN,
dataset_info_fpath=BASE_DATASET_INFO_EN,
)
28 changes: 8 additions & 20 deletions scripts/variational_RTC_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,33 @@
"""

import logging
import os
import random
from random import randint

import numpy as np
import torch
from torch.nn.functional import binary_cross_entropy

from arc_spice.data.multieurlex_utils import MultiHot, load_multieurlex
from arc_spice.data.multieurlex_utils import MultiHot, load_multieurlex_for_translation
from arc_spice.eval.classification_error import hamming_accuracy
from arc_spice.eval.translation_error import get_comet_model
from arc_spice.utils import seed_everything
from arc_spice.variational_pipelines.RTC_variational_pipeline import (
RTCVariationalPipeline,
)


def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)


def load_test_row():
lang_pair = {"source": "fr", "target": "en"}
(train, _, _), metadata_params = load_multieurlex(
dataset_dict, metadata_params = load_multieurlex_for_translation(
data_dir="data", level=1, lang_pair=lang_pair
)
train = dataset_dict["train"]
multi_onehot = MultiHot(metadata_params["n_classes"])
test_row = get_test_row(train)
class_labels = multi_onehot(test_row["class_labels"])
return test_row, class_labels, metadata_params


def get_test_row(train_data):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be appropriate to split these functionalities into two functions, or pass a debug_flag argument?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've simply removed the manually entered data here. I assume this script will be superseded in time by something that goes over more than 1 sample

row_iterator = iter(train_data)
for _ in range(randint(1, 25)):
test_row = next(row_iterator)

# debug row if needed
return {
"source_text": (
Expand All @@ -57,7 +42,10 @@ def get_test_row(train_data):
),
"class_labels": [0, 1],
}
# Normal row
## Normal row
row_iterator = iter(train_data)
for _ in range(randint(1, 25)):
test_row = next(row_iterator)
return test_row


Expand Down
Loading
Loading