Skip to content

Commit

Permalink
simplify score calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Mar 10, 2024
1 parent 32b9c14 commit 61ef0de
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 45 deletions.
2 changes: 1 addition & 1 deletion wtpsplit/evaluation/intrinsic_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Args:
eval_data_path: str = "data/eval.pth"
valid_text_path: str = None # "data/sentence/valid.parquet"
device: str = "cpu"
block_size: int = 512
block_size: int = 510
stride: int = 64
batch_size: int = 1
include_langs: List[str] = None
Expand Down
66 changes: 22 additions & 44 deletions wtpsplit/train/train_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,55 +464,33 @@ def compute_metrics(trainer):

with training_args.main_process_first():
if args.one_sample_per_line:
score = []
info = []
for chunk in eval_data:
score_chunk, info_chunk = evaluate_sentence(
lang,
chunk,
model,
stride=64,
block_size=512,
batch_size=training_args.per_device_eval_batch_size,
do_lowercase=args.do_lowercase,
do_remove_punct=args.do_remove_punct,
)
score.append(score_chunk)
info.append(info_chunk)

score = np.mean(score)
info = {
"f1": np.mean([i["f1"] for i in info]),
"f1_best": np.mean([i["f1_best"] for i in info]),
"threshold_best": np.mean([i["threshold_best"] for i in info]),
}
else:
score, info = evaluate_sentence(
lang,
eval_data,
model,
stride=64,
block_size=512,
batch_size=training_args.per_device_eval_batch_size,
do_lowercase=args.do_lowercase,
do_remove_punct=args.do_remove_punct,
)
metrics[f"{dataset_name}/{lang}/pr_auc"] = score
metrics[f"{dataset_name}/{lang}/f1"] = info["f1"]
metrics[f"{dataset_name}/{lang}/f1_best"] = info["f1_best"]
metrics[f"{dataset_name}/{lang}/threshold_best"] = info["threshold_best"]
if args.eval_pairwise:
score_pairwise, avg_acc = evaluate_sentence_pairwise(
eval_data = [item for sublist in eval_data for item in sublist]
score, info = evaluate_sentence(
lang,
eval_data,
model,
stride=args.eval_stride,
block_size=args.block_size,
stride=64,
block_size=512,
batch_size=training_args.per_device_eval_batch_size,
threshold=0.1,
do_lowercase=args.do_lowercase,
do_remove_punct=args.do_remove_punct,
)
metrics[f"{dataset_name}/{lang}/pairwise/pr_auc"] = score_pairwise
metrics[f"{dataset_name}/{lang}/pairwise/acc"] = avg_acc
metrics[f"{dataset_name}/{lang}/pr_auc"] = score
metrics[f"{dataset_name}/{lang}/f1"] = info["f1"]
metrics[f"{dataset_name}/{lang}/f1_best"] = info["f1_best"]
metrics[f"{dataset_name}/{lang}/threshold_best"] = info["threshold_best"]
if args.eval_pairwise:
score_pairwise, avg_acc = evaluate_sentence_pairwise(
lang,
eval_data,
model,
stride=args.eval_stride,
block_size=args.block_size,
batch_size=training_args.per_device_eval_batch_size,
threshold=0.1,
)
metrics[f"{dataset_name}/{lang}/pairwise/pr_auc"] = score_pairwise
metrics[f"{dataset_name}/{lang}/pairwise/acc"] = avg_acc

return metrics

Expand Down

0 comments on commit 61ef0de

Please sign in to comment.