diff --git a/simuleval/evaluator/evaluator.py b/simuleval/evaluator/evaluator.py index 7b6617d..b417f13 100644 --- a/simuleval/evaluator/evaluator.py +++ b/simuleval/evaluator/evaluator.py @@ -170,7 +170,7 @@ def build_instances_from_log(self): if self.output is not None: with open(self.output / "instances.log", "r") as f: for line in f: - instance = LogInstance(line.strip()) + instance = LogInstance(line.strip(), self.args.eval_latency_unit) index = instance.index - self.start_index self.instances[index] = instance self.instances[index].set_target_spm_model(self.target_spm_model) diff --git a/simuleval/evaluator/instance.py b/simuleval/evaluator/instance.py index c4d5ddd..9f33466 100644 --- a/simuleval/evaluator/instance.py +++ b/simuleval/evaluator/instance.py @@ -456,7 +456,7 @@ class SpeechToSpeechInstance(SpeechInputInstance, SpeechOutputInstance): class LogInstance: - def __init__(self, info: str) -> None: + def __init__(self, info: str, latency_unit: str = "word") -> None: self.info = json.loads(info.strip()) self.intervals = [] for key, value in self.info.items(): @@ -464,9 +464,7 @@ def __init__(self, info: str) -> None: self.index = self.info["index"] self.reference = self.info.get("reference", "") - self.reference_length = len( - self.reference.split(" ") - ) # ToDo: temporary solution, make it configurable + self.latency_unit = latency_unit self.source_length = self.info.get("source_length") # just for testing! self.finish_prediction = True self.metrics = {} @@ -474,3 +472,15 @@ def __init__(self, info: str) -> None: def set_target_spm_model(self, spm_model): self.target_spm_model = spm_model + + @property + def reference_length(self) -> int: + if self.latency_unit == "word": + return len(self.reference.split(" ")) + elif self.latency_unit == "char": + return len(self.reference.strip()) + elif self.latency_unit == "spm": + assert self.target_spm_model is not None + return len(self.target_spm_model.encode(self.reference, out_type=str)) + else: + raise NotImplementedError