Skip to content

Commit

Permalink
flatten list, add strip flag
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Mar 26, 2024
1 parent f804804 commit e380c2a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
14 changes: 14 additions & 0 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Args:
save_suffix: str = ""
do_lowercase: bool = False
do_remove_punct: bool = False
do_strip: bool = False


def process_logits(text, model, lang_code, args):
Expand Down Expand Up @@ -166,6 +167,11 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st

if "test_logits" not in dset_group:
test_sentences = dataset["data"]
# if list of lists: flatten
if isinstance(test_sentences[0], list):
test_sentences = [item for sublist in test_sentences for item in sublist]
if args.do_strip:
test_sentences = [sentence.lstrip("-").strip() for sentence in test_sentences]
test_sentences = [
corrupt(sentence, do_lowercase=args.do_lowercase, do_remove_punct=args.do_remove_punct)
for sentence in test_sentences
Expand All @@ -184,6 +190,10 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st

train_sentences = dataset["meta"].get("train_data")
if train_sentences is not None and "train_logits" not in dset_group:
if isinstance(train_sentences[0], list):
train_sentences = [item for sublist in train_sentences for item in sublist]
if args.do_strip:
train_sentences = [sentence.lstrip("-").strip() for sentence in train_sentences]
train_sentences = [
corrupt(sentence, do_lowercase=args.do_lowercase, do_remove_punct=args.do_remove_punct)
for sentence in train_sentences
Expand Down Expand Up @@ -272,6 +282,10 @@ def main(args):

for dataset_name, dataset in dsets["sentence"].items():
sentences = dataset["data"]
if isinstance(sentences[0], list):
sentences = [item for sublist in sentences for item in sublist]
if args.do_strip:
sentences = [sentence.lstrip("-").strip() for sentence in sentences]
sentences = [
corrupt(sentence, do_lowercase=args.do_lowercase, do_remove_punct=args.do_remove_punct)
for sentence in sentences
Expand Down
7 changes: 6 additions & 1 deletion wtpsplit/train/train_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def prepare_dataset(
dataset = datasets.Dataset.from_list(processed_dataset)

else:
if isinstance(dataset[0], list):
# flatten
dataset = [item for sublist in dataset for item in sublist]
dataset = datasets.Dataset.from_list(
[
{
Expand Down Expand Up @@ -465,6 +468,8 @@ def compute_metrics(trainer):
with training_args.main_process_first():
if args.one_sample_per_line:
eval_data = [item for sublist in eval_data for item in sublist]
elif isinstance(eval_data[0], list):
eval_data = [item for sublist in eval_data for item in sublist]
score, info = evaluate_sentence(
lang,
eval_data,
Expand Down Expand Up @@ -579,7 +584,7 @@ def compute_metrics(trainer):
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_lines.pt --save_suffix lines"
elif "verses" in args.text_path:
if args.do_lowercase and args.do_remove_punct:
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_verses_strip_n.pt --save_suffix verses --do_lowercase --do_remove_punct"
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_verses_strip_n_single.pt --save_suffix verses --do_lowercase --do_remove_punct"
else:
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py --model_path {args.model_name_or_path} --adapter_path {training_args.output_dir} --threshold 0.1 --custom_language_list data/lyrics_langs.csv --eval_data_path data/lyrics_verses_strip_n.pt --save_suffix verses"
elif args.do_lowercase and args.do_remove_punct:
Expand Down

0 comments on commit e380c2a

Please sign in to comment.