Skip to content

Commit

Permalink
rm hardcoded sample
Browse files Browse the repository at this point in the history
  • Loading branch information
lannelin committed Nov 22, 2024
1 parent ae88310 commit 5d35c90
Showing 1 changed file with 9 additions and 22 deletions.
31 changes: 9 additions & 22 deletions scripts/variational_RTC_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,25 @@
)


def get_random_test_row(train_data):
row_iterator = iter(train_data)
for _ in range(randint(1, 25)):
test_row = next(row_iterator)
return test_row


def load_test_row():
lang_pair = {"source": "fr", "target": "en"}
dataset_dict, metadata_params = load_multieurlex_for_translation(
data_dir="data", level=1, lang_pair=lang_pair
)
train = dataset_dict["train"]
multi_onehot = MultiHot(metadata_params["n_classes"])
test_row = get_test_row(train)
class_labels = multi_onehot(test_row["class_labels"])
test_row = get_random_test_row(train)
class_labels = multi_onehot(test_row["labels"])
return test_row, class_labels, metadata_params


def get_test_row(train_data):
# debug row if needed
return {
"source_text": (
"Le renard brun rapide a sauté par-dessus le chien paresseux."
"Le renard a sauté par-dessus le chien paresseux."
),
"target_text": (
"The quick brown fox jumped over the lazy dog. The fox jumped"
" over the lazy dog"
),
"class_labels": [0, 1],
}
## Normal row
row_iterator = iter(train_data)
for _ in range(randint(1, 25)):
test_row = next(row_iterator)
return test_row


def print_results(clean_output, var_output, class_labels, test_row, comet_model):
# ### TRANSLATION ###
print("\nTranslation:")
Expand Down

0 comments on commit 5d35c90

Please sign in to comment.