Skip to content

Commit

Permalink
Comet version update, according changes have been made. (#129)
Browse files Browse the repository at this point in the history
* Comet version update, according changes have been made.

* Hotfix typo, for pathlib.Path.

* Trying downgrade protobuf.
  • Loading branch information
devrimcavusoglu authored Aug 3, 2023
1 parent f8a2545 commit f8be123
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 23 deletions.
57 changes: 45 additions & 12 deletions jury/metrics/comet/comet_for_language_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
Comet implementation of evaluate package. See
https://github.com/huggingface/evaluate/blob/master/metrics/comet/comet.py
"""

from typing import Callable, Union
from pathlib import Path
from typing import Any, Callable, Dict, Union

import evaluate
from packaging.version import Version

from jury.metrics import LanguageGenerationInstance
from jury.metrics._core import MetricForCrossLingualEvaluation
Expand All @@ -29,6 +30,20 @@
# `import comet` placeholder
comet = PackagePlaceholder(version="1.0.1")

# Comet checkpoints, taken from https://github.com/Unbabel/COMET/blob/master/MODELS.md
COMET_MODELS = {
"emnlp20-comet-rank": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt20/emnlp20-comet-rank.tar.gz",
"wmt20-comet-qe-da": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt20/wmt20-comet-qe-da.tar.gz",
"wmt21-comet-da": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt21/wmt21-comet-da.tar.gz",
"wmt21-comet-mqm": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt21/wmt21-comet-mqm.tar.gz",
"wmt21-comet-qe-da": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt21/wmt21-comet-qe-da.tar.gz",
"wmt21-comet-qe-mqm": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt21/wmt21-comet-qe-mqm.tar.gz",
"wmt21-cometinho-mqm": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt21/wmt21-cometinho-mqm.tar.gz",
"wmt21-cometinho-da": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt21/wmt21-cometinho-da.tar.gz",
"eamt22-prune-comet-da": "https://unbabel-experimental-models.s3.amazonaws.com/comet/eamt22/eamt22-prune-comet-da.tar.gz",
"eamt22-cometinho-da": "https://unbabel-experimental-models.s3.amazonaws.com/comet/eamt22/eamt22-cometinho-da.tar.gz",
}

_CITATION = """\
@inproceedings{rei-EtAl:2020:WMT,
author = {Rei, Ricardo and Stewart, Craig and Farinha, Ana C and Lavie, Alon},
Expand Down Expand Up @@ -102,9 +117,17 @@ def _download_and_prepare(self, dl_manager):
super(CometForCrossLingualEvaluation, self)._download_and_prepare(dl_manager)

if self.config_name == "default":
checkpoint_path = comet.download_model("wmt20-comet-da")
if Version(comet.__version__) >= Version("2.0.0"):
checkpoint_path = comet.download_model("Unbabel/wmt22-comet-da")
else:
checkpoint_path = comet.download_model("wmt20-comet-da")
else:
checkpoint_path = comet.download_model(self.config_name)
model_url = COMET_MODELS.get(self.config_name)
if model_url is not None:
model_dir = dl_manager.download_and_extract(model_url)
checkpoint_path = (Path(model_dir) / f"{self.config_name}/checkpoints/model.ckpt").as_posix()
else:
checkpoint_path = comet.download_model(self.config_name)
self.scorer = comet.load_from_checkpoint(checkpoint_path)

def _info(self):
Expand All @@ -122,6 +145,14 @@ def _info(self):
],
)

def _handle_comet_outputs(self, outputs) -> Dict[str, Any]:
if Version(comet.__version__) >= Version("2.0.0"):
scores = outputs.get("scores")
system_score = outputs.get("system_score")
else:
scores, system_score = outputs
return {"scores": scores, "system_score": system_score}

def _compute_single_pred_single_ref(
self,
sources: LanguageGenerationInstance,
Expand All @@ -138,7 +169,7 @@ def _compute_single_pred_single_ref(
):
data = {"src": sources, "mt": predictions, "ref": references}
data = [dict(zip(data, t)) for t in zip(*data.values())]
scores, samples = self.scorer.predict(
comet_scores = self.scorer.predict(
data,
batch_size=batch_size,
gpus=gpus,
Expand All @@ -148,7 +179,7 @@ def _compute_single_pred_single_ref(
num_workers=num_workers,
length_batching=length_batching,
)
return {"scores": scores, "samples": samples}
return self._handle_comet_outputs(comet_scores)

def _compute_single_pred_multi_ref(
self,
Expand All @@ -168,7 +199,7 @@ def _compute_single_pred_multi_ref(
for src, pred, refs in zip(sources, predictions, references):
data = {"src": [src] * len(refs), "mt": [pred] * len(refs), "ref": refs}
data = [dict(zip(data, t)) for t in zip(*data.values())]
pred_scores, _ = self.scorer.predict(
comet_scores = self.scorer.predict(
data,
batch_size=batch_size,
gpus=gpus,
Expand All @@ -178,9 +209,10 @@ def _compute_single_pred_multi_ref(
num_workers=num_workers,
length_batching=length_batching,
)
scores.append(float(reduce_fn(pred_scores)))
pred_scores = self._handle_comet_outputs(comet_scores)
scores.append(float(reduce_fn(pred_scores["scores"])))

return {"scores": scores, "samples": sum(scores) / len(scores)}
return {"scores": scores, "system_score": sum(scores) / len(scores)}

def _compute_multi_pred_multi_ref(
self,
Expand All @@ -202,7 +234,7 @@ def _compute_multi_pred_multi_ref(
for pred in preds:
data = {"src": [src] * len(refs), "mt": [pred] * len(refs), "ref": refs}
data = [dict(zip(data, t)) for t in zip(*data.values())]
pred_scores, _ = self.scorer.predict(
comet_scores = self.scorer.predict(
data,
batch_size=batch_size,
gpus=gpus,
Expand All @@ -212,7 +244,8 @@ def _compute_multi_pred_multi_ref(
num_workers=num_workers,
length_batching=length_batching,
)
all_pred_scores.append(float(reduce_fn(pred_scores)))
pred_scores = self._handle_comet_outputs(comet_scores)
all_pred_scores.append(float(reduce_fn(pred_scores["scores"])))
scores.append(float(reduce_fn(all_pred_scores)))

return {"scores": scores, "samples": sum(scores) / len(scores)}
return {"scores": scores, "system_score": sum(scores) / len(scores)}
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def add_pywin(reqs: List[str]) -> None:
"jiwer>=2.3.0",
"seqeval==1.2.2",
"sentencepiece==0.1.96",
"unbabel-comet>=1.1.2,<2",
'unbabel-comet>=1.1.2,<2;python_version<"3.8"',
'unbabel-comet>=2.0,<2.1;python_version>="3.8"',
"protobuf<3.20.1",
]

_METRIC_REQUIREMENTS.extend(_PRISM_REQUIREMENTS)
Expand Down
3 changes: 2 additions & 1 deletion tests/jury/metrics/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

@pytest.fixture(scope="module")
def jury_comet():

metric = AutoMetric.load(
"comet",
config_name="wmt21-cometinho-da",
config_name="eamt22-cometinho-da",
compute_kwargs={"gpus": 0, "num_workers": 0, "progress_bar": False, "batch_size": 2},
)
return Jury(metrics=metric)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_data/expected_outputs/metrics/test_comet.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,32 @@
"empty_items": 0,
"comet": {
"scores": [
-0.029062556102871895,
0.47664526104927063
0.4201257526874542,
0.48723894357681274
],
"samples": 0.22379135247319937
"system_score": 0.4536823481321335
}
},
"multiple_ref": {
"total_items": 2,
"empty_items": 0,
"comet": {
"scores": [
0.05850040912628174,
0.49252423644065857
0.421277791261673,
0.48723891377449036
],
"samples": 0.27551232278347015
"system_score": 0.45425835251808167
}
},
"multiple_pred_multiple_ref": {
"total_items": 2,
"empty_items": 0,
"comet": {
"scores": [
0.3315272331237793,
1.4006391763687134
0.7988364100456238,
1.2400819063186646
],
"samples": 0.8660832047462463
"system_score": 1.0194591581821442
}
}
}

0 comments on commit f8be123

Please sign in to comment.