Skip to content

Commit

Permalink
🔨 Make metrics optional
Browse files Browse the repository at this point in the history
  • Loading branch information
klh5 committed Dec 12, 2024
1 parent 76d8d88 commit a4523ba
Showing 1 changed file with 67 additions and 105 deletions.
172 changes: 67 additions & 105 deletions src/m4st/process_demetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

class ProcessDEMETR:
def __init__(
self, output_filepath: os.PathLike | str, demetr_root: os.PathLike | str
self,
output_filepath: os.PathLike | str,
demetr_root: os.PathLike | str,
metrics_to_use: list,
) -> None:

# Conversion from DEMETR language tag to SONAR language code
Expand All @@ -32,32 +35,20 @@ def __init__(
}
self.output_path = output_filepath
self.demetr_root = demetr_root
self.metrics_to_use = metrics_to_use

colnames = ["category", *metrics_to_use]

with open(self.output_path, "w") as output_file:
writer = csv.writer(output_file)
writer.writerow(
[
"category",
"BLEU",
"SacreBLEU",
"BLASER_ref",
"BLASER_qe",
"COMET_ref",
"COMET_qe",
]
)

def get_accuracy_score(
self,
mt_scores: list,
dfluent_scores: list,
num_samples: int,
reverse_accuracy: bool,
) -> float:
writer.writerow(colnames)

mask = np.array(mt_scores) > np.array(dfluent_scores)
result = np.count_nonzero(~mask) if reverse_accuracy else np.count_nonzero(mask)
return result / num_samples * 100
if "Sacre_BLEU" in metrics_to_use:
self.sacre_bleu = SacreBLEUScore()
if "BLASER_ref" in metrics_to_use or "BLASER_qe" in metrics_to_use:
self.blaser = BLASERScore()
if "COMET_ref" in metrics_to_use or "COMET_qe" in metrics_to_use:
self.comet = COMETScore()

def process_demetr_category(
self,
Expand All @@ -70,29 +61,10 @@ def process_demetr_category(
curr_ds_path = os.path.join(self.demetr_root, cat_fp)
json_data = load_json(curr_ds_path)

nltk_bleu_mt = []
nltk_bleu_d = []

sacre_bleu_mt = []
sacre_bleu_d = []

blaser_ref_mt = []
blaser_ref_d = []
mt_results = np.zeros((num_samples, len(self.metrics_to_use)))
dis_results = np.zeros((num_samples, len(self.metrics_to_use)))

blaser_qe_mt = []
blaser_qe_d = []

comet_ref_mt = []
comet_ref_d = []

comet_qe_mt = []
comet_qe_d = []

sacre_bleu = SacreBLEUScore()
blaser = BLASERScore()
comet = COMETScore()

for sentence in json_data:
for i, sentence in enumerate(json_data):

ref_txt = sentence["eng_sent"] # Human translation
mt_txt = sentence["mt_sent"] # Original machine translation
Expand All @@ -101,68 +73,53 @@ def process_demetr_category(
src_lang = sentence["lang_tag"] # Source language
blaser_lang_code = self.language_codes[src_lang]

# String-based metrics
nltk_bleu_mt.append(nltk_bleu_score(ref_txt, mt_txt))
nltk_bleu_d.append(nltk_bleu_score(ref_txt, dfluent_txt))
sacre_bleu_mt.append(sacre_bleu.get_score(ref_txt, mt_txt))
sacre_bleu_d.append(sacre_bleu.get_score(ref_txt, dfluent_txt))

# Model-based metrics
# BLASER-2.0
blaser_ref_mt.append(
blaser.blaser_ref_score(ref_txt, mt_txt, src_text, blaser_lang_code)
)
blaser_ref_d.append(
blaser.blaser_ref_score(
ref_txt, dfluent_txt, src_text, blaser_lang_code
)
)
blaser_qe_mt.append(
blaser.blaser_qe_score(mt_txt, src_text, blaser_lang_code)
)
blaser_qe_d.append(
blaser.blaser_qe_score(dfluent_txt, src_text, blaser_lang_code)
)

# COMET
comet_ref_mt.append(comet.comet_ref_score(ref_txt, mt_txt, src_text))
comet_ref_d.append(comet.comet_ref_score(ref_txt, dfluent_txt, src_text))
comet_qe_mt.append(comet.comet_qe_score(mt_txt, src_text))
comet_qe_d.append(comet.comet_qe_score(dfluent_txt, src_text))

# Calculate accuracy as in DEMETR paper
nltk_bleu_avg = self.get_accuracy_score(
nltk_bleu_mt, nltk_bleu_d, num_samples, reverse_accuracy
)
sacre_bleu_avg = self.get_accuracy_score(
sacre_bleu_mt, sacre_bleu_d, num_samples, reverse_accuracy
)
blaser_ref_avg = self.get_accuracy_score(
blaser_ref_mt, blaser_ref_d, num_samples, reverse_accuracy
)
blaser_qe_avg = self.get_accuracy_score(
blaser_qe_mt, blaser_qe_d, num_samples, reverse_accuracy
)
comet_ref_avg = self.get_accuracy_score(
comet_ref_mt, comet_ref_d, num_samples, reverse_accuracy
)
comet_qe_avg = self.get_accuracy_score(
comet_qe_mt, comet_qe_d, num_samples, reverse_accuracy
)
for j, metric in enumerate(self.metrics_to_use):
if metric == "BLEU":
mt_results[i, j] = nltk_bleu_score(ref_txt, mt_txt)
dis_results[i, j] = nltk_bleu_score(ref_txt, dfluent_txt)
elif metric == "Sacre_BLEU":
mt_results[i, j] = self.sacre_bleu.get_score(ref_txt, mt_txt)
dis_results[i, j] = self.sacre_bleu.get_score(ref_txt, dfluent_txt)
elif metric == "BLASER_ref":
mt_results[i, j] = self.blaser.blaser_ref_score(
ref_txt, mt_txt, src_text, blaser_lang_code
)
dis_results[i, j] = self.blaser.blaser_ref_score(
ref_txt, dfluent_txt, src_text, blaser_lang_code
)
elif metric == "BLASER_qe":
mt_results[i, j] = self.blaser.blaser_qe_score(
mt_txt, src_text, blaser_lang_code
)
dis_results[i, j] = self.blaser.blaser_qe_score(
dfluent_txt, src_text, blaser_lang_code
)
elif metric == "COMET_ref":
mt_results[i, j] = self.comet.comet_ref_score(
ref_txt, mt_txt, src_text
)
dis_results[i, j] = self.comet.comet_ref_score(
ref_txt, dfluent_txt, src_text
)
elif metric == "COMET_qe":
mt_results[i, j] = self.comet.comet_qe_score(mt_txt, src_text)
dis_results[i, j] = self.comet.comet_qe_score(dfluent_txt, src_text)
else:
print(f"Unknown metric {metric}")

mask = mt_results > dis_results
if reverse_accuracy:
results = np.count_nonzero(~mask, axis=0)
else:
results = np.count_nonzero(mask, axis=0)

results = results / num_samples * 100

results_str = [category, *results]

with open(self.output_path, "a") as output_file:
csv_writer = csv.writer(output_file)
csv_writer.writerow(
[
category,
nltk_bleu_avg,
sacre_bleu_avg,
blaser_ref_avg,
blaser_qe_avg,
comet_ref_avg,
comet_qe_avg,
]
)
csv_writer.writerow(results_str)

def process_demetr(
self,
Expand All @@ -184,4 +141,9 @@ def process_demetr(
print(f"Processing input file {ds}")
reverse_acc = ds_cat == 35

self.process_demetr_category(ds_cat, ds, samples_per_cat, reverse_acc)
self.process_demetr_category(
ds_cat,
ds,
samples_per_cat,
reverse_acc,
)

0 comments on commit a4523ba

Please sign in to comment.