Skip to content

Commit

Permalink
actually use lowercase eval
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jan 14, 2024
1 parent dfba4de commit 914d5fc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
6 changes: 3 additions & 3 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def evaluate_sentence(
batch_size,
use_pysbd=False,
positive_index=None,
# do_lowercase=False,
do_lowercase=False,
):
if positive_index is None:
positive_index = Constants.NEWLINE_INDEX
Expand All @@ -75,8 +75,8 @@ def evaluate_sentence(

separator = Constants.SEPARATORS[lang_code]
text = separator.join(sentences)
# if do_lowercase:
# text = text.lower()
if do_lowercase:
text = text.lower()

logits, offsets_mapping, tokenizer = extract(
[text],
Expand Down
37 changes: 16 additions & 21 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,6 @@ def maybe_pad(text):
def compute_metrics(trainer):
metrics = {}
avg_metrics = defaultdict(lambda: [])
# metrics_lower = {}
# avg_metrics_lower = defaultdict(lambda: [])

model = trainer._wrap_model(trainer.model, training=False)

Expand All @@ -576,30 +574,27 @@ def compute_metrics(trainer):
avg_metrics[f"average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
else:
avg_metrics[f"average_whitespace_{dataset_name}_pr_auc"].append(score)
# for dataset_name, dataset in lang_data["sentence"].items():
# score, _ = evaluate_sentence(
# lang_code,
# dataset["data"],
# model,
# stride=args.eval_stride,
# block_size=args.block_size,
# batch_size=training_args.per_device_eval_batch_size,
# do_lowercase=True,
# )
# metrics_lower[f"lower_{lang_code}_{dataset_name}_pr_auc"] = score
# avg_metrics_lower[f"lower_average_{dataset_name}_pr_auc"].append(score)
# if lang_code in ["zh", "ja", "my", "km"]:
# avg_metrics_lower[f"lower_average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
# else:
# avg_metrics_lower[f"lower_average_whitespace_{dataset_name}_pr_auc"].append(score)
for dataset_name, dataset in lang_data["sentence"].items():
score, _ = evaluate_sentence(
lang_code,
dataset["data"],
model,
stride=args.eval_stride,
block_size=args.block_size,
batch_size=training_args.per_device_eval_batch_size,
do_lowercase=True,
)
metrics[f"lower_{lang_code}_{dataset_name}_pr_auc"] = score
avg_metrics[f"lower_average_{dataset_name}_pr_auc"].append(score)
if lang_code in ["zh", "ja", "my", "km"]:
avg_metrics[f"lower_average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
else:
avg_metrics[f"lower_average_whitespace_{dataset_name}_pr_auc"].append(score)


for name, values in avg_metrics.items():
if len(values) > 1:
metrics[name] = np.mean(values)
# for name, values in avg_metrics_lower.items():
# if len(values) > 1:
# metrics_lower[name] = np.mean(values)

return metrics

Expand Down

0 comments on commit 914d5fc

Please sign in to comment.