Skip to content

Commit

Permalink
fix eval selection
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Mar 29, 2024
1 parent ab776f8 commit 52b6a48
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,16 +566,16 @@ def compute_metrics(trainer):
if trainer.args.process_index == 0 and args.do_sentence_training:
# with training_args.main_process_first():
for dataset_name, dataset in lang_data["sentence"].items():
# score, _ = evaluate_sentence(
# lang_code,
# dataset["data"],
# model,
# stride=args.eval_stride,
# block_size=args.block_size,
# batch_size=training_args.per_device_eval_batch_size,
# )
# metrics[f"{lang_code}_{dataset_name}_pr_auc"] = score
# avg_metrics[f"average_{dataset_name}_pr_auc"].append(score)
score, _ = evaluate_sentence(
lang_code,
dataset["data"],
model,
stride=args.eval_stride,
block_size=args.block_size,
batch_size=training_args.per_device_eval_batch_size,
)
metrics[f"{lang_code}_{dataset_name}_pr_auc"] = score
avg_metrics[f"average_{dataset_name}_pr_auc"].append(score)
# if lang_code in ["zh", "ja", "my", "km"]:
# avg_metrics[f"average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
# else:
Expand Down Expand Up @@ -612,24 +612,24 @@ def compute_metrics(trainer):
avg_metrics[f"k_{k}_average_{dataset_name}_pr_auc"].append(score)
metrics[f"k_{k}_{lang_code}_{dataset_name}_acc"] = avg_acc
avg_metrics[f"k_{k}_average_{dataset_name}_acc"].append(avg_acc)
if lang_code in ["zh", "ja", "my", "km"]:
avg_metrics[f"k_{k}_average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
avg_metrics[f"k_{k}_average_nonwhitespace_{dataset_name}_acc"].append(avg_acc)
else:
avg_metrics[f"k_{k}_average_whitespace_{dataset_name}_pr_auc"].append(score)
avg_metrics[f"k_{k}_average_whitespace_{dataset_name}_acc"].append(avg_acc)
# if lang_code in ["zh", "ja", "my", "km"]:
# avg_metrics[f"k_{k}_average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
# avg_metrics[f"k_{k}_average_nonwhitespace_{dataset_name}_acc"].append(avg_acc)
# else:
# avg_metrics[f"k_{k}_average_whitespace_{dataset_name}_pr_auc"].append(score)
# avg_metrics[f"k_{k}_average_whitespace_{dataset_name}_acc"].append(avg_acc)
if k == 2:
# keep keys for backwards compat in wandb
metrics[f"pairwise_{lang_code}_{dataset_name}_pr_auc"] = score
avg_metrics[f"pairwise_average_{dataset_name}_pr_auc"].append(score)
metrics[f"pairwise_{lang_code}_{dataset_name}_acc"] = avg_acc
avg_metrics[f"pairwise_average_{dataset_name}_acc"].append(avg_acc)
if lang_code in ["zh", "ja", "my", "km"]:
avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_acc"].append(avg_acc)
else:
avg_metrics[f"pairwise_average_whitespace_{dataset_name}_pr_auc"].append(score)
avg_metrics[f"pairwise_average_whitespace_{dataset_name}_acc"].append(avg_acc)
# if lang_code in ["zh", "ja", "my", "km"]:
# avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
# avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_acc"].append(avg_acc)
# else:
# avg_metrics[f"pairwise_average_whitespace_{dataset_name}_pr_auc"].append(score)
# avg_metrics[f"pairwise_average_whitespace_{dataset_name}_acc"].append(avg_acc)

for name, values in avg_metrics.items():
if len(values) > 1:
Expand Down

0 comments on commit 52b6a48

Please sign in to comment.