Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comet version update, according changes have been made. #129

Merged
merged 3 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions jury/metrics/comet/comet_for_language_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
Comet implementation of evaluate package. See
https://github.com/huggingface/evaluate/blob/master/metrics/comet/comet.py
"""

from typing import Callable, Union
from pathlib import Path
from typing import Any, Callable, Dict, Union

import evaluate
from packaging.version import Version

from jury.metrics import LanguageGenerationInstance
from jury.metrics._core import MetricForCrossLingualEvaluation
Expand All @@ -29,6 +30,20 @@
# `import comet` placeholder
comet = PackagePlaceholder(version="1.0.1")

# Comet checkpoints, taken from https://github.com/Unbabel/COMET/blob/master/MODELS.md
COMET_MODELS = {
"emnlp20-comet-rank": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt20/emnlp20-comet-rank.tar.gz",
"wmt20-comet-qe-da": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt20/wmt20-comet-qe-da.tar.gz",
"wmt21-comet-da": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt21/wmt21-comet-da.tar.gz",
"wmt21-comet-mqm": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt21/wmt21-comet-mqm.tar.gz",
"wmt21-comet-qe-da": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt21/wmt21-comet-qe-da.tar.gz",
"wmt21-comet-qe-mqm": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt21/wmt21-comet-qe-mqm.tar.gz",
"wmt21-cometinho-mqm": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt21/wmt21-cometinho-mqm.tar.gz",
"wmt21-cometinho-da": "https://unbabel-experimental-models.s3.amazonaws.com/comet/wmt21/wmt21-cometinho-da.tar.gz",
"eamt22-prune-comet-da": "https://unbabel-experimental-models.s3.amazonaws.com/comet/eamt22/eamt22-prune-comet-da.tar.gz",
"eamt22-cometinho-da": "https://unbabel-experimental-models.s3.amazonaws.com/comet/eamt22/eamt22-cometinho-da.tar.gz",
}

_CITATION = """\
@inproceedings{rei-EtAl:2020:WMT,
author = {Rei, Ricardo and Stewart, Craig and Farinha, Ana C and Lavie, Alon},
Expand Down Expand Up @@ -102,9 +117,17 @@ def _download_and_prepare(self, dl_manager):
super(CometForCrossLingualEvaluation, self)._download_and_prepare(dl_manager)

if self.config_name == "default":
checkpoint_path = comet.download_model("wmt20-comet-da")
if Version(comet.__version__) >= Version("2.0.0"):
checkpoint_path = comet.download_model("Unbabel/wmt22-comet-da")
else:
checkpoint_path = comet.download_model("wmt20-comet-da")
else:
checkpoint_path = comet.download_model(self.config_name)
model_url = COMET_MODELS.get(self.config_name)
if model_url is not None:
model_dir = dl_manager.download_and_extract(model_url)
checkpoint_path = (Path(model_dir) / f"{self.config_name}/checkpoints/model.ckpt").as_posix()
else:
checkpoint_path = comet.download_model(self.config_name)
self.scorer = comet.load_from_checkpoint(checkpoint_path)

def _info(self):
Expand All @@ -122,6 +145,14 @@ def _info(self):
],
)

def _handle_comet_outputs(self, outputs) -> Dict[str, Any]:
if Version(comet.__version__) >= Version("2.0.0"):
scores = outputs.get("scores")
system_score = outputs.get("system_score")
else:
scores, system_score = outputs
return {"scores": scores, "system_score": system_score}

def _compute_single_pred_single_ref(
self,
sources: LanguageGenerationInstance,
Expand All @@ -138,7 +169,7 @@ def _compute_single_pred_single_ref(
):
data = {"src": sources, "mt": predictions, "ref": references}
data = [dict(zip(data, t)) for t in zip(*data.values())]
scores, samples = self.scorer.predict(
comet_scores = self.scorer.predict(
data,
batch_size=batch_size,
gpus=gpus,
Expand All @@ -148,7 +179,7 @@ def _compute_single_pred_single_ref(
num_workers=num_workers,
length_batching=length_batching,
)
return {"scores": scores, "samples": samples}
return self._handle_comet_outputs(comet_scores)

def _compute_single_pred_multi_ref(
self,
Expand All @@ -168,7 +199,7 @@ def _compute_single_pred_multi_ref(
for src, pred, refs in zip(sources, predictions, references):
data = {"src": [src] * len(refs), "mt": [pred] * len(refs), "ref": refs}
data = [dict(zip(data, t)) for t in zip(*data.values())]
pred_scores, _ = self.scorer.predict(
comet_scores = self.scorer.predict(
data,
batch_size=batch_size,
gpus=gpus,
Expand All @@ -178,9 +209,10 @@ def _compute_single_pred_multi_ref(
num_workers=num_workers,
length_batching=length_batching,
)
scores.append(float(reduce_fn(pred_scores)))
pred_scores = self._handle_comet_outputs(comet_scores)
scores.append(float(reduce_fn(pred_scores["scores"])))

return {"scores": scores, "samples": sum(scores) / len(scores)}
return {"scores": scores, "system_score": sum(scores) / len(scores)}

def _compute_multi_pred_multi_ref(
self,
Expand All @@ -202,7 +234,7 @@ def _compute_multi_pred_multi_ref(
for pred in preds:
data = {"src": [src] * len(refs), "mt": [pred] * len(refs), "ref": refs}
data = [dict(zip(data, t)) for t in zip(*data.values())]
pred_scores, _ = self.scorer.predict(
comet_scores = self.scorer.predict(
data,
batch_size=batch_size,
gpus=gpus,
Expand All @@ -212,7 +244,8 @@ def _compute_multi_pred_multi_ref(
num_workers=num_workers,
length_batching=length_batching,
)
all_pred_scores.append(float(reduce_fn(pred_scores)))
pred_scores = self._handle_comet_outputs(comet_scores)
all_pred_scores.append(float(reduce_fn(pred_scores["scores"])))
scores.append(float(reduce_fn(all_pred_scores)))

return {"scores": scores, "samples": sum(scores) / len(scores)}
return {"scores": scores, "system_score": sum(scores) / len(scores)}
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def add_pywin(reqs: List[str]) -> None:
"jiwer>=2.3.0",
"seqeval==1.2.2",
"sentencepiece==0.1.96",
"unbabel-comet>=1.1.2,<2",
'unbabel-comet>=1.1.2,<2;python_version<"3.8"',
'unbabel-comet>=2.0,<2.1;python_version>="3.8"',
"protobuf<3.20.1",
]

_METRIC_REQUIREMENTS.extend(_PRISM_REQUIREMENTS)
Expand Down
3 changes: 2 additions & 1 deletion tests/jury/metrics/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

@pytest.fixture(scope="module")
def jury_comet():

metric = AutoMetric.load(
"comet",
config_name="wmt21-cometinho-da",
config_name="eamt22-cometinho-da",
compute_kwargs={"gpus": 0, "num_workers": 0, "progress_bar": False, "batch_size": 2},
)
return Jury(metrics=metric)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_data/expected_outputs/metrics/test_comet.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,32 @@
"empty_items": 0,
"comet": {
"scores": [
-0.029062556102871895,
0.47664526104927063
0.4201257526874542,
0.48723894357681274
],
"samples": 0.22379135247319937
"system_score": 0.4536823481321335
}
},
"multiple_ref": {
"total_items": 2,
"empty_items": 0,
"comet": {
"scores": [
0.05850040912628174,
0.49252423644065857
0.421277791261673,
0.48723891377449036
],
"samples": 0.27551232278347015
"system_score": 0.45425835251808167
}
},
"multiple_pred_multiple_ref": {
"total_items": 2,
"empty_items": 0,
"comet": {
"scores": [
0.3315272331237793,
1.4006391763687134
0.7988364100456238,
1.2400819063186646
],
"samples": 0.8660832047462463
"system_score": 1.0194591581821442
}
}
}