diff --git a/src/m4st/process_demetr.py b/src/m4st/process_demetr.py index 9cbb756..1c542de 100644 --- a/src/m4st/process_demetr.py +++ b/src/m4st/process_demetr.py @@ -14,7 +14,10 @@ class ProcessDEMETR: def __init__( - self, output_filepath: os.PathLike | str, demetr_root: os.PathLike | str + self, + output_filepath: os.PathLike | str, + demetr_root: os.PathLike | str, + metrics_to_use: list, ) -> None: # Conversion from DEMETR language tag to SONAR language code @@ -32,32 +35,20 @@ def __init__( } self.output_path = output_filepath self.demetr_root = demetr_root + self.metrics_to_use = metrics_to_use + + colnames = ["category", *metrics_to_use] with open(self.output_path, "w") as output_file: writer = csv.writer(output_file) - writer.writerow( - [ - "category", - "BLEU", - "SacreBLEU", - "BLASER_ref", - "BLASER_qe", - "COMET_ref", - "COMET_qe", - ] - ) - - def get_accuracy_score( - self, - mt_scores: list, - dfluent_scores: list, - num_samples: int, - reverse_accuracy: bool, - ) -> float: + writer.writerow(colnames) - mask = np.array(mt_scores) > np.array(dfluent_scores) - result = np.count_nonzero(~mask) if reverse_accuracy else np.count_nonzero(mask) - return result / num_samples * 100 + if "Sacre_BLEU" in metrics_to_use: + self.sacre_bleu = SacreBLEUScore() + if "BLASER_ref" in metrics_to_use or "BLASER_qe" in metrics_to_use: + self.blaser = BLASERScore() + if "COMET_ref" in metrics_to_use or "COMET_qe" in metrics_to_use: + self.comet = COMETScore() def process_demetr_category( self, @@ -70,29 +61,10 @@ def process_demetr_category( curr_ds_path = os.path.join(self.demetr_root, cat_fp) json_data = load_json(curr_ds_path) - nltk_bleu_mt = [] - nltk_bleu_d = [] - - sacre_bleu_mt = [] - sacre_bleu_d = [] - - blaser_ref_mt = [] - blaser_ref_d = [] + mt_results = np.zeros((num_samples, len(self.metrics_to_use))) + dis_results = np.zeros((num_samples, len(self.metrics_to_use))) - blaser_qe_mt = [] - blaser_qe_d = [] - - comet_ref_mt = [] - comet_ref_d = [] - - comet_qe_mt = [] - comet_qe_d = [] - - sacre_bleu = SacreBLEUScore() - blaser = BLASERScore() - comet = COMETScore() - - for sentence in json_data: + for i, sentence in enumerate(json_data): ref_txt = sentence["eng_sent"] # Human translation mt_txt = sentence["mt_sent"] # Original machine translation @@ -101,68 +73,53 @@ def process_demetr_category( src_lang = sentence["lang_tag"] # Source language blaser_lang_code = self.language_codes[src_lang] - # String-based metrics - nltk_bleu_mt.append(nltk_bleu_score(ref_txt, mt_txt)) - nltk_bleu_d.append(nltk_bleu_score(ref_txt, dfluent_txt)) - sacre_bleu_mt.append(sacre_bleu.get_score(ref_txt, mt_txt)) - sacre_bleu_d.append(sacre_bleu.get_score(ref_txt, dfluent_txt)) - - # Model-based metrics - # BLASER-2.0 - blaser_ref_mt.append( - blaser.blaser_ref_score(ref_txt, mt_txt, src_text, blaser_lang_code) - ) - blaser_ref_d.append( - blaser.blaser_ref_score( - ref_txt, dfluent_txt, src_text, blaser_lang_code - ) - ) - blaser_qe_mt.append( - blaser.blaser_qe_score(mt_txt, src_text, blaser_lang_code) - ) - blaser_qe_d.append( - blaser.blaser_qe_score(dfluent_txt, src_text, blaser_lang_code) - ) - - # COMET - comet_ref_mt.append(comet.comet_ref_score(ref_txt, mt_txt, src_text)) - comet_ref_d.append(comet.comet_ref_score(ref_txt, dfluent_txt, src_text)) - comet_qe_mt.append(comet.comet_qe_score(mt_txt, src_text)) - comet_qe_d.append(comet.comet_qe_score(dfluent_txt, src_text)) - - # Calculate accuracy as in DEMETR paper - nltk_bleu_avg = self.get_accuracy_score( - nltk_bleu_mt, nltk_bleu_d, num_samples, reverse_accuracy - ) - sacre_bleu_avg = self.get_accuracy_score( - sacre_bleu_mt, sacre_bleu_d, num_samples, reverse_accuracy - ) - blaser_ref_avg = self.get_accuracy_score( - blaser_ref_mt, blaser_ref_d, num_samples, reverse_accuracy - ) - blaser_qe_avg = self.get_accuracy_score( - blaser_qe_mt, blaser_qe_d, num_samples, reverse_accuracy - ) - comet_ref_avg = self.get_accuracy_score( - comet_ref_mt, comet_ref_d, num_samples, reverse_accuracy - ) - comet_qe_avg = self.get_accuracy_score( - comet_qe_mt, comet_qe_d, num_samples, reverse_accuracy - ) + for j, metric in enumerate(self.metrics_to_use): + if metric == "BLEU": + mt_results[i, j] = nltk_bleu_score(ref_txt, mt_txt) + dis_results[i, j] = nltk_bleu_score(ref_txt, dfluent_txt) + elif metric == "Sacre_BLEU": + mt_results[i, j] = self.sacre_bleu.get_score(ref_txt, mt_txt) + dis_results[i, j] = self.sacre_bleu.get_score(ref_txt, dfluent_txt) + elif metric == "BLASER_ref": + mt_results[i, j] = self.blaser.blaser_ref_score( + ref_txt, mt_txt, src_text, blaser_lang_code + ) + dis_results[i, j] = self.blaser.blaser_ref_score( + ref_txt, dfluent_txt, src_text, blaser_lang_code + ) + elif metric == "BLASER_qe": + mt_results[i, j] = self.blaser.blaser_qe_score( + mt_txt, src_text, blaser_lang_code + ) + dis_results[i, j] = self.blaser.blaser_qe_score( + dfluent_txt, src_text, blaser_lang_code + ) + elif metric == "COMET_ref": + mt_results[i, j] = self.comet.comet_ref_score( + ref_txt, mt_txt, src_text + ) + dis_results[i, j] = self.comet.comet_ref_score( + ref_txt, dfluent_txt, src_text + ) + elif metric == "COMET_qe": + mt_results[i, j] = self.comet.comet_qe_score(mt_txt, src_text) + dis_results[i, j] = self.comet.comet_qe_score(dfluent_txt, src_text) + else: + print(f"Unknown metric {metric}") + + mask = mt_results > dis_results + if reverse_accuracy: + results = np.count_nonzero(~mask, axis=0) + else: + results = np.count_nonzero(mask, axis=0) + + results = results / num_samples * 100 + + results_str = [category, *results] with open(self.output_path, "a") as output_file: csv_writer = csv.writer(output_file) - csv_writer.writerow( - [ - category, - nltk_bleu_avg, - sacre_bleu_avg, - blaser_ref_avg, - blaser_qe_avg, - comet_ref_avg, - comet_qe_avg, - ] - ) + csv_writer.writerow(results_str) def process_demetr( self, @@ -184,4 +141,9 @@ def process_demetr( print(f"Processing input file {ds}") reverse_acc = ds_cat == 35 - self.process_demetr_category(ds_cat, ds, samples_per_cat, reverse_acc) + self.process_demetr_category( + ds_cat, + ds, + samples_per_cat, + reverse_acc, + )