diff --git a/machine/jobs/nmt_engine_build_job.py b/machine/jobs/nmt_engine_build_job.py index a69563c..36d7539 100644 --- a/machine/jobs/nmt_engine_build_job.py +++ b/machine/jobs/nmt_engine_build_job.py @@ -81,7 +81,7 @@ def run( inference_step_count = sum(1 for _ in src_pretranslations) with ExitStack() as stack: phase_progress = stack.enter_context(progress_reporter.start_next_phase()) - model = stack.enter_context(self._nmt_model_factory.create_engine()) + engine = stack.enter_context(self._nmt_model_factory.create_engine()) src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations()) writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer()) current_inference_step = 0 @@ -90,7 +90,7 @@ def run( for pi_batch in batch(src_pretranslations, batch_size): if check_canceled is not None: check_canceled() - _translate_batch(model, pi_batch, writer) + _translate_batch(engine, pi_batch, writer) current_inference_step += len(pi_batch) phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) diff --git a/machine/tokenization/__init__.py b/machine/tokenization/__init__.py index 74337b8..aab2dbf 100644 --- a/machine/tokenization/__init__.py +++ b/machine/tokenization/__init__.py @@ -7,6 +7,7 @@ from .range_tokenizer import RangeTokenizer from .string_detokenizer import StringDetokenizer from .string_tokenizer import StringTokenizer +from .tokenization_utils import get_ranges, split from .tokenizer import Tokenizer from .whitespace_detokenizer import WHITESPACE_DETOKENIZER, WhitespaceDetokenizer from .whitespace_tokenizer import WHITESPACE_TOKENIZER, WhitespaceTokenizer @@ -15,12 +16,14 @@ __all__ = [ "Detokenizer", + "get_ranges", "LatinSentenceTokenizer", "LatinWordDetokenizer", "LatinWordTokenizer", "LineSegmentTokenizer", "NullTokenizer", "RangeTokenizer", + "split", "StringDetokenizer", "StringTokenizer", "Tokenizer", diff --git a/machine/tokenization/tokenization_utils.py b/machine/tokenization/tokenization_utils.py new file mode 100644 index 0000000..3ceda14 --- /dev/null +++ b/machine/tokenization/tokenization_utils.py @@ -0,0 +1,17 @@ +from typing import Generator, Iterable, List + +from ..annotations.range import Range + + +def split(s: str, ranges: Iterable[Range[int]]) -> List[str]: + return [s[range.start : range.end] for range in ranges] + + +def get_ranges(s: str, tokens: Iterable[str]) -> Generator[Range[int], None, None]: + start = 0 + for token in tokens: + index = s.find(token, start) + if index == -1: + raise ValueError(f"The string does not contain the specified token: {token}.") + yield Range.create(index, index + len(token)) + start = index + len(token) diff --git a/machine/translation/__init__.py b/machine/translation/__init__.py index 89c9329..51a9c2e 100644 --- a/machine/translation/__init__.py +++ b/machine/translation/__init__.py @@ -1,26 +1,33 @@ from .corpus_ops import translate_corpus, word_align_corpus from .edit_operation import EditOperation +from .error_correction_model import ErrorCorrectionModel from .evaluation import compute_bleu from .fuzzy_edit_distance_word_alignment_method import FuzzyEditDistanceWordAlignmentMethod from .hmm_word_alignment_model import HmmWordAlignmentModel from .ibm1_word_alignment_model import Ibm1WordAlignmentModel from .ibm1_word_confidence_estimator import Ibm1WordConfidenceEstimator from .ibm2_word_alignment_model import Ibm2WordAlignmentModel -from .interactive_translation_engine import InterativeTranslationEngine +from .interactive_translation_engine import InteractiveTranslationEngine from .interactive_translation_model import InteractiveTranslationModel +from .interactive_translator import InteractiveTranslator +from .interactive_translator_factory import InteractiveTranslatorFactory from .null_trainer import NullTrainer from .phrase import Phrase +from .phrase_translation_suggester import PhraseTranslationSuggester from .segment_scorer import SegmentScorer from .symmetrization_heuristic import SymmetrizationHeuristic from .symmetrized_word_aligner import SymmetrizedWordAligner from .symmetrized_word_alignment_model import SymmetrizedWordAlignmentModel from .symmetrized_word_alignment_model_trainer import SymmetrizedWordAlignmentModelTrainer from .trainer import Trainer, TrainStats +from .translation_constants import MAX_SEGMENT_LENGTH from .translation_engine import TranslationEngine from .translation_model import TranslationModel from .translation_result import TranslationResult from .translation_result_builder import TranslationResultBuilder from .translation_sources import TranslationSources +from .translation_suggester import TranslationSuggester +from .translation_suggestion import TranslationSuggestion from .word_aligner import WordAligner from .word_alignment_matrix import WordAlignmentMatrix from .word_alignment_method import WordAlignmentMethod @@ -29,21 +36,23 @@ from .word_graph import WordGraph from .word_graph_arc import WordGraphArc -MAX_SEGMENT_LENGTH = 200 - __all__ = [ "compute_bleu", "EditOperation", + "ErrorCorrectionModel", "FuzzyEditDistanceWordAlignmentMethod", "HmmWordAlignmentModel", "Ibm1WordAlignmentModel", "Ibm1WordConfidenceEstimator", "Ibm2WordAlignmentModel", + "InteractiveTranslationEngine", "InteractiveTranslationModel", - "InterativeTranslationEngine", + "InteractiveTranslator", + "InteractiveTranslatorFactory", "MAX_SEGMENT_LENGTH", "NullTrainer", "Phrase", + "PhraseTranslationSuggester", "SegmentScorer", "SymmetrizationHeuristic", "SymmetrizedWordAligner", @@ -57,6 +66,8 @@ "TranslationResult", "TranslationResultBuilder", "TranslationSources", + "TranslationSuggester", + "TranslationSuggestion", "word_align_corpus", "WordAligner", "WordAlignmentMatrix", diff --git a/machine/translation/ecm_score_info.py b/machine/translation/ecm_score_info.py new file mode 100644 index 0000000..dee406d --- /dev/null +++ b/machine/translation/ecm_score_info.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import List + +from .edit_operation import EditOperation + + +class EcmScoreInfo: + def __init__(self) -> None: + self._scores: List[float] = [] + self._operations: List[EditOperation] = [] + + @property + def scores(self) -> List[float]: + return self._scores + + @property + def operations(self) -> List[EditOperation]: + return self._operations + + def update_positions(self, prev_esi: EcmScoreInfo, positions: List[int]) -> None: + while len(self.scores) < len(prev_esi.scores): + self.scores.append(0.0) + + while len(self.operations) < len(prev_esi.operations): + self.operations.append(EditOperation.NONE) + + for i in range(len(positions)): + self.scores[positions[i]] = prev_esi.scores[positions[i]] + if len(prev_esi.operations) > i: + self.operations[positions[i]] = prev_esi.operations[positions[i]] + + def remove_last(self) -> None: + if len(self.scores) > 1: + self.scores.pop() + if len(self.operations) > 1: + self.operations.pop() + + def get_last_ins_prefix_word_from_esi(self) -> List[int]: + results = [0] * len(self.operations) + + for j in range(len(self.operations) - 1, -1, -1): + if self.operations[j] == EditOperation.HIT: + results[j] = j - 1 + elif self.operations[j] == EditOperation.INSERT: + tj = j + while tj >= 0 and self.operations[tj] == EditOperation.INSERT: + tj -= 1 + if self.operations[tj] == EditOperation.HIT or self.operations[tj] == EditOperation.SUBSTITUTE: + tj -= 1 + results[j] = tj + elif self.operations[j] == EditOperation.DELETE: + results[j] = j + elif self.operations[j] == EditOperation.SUBSTITUTE: + results[j] = j - 1 + elif self.operations[j] == EditOperation.NONE: + results[j] = 0 + + return results diff --git a/machine/translation/edit_distance.py b/machine/translation/edit_distance.py new file mode 100644 index 0000000..2f776d3 --- /dev/null +++ b/machine/translation/edit_distance.py @@ -0,0 +1,133 @@ +from abc import ABC, abstractmethod +from typing import Generic, Iterable, List, Tuple, TypeVar + +from .edit_operation import EditOperation + +Seq = TypeVar("Seq") +Item = TypeVar("Item") + + +class EditDistance(ABC, Generic[Seq, Item]): + @abstractmethod + def _get_count(self, seq: Seq) -> int: + ... + + @abstractmethod + def _get_item(self, seq: Seq, index: int) -> Item: + ... + + @abstractmethod + def _get_hit_cost(self, x: Item, y: Item, is_complete: bool) -> float: + ... + + @abstractmethod + def _get_substitution_cost(self, x: Item, y: Item, is_complete: bool) -> float: + ... + + @abstractmethod + def _get_deletion_cost(self, x: Item) -> float: + ... + + @abstractmethod + def _get_insertion_cost(self, y: Item) -> float: + ... + + @abstractmethod + def _is_hit(self, x: Item, y: Item, is_complete: bool) -> bool: + ... + + def _init_dist_matrix(self, x: Seq, y: Seq) -> List[List[float]]: + x_count = self._get_count(x) + y_count = self._get_count(y) + dim = max(x_count, y_count) + dist_matrix = [[0.0 for _ in range(dim + 1)] for _ in range(dim + 1)] + return dist_matrix + + def _compute_dist_matrix( + self, x: Seq, y: Seq, is_last_item_complete: bool, use_prefix_del_op: bool + ) -> Tuple[float, List[List[float]]]: + dist_matrix = self._init_dist_matrix(x, y) + + x_count = self._get_count(x) + y_count = self._get_count(y) + for i in range(x_count + 1): + for j in range(y_count + 1): + dist_matrix[i][j], _, _, _ = self._process_dist_matrix_cell( + x, y, dist_matrix, use_prefix_del_op, j != y_count or is_last_item_complete, i, j + ) + + return dist_matrix[x_count][y_count], dist_matrix + + def _process_dist_matrix_cell( + self, x: Seq, y: Seq, dist_matrix: List[List[float]], use_prefix_del_op: bool, is_complete: bool, i: int, j: int + ) -> Tuple[float, int, int, EditOperation]: + if i != 0 and j != 0: + x_item = self._get_item(x, i - 1) + y_item = self._get_item(y, j - 1) + if self._is_hit(x_item, y_item, is_complete): + subst_cost = self._get_hit_cost(x_item, y_item, is_complete) + op = EditOperation.HIT + else: + subst_cost = self._get_substitution_cost(x_item, y_item, is_complete) + op = EditOperation.SUBSTITUTE + + cost = dist_matrix[i - 1][j - 1] + subst_cost + min = cost + i_pred = i - 1 + j_pred = j - 1 + + del_cost = 0 if use_prefix_del_op and j == self._get_count(y) else self._get_deletion_cost(x_item) + cost = dist_matrix[i - 1][j] + del_cost + if cost < min: + min = cost + i_pred = i - 1 + j_pred = j + op = EditOperation.PREFIX_DELETE if del_cost == 0 else EditOperation.DELETE + + cost = dist_matrix[i][j - 1] + self._get_insertion_cost(y_item) + if cost < min: + min = cost + i_pred = i + j_pred = j - 1 + op = EditOperation.INSERT + + return (min, i_pred, j_pred, op) + + if i == 0 and j == 0: + return (0.0, 0, 0, EditOperation.NONE) + + if i == 0: + return ( + dist_matrix[0][j - 1] + self._get_insertion_cost(self._get_item(y, j - 1)), + 0, + j - 1, + EditOperation.INSERT, + ) + + return ( + dist_matrix[i - 1][0] + self._get_deletion_cost(self._get_item(x, i - 1)), + i - 1, + 0, + EditOperation.DELETE, + ) + + def _get_operations( + self, + x: Seq, + y: Seq, + dist_matrix: List[List[float]], + is_last_item_complete: bool, + use_prefix_del_op: bool, + i: int, + j: int, + ) -> Iterable[EditOperation]: + y_count = self._get_count(y) + ops: List[EditOperation] = [] + while i > 0 or j > 0: + _, i, j, op = self._process_dist_matrix_cell( + x, y, dist_matrix, use_prefix_del_op, j != y_count or is_last_item_complete, i, j + ) + if op != EditOperation.PREFIX_DELETE: + ops.append(op) + ops.reverse() + return ops diff --git a/machine/translation/error_correction_model.py b/machine/translation/error_correction_model.py new file mode 100644 index 0000000..b3eb2e2 --- /dev/null +++ b/machine/translation/error_correction_model.py @@ -0,0 +1,81 @@ +from math import log +from typing import Sequence + +from .ecm_score_info import EcmScoreInfo +from .edit_operation import EditOperation +from .segment_edit_distance import SegmentEditDistance +from .translation_result_builder import TranslationResultBuilder +from .translation_sources import TranslationSources + + +class ErrorCorrectionModel: + def __init__(self) -> None: + self._segment_edit_distance = SegmentEditDistance() + self.set_error_model_parameters(voc_size=128, hit_prob=0.8, ins_factor=1, subst_factor=1, del_factor=1) + + def set_error_model_parameters( + self, voc_size: int, hit_prob: float, ins_factor: float, subst_factor: float, del_factor: float + ) -> None: + if voc_size == 0: + e = (1 - hit_prob) / (ins_factor + subst_factor + del_factor) + else: + e = (1 - hit_prob) / ((ins_factor * voc_size) + (subst_factor * (voc_size - 1)) + del_factor) + + ins_prob = e * ins_factor + subst_prob = e * subst_factor + del_prob = e * del_factor + + self._segment_edit_distance.hit_cost = -log(hit_prob) + self._segment_edit_distance.insertion_cost = -log(ins_prob) + self._segment_edit_distance.substitution_cost = -log(subst_prob) + self._segment_edit_distance.deletion_cost = -log(del_prob) + + def setup_initial_esi(self, initial_esi: EcmScoreInfo) -> None: + score = self._segment_edit_distance.compute([], []) + initial_esi.scores.clear() + initial_esi.scores.append(score) + initial_esi.operations.clear() + + def setup_esi(self, esi: EcmScoreInfo, prev_esi: EcmScoreInfo, word: str) -> None: + score = self._segment_edit_distance.compute([word], []) + esi.scores.clear() + esi.scores.append(prev_esi.scores[0] + score) + esi.operations.clear() + esi.operations.append(EditOperation.NONE) + + def extend_initial_esi( + self, initial_esi: EcmScoreInfo, prev_initial_esi: EcmScoreInfo, prefix_diff: Sequence[str] + ) -> None: + self._segment_edit_distance.incr_compute_prefix_first_row( + initial_esi.scores, prev_initial_esi.scores, prefix_diff + ) + + def extend_esi( + self, + esi: EcmScoreInfo, + prev_esi: EcmScoreInfo, + word: str, + prefix_diff: Sequence[str], + is_last_word_complete: bool, + ) -> None: + ops = self._segment_edit_distance.incr_compute_prefix( + esi.scores, prev_esi.scores, word, prefix_diff, is_last_word_complete + ) + esi.operations.extend(ops) + + def correct_prefix( + self, + builder: TranslationResultBuilder, + uncorrected_prefix_len: int, + prefix: Sequence[str], + is_last_word_complete: bool, + ) -> int: + if uncorrected_prefix_len == 0: + for w in prefix: + builder.append_token(w, TranslationSources.PREFIX, -1) + return len(prefix) + + _, word_ops, char_ops = self._segment_edit_distance.compute_prefix( + builder.target_tokens[:uncorrected_prefix_len], prefix, is_last_word_complete, use_prefix_del_op=False + ) + return builder.correct_prefix(word_ops, char_ops, prefix, is_last_word_complete) diff --git a/machine/translation/error_correction_word_graph_processor.py b/machine/translation/error_correction_word_graph_processor.py new file mode 100644 index 0000000..e303258 --- /dev/null +++ b/machine/translation/error_correction_word_graph_processor.py @@ -0,0 +1,394 @@ +from __future__ import annotations + +import sys +from dataclasses import dataclass, field +from functools import total_ordering +from queue import PriorityQueue +from typing import Any, Generator, List, Sequence, Set + +from ..statistics.log_space import LOG_SPACE_ZERO +from ..tokenization.detokenizer import Detokenizer +from .ecm_score_info import EcmScoreInfo +from .error_correction_model import ErrorCorrectionModel +from .translation_result import TranslationResult +from .translation_result_builder import TranslationResultBuilder +from .word_alignment_matrix import WordAlignmentMatrix +from .word_graph import WORD_GRAPH_INITIAL_STATE, WordGraph +from .word_graph_arc import WordGraphArc + + +class ErrorCorrectionWordGraphProcessor: + def __init__( + self, + ecm: ErrorCorrectionModel, + target_detokenizer: Detokenizer[str, str], + word_graph: WordGraph, + ecm_weight: float = 1, + word_graph_weight: float = 1, + ) -> None: + self.confidence_threshold = 0.0 + self._ecm = ecm + self._target_detokenizer = target_detokenizer + self._word_graph = word_graph + self._ecm_weight = ecm_weight + self._word_graph_weight = word_graph_weight + + self._rest_scores = self._word_graph.compute_rest_scores() + self._state_ecm_score_infos: List[EcmScoreInfo] = [] + self._arc_ecm_score_infos: List[List[EcmScoreInfo]] = [] + self._state_best_scores: List[List[float]] = [] + self._state_word_graph_scores: List[float] = [] + self._state_best_prev_arcs: List[List[int]] = [] + self._states_involved_in_arcs: Set[int] = set() + self._prev_prefix: List[str] = [] + self._prev_is_last_word_complete = False + + self._init_states() + self._init_arcs() + + @property + def ecm_weight(self) -> float: + return self._ecm_weight + + @property + def word_graph_weight(self) -> float: + return self._word_graph_weight + + def _init_states(self) -> None: + for _ in range(self._word_graph.state_count): + self._state_ecm_score_infos.append(EcmScoreInfo()) + self._state_word_graph_scores.append(LOG_SPACE_ZERO) + self._state_best_scores.append([]) + self._state_best_prev_arcs.append([]) + + if not self._word_graph.is_empty: + self._ecm.setup_initial_esi(self._state_ecm_score_infos[WORD_GRAPH_INITIAL_STATE]) + self._update_initial_state_best_scores() + + def _init_arcs(self) -> None: + for arc_index in range(len(self._word_graph.arcs)): + arc = self._word_graph.arcs[arc_index] + + # init ecm score info for each word of arc + prev_esi = self._state_ecm_score_infos[arc.prev_state] + esis: List[EcmScoreInfo] = [] + for word in arc.target_tokens: + esi = EcmScoreInfo() + self._ecm.setup_esi(esi, prev_esi, word) + esis.append(esi) + prev_esi = esi + self._arc_ecm_score_infos.append(esis) + + # init best scores for the arc's successive state + self._update_state_best_scores(arc_index, 0) + + self._states_involved_in_arcs.add(arc.prev_state) + self._states_involved_in_arcs.add(arc.next_state) + + def _update_initial_state_best_scores(self) -> None: + esi = self._state_ecm_score_infos[WORD_GRAPH_INITIAL_STATE] + + self._state_word_graph_scores[WORD_GRAPH_INITIAL_STATE] = self._word_graph.initial_state_score + + best_scores = self._state_best_scores[WORD_GRAPH_INITIAL_STATE] + best_prev_arcs = self._state_best_prev_arcs[WORD_GRAPH_INITIAL_STATE] + + best_scores.clear() + best_prev_arcs.clear() + + for score in esi.scores: + best_scores.append( + (self.ecm_weight * -score) + (self.word_graph_weight * self._word_graph.initial_state_score) + ) + best_prev_arcs.append(sys.maxsize) + + def _update_state_best_scores(self, arc_index: int, prefix_diff_size: int) -> None: + arc = self._word_graph.arcs[arc_index] + arc_esis = self._arc_ecm_score_infos[arc_index] + + prev_esi = self._state_ecm_score_infos[arc.prev_state] if len(arc_esis) == 0 else arc_esis[-1] + + word_graph_score = self._state_word_graph_scores[arc.prev_state] + arc.score + + next_state_best_scores = self._state_best_scores[arc.next_state] + next_state_best_prev_arcs = self._state_best_prev_arcs[arc.next_state] + + positions: List[int] = [] + start_pos = 0 if prefix_diff_size == 0 else len(prev_esi.scores) - prefix_diff_size + for i in range(start_pos, len(prev_esi.scores)): + new_score = (self.ecm_weight * -prev_esi.scores[i]) + (self.word_graph_weight * word_graph_score) + + if i == len(next_state_best_scores) or next_state_best_scores[i] < new_score: + _add_or_replace(next_state_best_scores, i, new_score) + positions.append(i) + _add_or_replace(next_state_best_prev_arcs, i, arc_index) + + self._state_ecm_score_infos[arc.next_state].update_positions(prev_esi, positions) + + if word_graph_score > self._state_word_graph_scores[arc.next_state]: + self._state_word_graph_scores[arc.next_state] = word_graph_score + + def correct(self, prefix: Sequence[str], is_last_word_complete: bool) -> None: + # get valid portion of the processed prefix vector + valid_proc_prefix_count = 0 + for i in range(len(self._prev_prefix)): + if i >= len(prefix): + break + + if i == len(self._prev_prefix) - 1 and i == len(prefix) - 1: + if self._prev_prefix[i] == prefix[i] and self._prev_is_last_word_complete == is_last_word_complete: + valid_proc_prefix_count += 1 + elif self._prev_prefix[i] == prefix[i]: + valid_proc_prefix_count += 1 + + diff_size = len(self._prev_prefix) - valid_proc_prefix_count + if diff_size > 0: + # adjust size of info for arcs + for esis in self._arc_ecm_score_infos: + for esi in esis: + for i in range(diff_size): + esi.remove_last() + + # adjust size of info for states + for state in self._states_involved_in_arcs: + for i in range(diff_size): + self._state_ecm_score_infos[state].remove_last() + self._state_best_scores[state].pop() + self._state_best_prev_arcs[state].pop() + + # get difference between prefix and valid portion of processed prefix + prefix_diff: List[str] = [] + for i in range(len(prefix) - valid_proc_prefix_count): + prefix_diff.append(prefix[valid_proc_prefix_count + i]) + + # process word-graph given prefix difference + self._process_word_graph_prefix_diff(prefix_diff, is_last_word_complete) + + self._prev_prefix = list(prefix) + self._prev_is_last_word_complete = is_last_word_complete + + def get_results(self) -> Generator[TranslationResult, None, None]: + queue = self._get_hypotheses() + + for hypothesis in self._search(queue): + builder = TranslationResultBuilder(self._word_graph.source_tokens, self._target_detokenizer) + self._build_correction_from_hypothesis( + builder, self._prev_prefix, self._prev_is_last_word_complete, hypothesis + ) + yield builder.to_result() + + def _search(self, queue: PriorityQueue) -> Generator[_Hypothesis, None, None]: + while queue.not_empty: + hypothesis: _Hypothesis = queue.get() + last_state = hypothesis.start_state if len(hypothesis.arcs) == 0 else hypothesis.arcs[-1].next_state + + if last_state in self._word_graph.final_states: + yield hypothesis + elif self.confidence_threshold <= 0: + hypothesis.arcs.extend(self._word_graph.get_best_path_from_state_to_final_state(last_state)) + yield hypothesis + else: + score = hypothesis.score - (self.word_graph_weight * self._rest_scores[last_state]) + arc_indices = self._word_graph.get_next_arc_indices(last_state) + enqueued_arc = False + for i in range(len(arc_indices)): + arc_index = arc_indices[i] + arc = self._word_graph.arcs[arc_index] + if self._is_arc_pruned(arc): + continue + + new_hypothesis = hypothesis + if i < len(arc_indices) - 1: + new_hypothesis = new_hypothesis.clone() + new_hypothesis.score = score + new_hypothesis.score += arc.score + new_hypothesis.score += self._rest_scores[arc.next_state] + new_hypothesis.arcs.append(arc) + queue.put(new_hypothesis) + enqueued_arc = True + + if not enqueued_arc and (hypothesis.start_arc_index != -1 or len(hypothesis.arcs) > 0): + hypothesis.arcs.extend(self._word_graph.get_best_path_from_state_to_final_state(last_state)) + yield hypothesis + + def _get_hypotheses(self) -> PriorityQueue: + queue = PriorityQueue(maxsize=1000) + + # add hypotheses starting before each word in each arc + for arc_index in range(len(self._word_graph.arcs)): + arc = self._word_graph.arcs[arc_index] + if not self._is_arc_pruned(arc): + word_graph_score = self._state_word_graph_scores[arc.prev_state] + arc.score + + for i in range(-1, len(arc.target_tokens) - 1): + esi = ( + self._state_ecm_score_infos[arc.prev_state] + if i == -1 + else self._arc_ecm_score_infos[arc_index][i] + ) + score = ( + (self.word_graph_weight * word_graph_score) + + (self.ecm_weight * -esi.scores[-1]) + + (self.word_graph_weight * self._rest_scores[arc.next_state]) + ) + queue.put(_Hypothesis(score, arc.next_state, arc_index, i)) + + # add hypotheses starting before each final state + for state in self._word_graph.final_states: + rest_score = self._rest_scores[state] + best_scores = self._state_best_scores[state] + + score = best_scores[-1] + (self.word_graph_weight * rest_score) + queue.put(_Hypothesis(score, state)) + + return queue + + def _is_arc_pruned(self, arc: WordGraphArc) -> bool: + return not arc.is_unknown and any(c < self.confidence_threshold for c in arc.confidences) + + def _build_correction_from_hypothesis( + self, + builder: TranslationResultBuilder, + prefix: Sequence[str], + is_last_word_complete: bool, + hypothesis: _Hypothesis, + ) -> None: + if hypothesis.start_arc_index == -1: + self._add_best_uncorrected_prefix_state(builder, len(prefix), hypothesis.start_state) + uncorrected_prefix_len = len(builder.target_tokens) + else: + self._add_best_uncorrected_prefix_sub_state( + builder, len(prefix), hypothesis.start_arc_index, hypothesis.start_arc_word_index + ) + first_arc = self._word_graph.arcs[hypothesis.start_arc_index] + uncorrected_prefix_len = ( + len(builder.target_tokens) - (len(first_arc.target_tokens) - hypothesis.start_arc_word_index) + 1 + ) + + alignment_cols_to_add_count = self._ecm.correct_prefix( + builder, uncorrected_prefix_len, prefix, is_last_word_complete + ) + + for arc in hypothesis.arcs: + self._update_correction_from_arc(builder, arc, alignment_cols_to_add_count) + alignment_cols_to_add_count = 0 + + def _add_best_uncorrected_prefix_state( + self, builder: TranslationResultBuilder, proc_prefix_pos: int, state: int + ) -> None: + arcs: List[WordGraphArc] = [] + + cur_state = state + cur_proc_prefix_pos = proc_prefix_pos + while cur_state != WORD_GRAPH_INITIAL_STATE: + arc_index = self._state_best_prev_arcs[cur_state][cur_proc_prefix_pos] + arc = self._word_graph.arcs[arc_index] + + for i in range(len(arc.target_tokens) - 1, -1, -1): + pred_prefix_words = self._arc_ecm_score_infos[arc_index][i].get_last_ins_prefix_word_from_esi() + cur_proc_prefix_pos = pred_prefix_words[cur_proc_prefix_pos] + + arcs.append(arc) + + cur_state = arc.prev_state + + for arc in reversed(arcs): + self._update_correction_from_arc(builder, arc, 0) + + def _add_best_uncorrected_prefix_sub_state( + self, builder: TranslationResultBuilder, proc_prefix_pos: int, arc_index: int, arc_word_index: int + ) -> None: + arc = self._word_graph.arcs[arc_index] + + cur_proc_prefix_pos = proc_prefix_pos + for i in range(arc_word_index, -1, -1): + pred_prefix_words = self._arc_ecm_score_infos[arc_index][i].get_last_ins_prefix_word_from_esi() + cur_proc_prefix_pos = pred_prefix_words[cur_proc_prefix_pos] + + self._add_best_uncorrected_prefix_state(builder, cur_proc_prefix_pos, arc.prev_state) + + self._update_correction_from_arc(builder, arc, 0) + + def _update_correction_from_arc( + self, builder: TranslationResultBuilder, arc: WordGraphArc, alignment_cols_to_add_count: int + ) -> None: + for i in range(len(arc.target_tokens)): + builder.append_token(arc.target_tokens[i], arc.sources[i], arc.confidences[i]) + + alignment = arc.alignment + if alignment_cols_to_add_count > 0: + new_alignment = WordAlignmentMatrix.from_word_pairs( + alignment.row_count, alignment.column_count + alignment_cols_to_add_count + ) + for j in range(alignment.column_count): + for i in range(alignment.row_count): + new_alignment[i, alignment_cols_to_add_count + j] = alignment[i, j] + alignment = new_alignment + + builder.mark_phrase(arc.source_segment_range, alignment) + + def _process_word_graph_prefix_diff(self, prefix_diff: List[str], is_last_word_complete: bool) -> None: + if len(prefix_diff) == 0: + return + + if not self._word_graph.is_empty: + prev_initial_esi = self._state_ecm_score_infos[WORD_GRAPH_INITIAL_STATE] + self._ecm.extend_initial_esi( + self._state_ecm_score_infos[WORD_GRAPH_INITIAL_STATE], prev_initial_esi, prefix_diff + ) + self._update_initial_state_best_scores() + + for arc_index in range(len(self._word_graph.arcs)): + arc = self._word_graph.arcs[arc_index] + + # extend ecm score info for each word of arc + prev_esi = self._state_ecm_score_infos[arc.prev_state] + esis = self._arc_ecm_score_infos[arc_index] + while len(esis) < len(arc.target_tokens): + esis.append(EcmScoreInfo()) + for i in range(len(arc.target_tokens)): + esi = esis[i] + self._ecm.extend_esi( + esi, prev_esi, "" if arc.is_unknown else arc.target_tokens[i], prefix_diff, is_last_word_complete + ) + prev_esi = esi + + # update best scores for the arc's successive state + self._update_state_best_scores(arc_index, len(prefix_diff)) + + +@dataclass +@total_ordering +class _Hypothesis: + score: float + start_state: int + start_arc_index: int = -1 + start_arc_word_index: int = -1 + arcs: List[WordGraphArc] = field(default_factory=list) + + def __lt__(self, other: _Hypothesis) -> bool: + return self.score > other.score + + def __le__(self, other: _Hypothesis) -> bool: + return self.score >= other.score + + def __gt__(self, other: _Hypothesis) -> bool: + return self.score < other.score + + def __ge__(self, other: _Hypothesis) -> bool: + return self.score <= other.score + + def clone(self) -> _Hypothesis: + return _Hypothesis( + self.score, self.start_state, self.start_arc_index, self.start_arc_word_index, list(self.arcs) + ) + + +def _add_or_replace(list: list, index: int, item: Any) -> None: + if index > len(list): + raise ValueError("index out of range") + + if index == len(list): + list.append(item) + else: + list[index] = item diff --git a/machine/translation/interactive_translation_engine.py b/machine/translation/interactive_translation_engine.py index cb0af6d..3258556 100644 --- a/machine/translation/interactive_translation_engine.py +++ b/machine/translation/interactive_translation_engine.py @@ -7,7 +7,7 @@ from .word_graph import WordGraph -class InterativeTranslationEngine(TranslationEngine): +class InteractiveTranslationEngine(TranslationEngine): @abstractmethod def get_word_graph(self, segment: Union[str, Sequence[str]]) -> WordGraph: ... @@ -20,5 +20,5 @@ def train_segment( ) -> None: ... - def __enter__(self) -> InterativeTranslationEngine: + def __enter__(self) -> InteractiveTranslationEngine: return self diff --git a/machine/translation/interactive_translation_model.py b/machine/translation/interactive_translation_model.py index 5ffb078..5a06660 100644 --- a/machine/translation/interactive_translation_model.py +++ b/machine/translation/interactive_translation_model.py @@ -1,10 +1,10 @@ from __future__ import annotations -from .interactive_translation_engine import InterativeTranslationEngine +from .interactive_translation_engine import InteractiveTranslationEngine from .translation_model import TranslationModel -class InteractiveTranslationModel(TranslationModel, InterativeTranslationEngine): +class InteractiveTranslationModel(TranslationModel, InteractiveTranslationEngine): def save(self) -> None: ... diff --git a/machine/translation/interactive_translator.py b/machine/translation/interactive_translator.py new file mode 100644 index 0000000..44212e0 --- /dev/null +++ b/machine/translation/interactive_translator.py @@ -0,0 +1,118 @@ +from typing import Generator, List, Sequence + +from ..annotations.range import Range +from ..tokenization.detokenizer import Detokenizer +from ..tokenization.range_tokenizer import RangeTokenizer +from ..tokenization.tokenization_utils import get_ranges, split +from .error_correction_model import ErrorCorrectionModel +from .error_correction_word_graph_processor import ErrorCorrectionWordGraphProcessor +from .interactive_translation_engine import InteractiveTranslationEngine +from .translation_constants import MAX_SEGMENT_LENGTH +from .translation_result import TranslationResult +from .word_graph import WordGraph + + +class InteractiveTranslator: + def __init__( + self, + ecm: ErrorCorrectionModel, + engine: InteractiveTranslationEngine, + target_tokenizer: RangeTokenizer[str, int, str], + target_detokenizer: Detokenizer[str, str], + segment: str, + word_graph: WordGraph, + sentence_start: bool, + ) -> None: + self._segment = segment + self._segment_word_ranges = list(get_ranges(self._segment, word_graph.source_tokens)) + self._engine = engine + self._target_tokenizer = target_tokenizer + self._prefix_word_ranges: List[Range[int]] = [] + self._prefix = "" + self._is_last_word_complete = True + self._word_graph_processor = ErrorCorrectionWordGraphProcessor(ecm, target_detokenizer, word_graph) + self._target_detokenizer = target_detokenizer + self._sentence_start = sentence_start + self._correct() + + @property + def target_detokenizer(self) -> Detokenizer[str, str]: + return self._target_detokenizer + + @property + def segment(self) -> str: + return self._segment + + @property + def segment_word_ranges(self) -> Sequence[Range[int]]: + return self._segment_word_ranges + + @property + def prefix(self) -> str: + return self._prefix + + @property + def prefix_word_ranges(self) -> Sequence[Range[int]]: + return self._prefix_word_ranges + + @property + def is_last_word_complete(self) -> bool: + return self._is_last_word_complete + + @property + def sentence_start(self) -> bool: + return self._sentence_start + + @property + def is_segment_valid(self) -> bool: + return len(self.segment_word_ranges) <= MAX_SEGMENT_LENGTH + + def set_prefix(self, prefix: str) -> None: + if self._prefix != prefix: + self._prefix = prefix + self._correct() + + def append_to_prefix(self, addition: str) -> None: + if addition != "": + self._prefix += addition + self._correct() + + def approve(self, aligned_only: bool) -> None: + if not self.is_segment_valid or len(self.prefix_word_ranges) > MAX_SEGMENT_LENGTH: + return + + segment_word_ranges = self._segment_word_ranges + if aligned_only: + best_result = next(self.get_current_results(), None) + if best_result is None: + return + segment_word_ranges = self._get_aligned_source_segment(best_result) + + if len(segment_word_ranges) > 0: + source_segment = self._segment[segment_word_ranges[0].start : segment_word_ranges[-1].end] + target_segment = self._prefix[self._prefix_word_ranges[0].start : self._prefix_word_ranges[-1].end] + self._engine.train_segment(source_segment, target_segment, self._sentence_start) + + def get_current_results(self) -> Generator[TranslationResult, None, None]: + return self._word_graph_processor.get_results() + + def _correct(self) -> None: + self._prefix_word_ranges = list(self._target_tokenizer.tokenize_as_ranges(self._prefix)) + self._is_last_word_complete = len(self._prefix_word_ranges) == 0 or self._prefix_word_ranges[-1].end < len( + self._prefix + ) + self._word_graph_processor.correct(split(self._prefix, self._prefix_word_ranges), self._is_last_word_complete) + + def _get_aligned_source_segment(self, result: TranslationResult) -> Sequence[Range[int]]: + source_length = 0 + for phrase in result.phrases: + if phrase.target_segment_cut > len(self._prefix_word_ranges): + break + if phrase.source_segment_range.end > source_length: + source_length = phrase.source_segment_range.end + + return ( + self._segment_word_ranges + if source_length == len(self._segment_word_ranges) + else self._segment_word_ranges[:source_length] + ) diff --git a/machine/translation/interactive_translator_factory.py b/machine/translation/interactive_translator_factory.py new file mode 100644 index 0000000..5472620 --- /dev/null +++ b/machine/translation/interactive_translator_factory.py @@ -0,0 +1,39 @@ +from ..tokenization.detokenizer import Detokenizer +from ..tokenization.range_tokenizer import RangeTokenizer +from ..tokenization.whitespace_detokenizer import WHITESPACE_DETOKENIZER +from ..tokenization.whitespace_tokenizer import WHITESPACE_TOKENIZER +from .error_correction_model import ErrorCorrectionModel +from .interactive_translation_engine import InteractiveTranslationEngine +from .interactive_translator import InteractiveTranslator + + +class InteractiveTranslatorFactory: + def __init__( + self, + engine: InteractiveTranslationEngine, + target_tokenizer: RangeTokenizer[str, int, str] = WHITESPACE_TOKENIZER, + target_detokenizer: Detokenizer[str, str] = WHITESPACE_DETOKENIZER, + ) -> None: + self._engine = engine + self._ecm = ErrorCorrectionModel() + self.target_tokenizer = target_tokenizer + self.target_detokenizer = target_detokenizer + + @property + def engine(self) -> InteractiveTranslationEngine: + return self._engine + + @property + def error_correction_model(self) -> ErrorCorrectionModel: + return self._ecm + + def create(self, segment: str, sentence_start: bool = True) -> InteractiveTranslator: + return InteractiveTranslator( + self._ecm, + self._engine, + self.target_tokenizer, + self.target_detokenizer, + segment, + self._engine.get_word_graph(segment), + sentence_start, + ) diff --git a/machine/translation/phrase_translation_suggester.py b/machine/translation/phrase_translation_suggester.py new file mode 100644 index 0000000..10364ed --- /dev/null +++ b/machine/translation/phrase_translation_suggester.py @@ -0,0 +1,127 @@ +from typing import Iterable, List, Optional, Sequence + +from ..utils.string_utils import is_punctuation +from .translation_result import TranslationResult +from .translation_sources import TranslationSources +from .translation_suggester import TranslationSuggester +from .translation_suggestion import TranslationSuggestion + + +class PhraseTranslationSuggester(TranslationSuggester): + def get_suggestions( + self, n: int, prefix_count: int, is_last_word_complete: bool, results: Iterable[TranslationResult] + ) -> Sequence[TranslationSuggestion]: + suggestions: List[TranslationSuggestion] = [] + for result in results: + starting_j = prefix_count + if not is_last_word_complete: + # if the prefix ends with a partial word and it has been completed, + # then make sure it is included as a suggestion, + # otherwise, don't return any suggestions + if TranslationSources.SMT in result.sources[starting_j - 1]: + starting_j -= 1 + else: + break + + k = 0 + while k < len(result.phrases) and result.phrases[k].target_segment_cut <= starting_j: + k += 1 + + suggestion_confidence = -1 + indices: List[int] = [] + for k in range(k, len(result.phrases)): + phrase = result.phrases[k] + phrase_confidence = 1.0 + ending_j = starting_j + for j in range(starting_j, phrase.target_segment_cut): + if result.sources[j] == TranslationSources.NONE: + phrase_confidence = 0.0 + break + + word = result.target_tokens[j] + if self.break_on_punctuation and all(is_punctuation(c) for c in word): + break + + phrase_confidence = min(phrase_confidence, result.confidences[j]) + if phrase_confidence < self.confidence_threshold: + break + + ending_j = j + 1 + + if phrase_confidence >= self.confidence_threshold: + suggestion_confidence = ( + phrase_confidence + if suggestion_confidence == -1 + else min(suggestion_confidence, phrase_confidence) + ) + + if starting_j == ending_j: + break + + for j in range(starting_j, ending_j): + indices.append(j) + + starting_j = phrase.target_segment_cut + else: + # hit a phrase with a low confidence, so don't include any more words in this suggestion + break + if suggestion_confidence == -1: + break + elif len(indices) == 0: + continue + + new_suggestion = TranslationSuggestion(result, indices, suggestion_confidence) + duplicate = False + new_suggestion_words: Optional[List[str]] = None + table: Optional[List[int]] = None + for suggestion in suggestions: + if len(suggestion.target_word_indices) >= len(new_suggestion.target_word_indices): + if new_suggestion_words is None: + new_suggestion_words = list(new_suggestion.target_words) + if table is None: + table = _compute_kmp_table(new_suggestion_words) + if _is_subsequence(table, new_suggestion_words, list(suggestion.target_words)): + duplicate = True + break + + if not duplicate: + suggestions.append(new_suggestion) + if len(suggestions) == n: + break + return suggestions + + +def _is_subsequence(table: List[int], new_suggestion: Sequence[str], suggestion: Sequence[str]) -> bool: + j = 0 + i = 0 + while i < len(suggestion): + if new_suggestion[j] == suggestion[i]: + j += 1 + i += 1 + if j == len(new_suggestion): + return True + elif i < len(suggestion) and new_suggestion[j] != suggestion[i]: + if j != 0: + j = table[j - 1] + else: + i += 1 + return False + + +def _compute_kmp_table(new_suggestion: Sequence[str]) -> List[int]: + table = [0] * len(new_suggestion) + length = 0 + i = 1 + table[0] = 0 + + while i < len(new_suggestion): + if new_suggestion[i] == new_suggestion[length]: + length += 1 + table[i] = length + i += 1 + elif length != 0: + length = table[length - 1] + else: + table[i] = length + i += 1 + return table diff --git a/machine/translation/segment_edit_distance.py b/machine/translation/segment_edit_distance.py new file mode 100644 index 0000000..4649bc8 --- /dev/null +++ b/machine/translation/segment_edit_distance.py @@ -0,0 +1,176 @@ +from typing import Iterable, List, Sequence, Tuple + +from .edit_distance import EditDistance +from .edit_operation import EditOperation +from .word_edit_distance import WordEditDistance + + +class SegmentEditDistance(EditDistance[Sequence[str], str]): + def __init__(self) -> None: + self._word_edit_distance = WordEditDistance() + + @property + def hit_cost(self) -> float: + return self._word_edit_distance.hit_cost + + @hit_cost.setter + def hit_cost(self, cost: float) -> None: + self._word_edit_distance.hit_cost = cost + + @property + def substitution_cost(self) -> float: + return self._word_edit_distance.substitution_cost + + @substitution_cost.setter + def substitution_cost(self, cost: float) -> None: + self._word_edit_distance.substitution_cost = cost + + @property + def insertion_cost(self) -> float: + return self._word_edit_distance.insertion_cost + + @insertion_cost.setter + def insertion_cost(self, cost: float) -> None: + self._word_edit_distance.insertion_cost = cost + + @property + def deletion_cost(self) -> float: + return self._word_edit_distance.deletion_cost + + @deletion_cost.setter + def deletion_cost(self, cost: float) -> None: + self._word_edit_distance.deletion_cost = cost + + def compute(self, x: Sequence[str], y: Sequence[str]) -> float: + dist, _ = self._compute_dist_matrix(x, y, is_last_item_complete=True, use_prefix_del_op=False) + return dist + + def compute_prefix( + self, x: Sequence[str], y: Sequence[str], is_last_item_complete: bool, use_prefix_del_op: bool + ) -> Tuple[float, Iterable[EditOperation], Iterable[EditOperation]]: + dist, dist_matrix = self._compute_dist_matrix(x, y, is_last_item_complete, use_prefix_del_op) + + i = len(x) + j = len(y) + ops: List[EditOperation] = [] + char_ops: Iterable[EditOperation] = [] + while i > 0 or j > 0: + _, i, j, op = self._process_dist_matrix_cell( + x, y, dist_matrix, use_prefix_del_op, j != len(y) or is_last_item_complete, i, j + ) + if op != EditOperation.PREFIX_DELETE: + ops.append(op) + + if j + 1 == len(y) and not is_last_item_complete and op == EditOperation.HIT: + _, char_ops = self._word_edit_distance.compute_prefix( + x[i], y[-1], is_last_item_complete=True, use_prefix_del_op=True + ) + + ops.reverse() + return (dist, ops, char_ops) + + def incr_compute_prefix_first_row( + self, scores: List[float], prev_scores: List[float], y_incr: Sequence[str] + ) -> None: + if scores is not prev_scores: + scores.clear() + scores.extend(prev_scores) + + start_pos = len(scores) + for j_incr in range(len(y_incr)): + j = start_pos + j_incr + if j == 0: + scores.append(self._get_insertion_cost(y_incr[j_incr])) + else: + scores.append(scores[j - 1] + self._get_insertion_cost(y_incr[j_incr])) + + def incr_compute_prefix( + self, + scores: List[float], + prev_scores: List[float], + x_word: str, + y_incr: Sequence[str], + is_last_item_complete: bool, + ) -> Iterable[EditOperation]: + x = [x_word] + y = [""] * (len(prev_scores) - 1) + for i in range(len(y_incr)): + y[len(prev_scores) - len(y_incr) - 1 + i] = y_incr[i] + + dist_matrix = self._init_dist_matrix(x, y) + + for j in range(len(prev_scores)): + dist_matrix[0][j] = prev_scores[j] + for j in range(len(scores)): + dist_matrix[1][j] = scores[j] + + while len(scores) < len(prev_scores): + scores.append(0.0) + + start_pos = len(prev_scores) - len(y_incr) + + ops: List[EditOperation] = [] + for j_incr in range(len(y_incr)): + j = start_pos + j_incr + dist, _, _, op = self._process_dist_matrix_cell( + x, y, dist_matrix, use_prefix_del_op=False, is_complete=j != len(y) or is_last_item_complete, i=1, j=j + ) + scores[j] = dist + dist_matrix[1][j] = dist + ops.append(op) + + return ops + + def _get_count(self, seq: Sequence[str]) -> int: + return len(seq) + + def _get_item(self, seq: Sequence[str], index: int) -> str: + return seq[index] + + def _get_hit_cost(self, x: str, y: str, is_complete: bool) -> float: + return self.hit_cost * len(y) + + def _get_substitution_cost(self, x: str, y: str, is_complete: bool) -> float: + if x == "": + return (self.substitution_cost * 0.99) * len(y) + + if is_complete: + _, ops = self._word_edit_distance.compute(x, y) + else: + _, ops = self._word_edit_distance.compute_prefix(x, y, is_last_item_complete=True, use_prefix_del_op=True) + + hit_count, ins_count, subst_count, del_count = _get_op_counts(ops) + return ( + self.hit_cost * hit_count + + self.insertion_cost * ins_count + + self.substitution_cost * subst_count + + self.deletion_cost * del_count + ) + + def _get_deletion_cost(self, x: str) -> float: + if x == "": + return self.deletion_cost + return self.deletion_cost * len(x) + + def _get_insertion_cost(self, y: str) -> float: + return self.insertion_cost * len(y) + + def _is_hit(self, x: str, y: str, is_complete: bool) -> bool: + return x == y or (not is_complete and x.startswith(y)) + + +def _get_op_counts(ops: Iterable[EditOperation]) -> Tuple[int, int, int, int]: + hit_count = 0 + ins_count = 0 + subst_count = 0 + del_count = 0 + for op in ops: + if op == EditOperation.HIT: + hit_count += 1 + elif op == EditOperation.INSERT: + ins_count += 1 + elif op == EditOperation.SUBSTITUTE: + subst_count += 1 + elif op == EditOperation.DELETE: + del_count += 1 + return (hit_count, ins_count, subst_count, del_count) diff --git a/machine/translation/translation_constants.py b/machine/translation/translation_constants.py new file mode 100644 index 0000000..7de2111 --- /dev/null +++ b/machine/translation/translation_constants.py @@ -0,0 +1 @@ +MAX_SEGMENT_LENGTH = 200 diff --git a/machine/translation/translation_result_builder.py b/machine/translation/translation_result_builder.py index 299120b..e3bcbd7 100644 --- a/machine/translation/translation_result_builder.py +++ b/machine/translation/translation_result_builder.py @@ -64,7 +64,7 @@ def correct_prefix( self, word_ops: Iterable[EditOperation], char_ops: Iterable[EditOperation], - prefix: List[str], + prefix: Sequence[str], is_last_word_complete: bool, ) -> int: alignment_cols_to_copy: List[int] = [] @@ -74,7 +74,7 @@ def correct_prefix( k = 0 for word_op in word_ops: if word_op == EditOperation.INSERT: - self._target_tokens.insert(j, prefix[i]) + self._target_tokens.insert(j, prefix[j]) self._sources.insert(j, TranslationSources.PREFIX) self._confidences.insert(j, -1) alignment_cols_to_copy.append(-1) @@ -103,9 +103,9 @@ def correct_prefix( k += 1 elif word_op in {EditOperation.HIT, EditOperation.SUBSTITUTE}: if word_op == EditOperation.SUBSTITUTE or j < len(prefix) - 1 or is_last_word_complete: - self._target_tokens[j] = prefix[i] + self._target_tokens[j] = prefix[j] else: - self._target_tokens[j] = self._correct_word(char_ops, self._target_tokens[j], prefix[i]) + self._target_tokens[j] = self._correct_word(char_ops, self._target_tokens[j], prefix[j]) if word_op == EditOperation.SUBSTITUTE: self._confidences[j] = -1 diff --git a/machine/translation/translation_suggester.py b/machine/translation/translation_suggester.py new file mode 100644 index 0000000..8d4e21d --- /dev/null +++ b/machine/translation/translation_suggester.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from typing import Iterable, Sequence + +from .translation_result import TranslationResult +from .translation_suggestion import TranslationSuggestion + + +class TranslationSuggester(ABC): + def __init__(self, confidence_threshold: float = 0, break_on_punctuation: bool = True) -> None: + self.confidence_threshold = confidence_threshold + self.break_on_punctuation = break_on_punctuation + + @abstractmethod + def get_suggestions( + self, n: int, prefix_count: int, is_last_word_complete: bool, results: Iterable[TranslationResult] + ) -> Sequence[TranslationSuggestion]: + ... diff --git a/machine/translation/translation_suggestion.py b/machine/translation/translation_suggestion.py new file mode 100644 index 0000000..43def27 --- /dev/null +++ b/machine/translation/translation_suggestion.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass, field +from typing import Iterable, Sequence + +from .translation_result import TranslationResult + + +@dataclass +class TranslationSuggestion: + result: TranslationResult + target_word_indices: Sequence[int] = field(default_factory=list) + confidence: float = 0 + + @property + def target_words(self) -> Iterable[str]: + return (self.result.target_tokens[i] for i in self.target_word_indices) diff --git a/machine/translation/word_edit_distance.py b/machine/translation/word_edit_distance.py new file mode 100644 index 0000000..3552452 --- /dev/null +++ b/machine/translation/word_edit_distance.py @@ -0,0 +1,55 @@ +from typing import Iterable, Tuple + +from .edit_distance import EditDistance +from .edit_operation import EditOperation + + +class WordEditDistance(EditDistance[str, str]): + def __init__(self) -> None: + self.hit_cost = 0.0 + self.insertion_cost = 0.0 + self.deletion_cost = 0.0 + self.substitution_cost = 0.0 + + def compute(self, x: str, y: str) -> Tuple[float, Iterable[EditOperation]]: + dist, dist_matrix = self._compute_dist_matrix(x, y, is_last_item_complete=True, use_prefix_del_op=False) + ops = self._get_operations( + x, + y, + dist_matrix, + is_last_item_complete=True, + use_prefix_del_op=False, + i=self._get_count(x), + j=self._get_count(y), + ) + return (dist, ops) + + def compute_prefix( + self, x: str, y: str, is_last_item_complete: bool, use_prefix_del_op: bool + ) -> Tuple[float, Iterable[EditOperation]]: + dist, dist_matrix = self._compute_dist_matrix(x, y, is_last_item_complete, use_prefix_del_op) + ops = self._get_operations( + x, y, dist_matrix, is_last_item_complete, use_prefix_del_op, i=self._get_count(x), j=self._get_count(y) + ) + return (dist, ops) + + def _get_count(self, seq: str) -> int: + return len(seq) + + def _get_item(self, seq: str, index: int) -> str: + return seq[index] + + def _get_hit_cost(self, x: str, y: str, is_complete: bool) -> float: + return self.hit_cost + + def _get_substitution_cost(self, x: str, y: str, is_complete: bool) -> float: + return self.substitution_cost + + def _get_deletion_cost(self, x: str) -> float: + return self.deletion_cost + + def _get_insertion_cost(self, y: str) -> float: + return self.insertion_cost + + def _is_hit(self, x: str, y: str, is_complete: bool) -> bool: + return x == y diff --git a/machine/translation/word_graph.py b/machine/translation/word_graph.py index 589e53f..d009ad7 100644 --- a/machine/translation/word_graph.py +++ b/machine/translation/word_graph.py @@ -1,9 +1,11 @@ from dataclasses import dataclass, field -from typing import AbstractSet, Dict, Iterable, List, Sequence +from typing import AbstractSet, Dict, Generator, Iterable, List, Sequence, Set, Tuple +from ..statistics.log_space import LOG_SPACE_ZERO, log_space_multiple from .word_graph_arc import WordGraphArc EMPTY_ARC_INDICES: Sequence[int] = [] +WORD_GRAPH_INITIAL_STATE = 0 @dataclass(frozen=True) @@ -15,12 +17,12 @@ class StateInfo: class WordGraph: def __init__( self, - source_words: Iterable[str], + source_tokens: Iterable[str], arcs: Iterable[WordGraphArc] = [], final_states: Iterable[int] = [], initial_state_score: float = 0, ) -> None: - self._source_words = list(source_words) + self._source_tokens = list(source_tokens) self._states: Dict[int, StateInfo] = {} arc_list: List[WordGraphArc] = [] max_state = -1 @@ -40,8 +42,8 @@ def __init__( self._initial_state_score = initial_state_score @property - def source_words(self) -> Sequence[str]: - return self._source_words + def source_tokens(self) -> Sequence[str]: + return self._source_tokens @property def initial_state_score(self) -> float: @@ -75,6 +77,71 @@ def get_next_arc_indices(self, state: int) -> Sequence[int]: return EMPTY_ARC_INDICES return state_info.next_arc_indices + def compute_rest_scores(self) -> List[float]: + rest_scores: List[float] = [LOG_SPACE_ZERO] * self._state_count + for state in self._final_states: + rest_scores[state] = self._initial_state_score + + for arc in reversed(self._arcs): + score = log_space_multiple(arc.score, rest_scores[arc.next_state]) + if score > rest_scores[arc.prev_state]: + rest_scores[arc.prev_state] = score + return rest_scores + + def get_best_path_from_state_to_final_state(self, state: int) -> Iterable[WordGraphArc]: + arcs = list(self._get_best_path_from_final_state_to_state(state)) + arcs.reverse() + return arcs + + def _get_best_path_from_final_state_to_state(self, state: int) -> Generator[WordGraphArc, None, None]: + prev_scores, state_best_pred_arcs = self._compute_prev_scores(state) + + best_final_state_score: float = LOG_SPACE_ZERO + best_final_state = WORD_GRAPH_INITIAL_STATE + for final_state in self._final_states: + score = prev_scores[final_state] + if best_final_state_score < score: + best_final_state = final_state + best_final_state_score = score + + if best_final_state in self._final_states: + cur_state = best_final_state + end = False + while not end: + if cur_state == state: + end = True + else: + arc_index = state_best_pred_arcs[cur_state] + arc = self.arcs[arc_index] + yield arc + cur_state = arc.prev_state + + def _compute_prev_scores(self, state: int) -> Tuple[List[float], List[int]]: + if self.is_empty: + return [], [] + + prev_scores: List[float] = [LOG_SPACE_ZERO] * self._state_count + state_best_prev_arcs = [0] * self._state_count + + if state == WORD_GRAPH_INITIAL_STATE: + prev_scores[WORD_GRAPH_INITIAL_STATE] = self.initial_state_score + else: + prev_scores[state] = 0 + + accessible_states: Set[int] = {state} + for arc_index in range(len(self.arcs)): + arc = self.arcs[arc_index] + if arc.prev_state in accessible_states: + score = log_space_multiple(arc.score, prev_scores[arc.prev_state]) + if score > prev_scores[arc.next_state]: + prev_scores[arc.next_state] = score + state_best_prev_arcs[arc.next_state] = arc_index + accessible_states.add(arc.next_state) + else: + if arc.next_state not in accessible_states: + prev_scores[arc.next_state] = LOG_SPACE_ZERO + return prev_scores, state_best_prev_arcs + def _get_or_create_state_info(self, state: int) -> StateInfo: state_info = self._states.get(state) if state_info is None: diff --git a/machine/translation/word_graph_arc.py b/machine/translation/word_graph_arc.py index eacbcf2..5f31fa9 100644 --- a/machine/translation/word_graph_arc.py +++ b/machine/translation/word_graph_arc.py @@ -11,7 +11,7 @@ def __init__( prev_state: int, next_state: int, score: float, - words: Iterable[str], + target_tokens: Iterable[str], alignment: WordAlignmentMatrix, source_segment_range: Range[int], sources: Iterable[TranslationSources], @@ -20,7 +20,7 @@ def __init__( self._prev_state = prev_state self._next_state = next_state self._score = score - self._words = list(words) + self._target_tokens = list(target_tokens) self._alignment = alignment self._source_segment_range = source_segment_range self._sources = list(sources) @@ -39,8 +39,8 @@ def score(self) -> float: return self._score @property - def words(self) -> Sequence[str]: - return self._words + def target_tokens(self) -> Sequence[str]: + return self._target_tokens @property def alignment(self) -> WordAlignmentMatrix: diff --git a/poetry.lock b/poetry.lock index 4b74a7b..07a7fce 100644 --- a/poetry.lock +++ b/poetry.lock @@ -722,6 +722,17 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "decoy" +version = "2.1.0" +description = "Opinionated mocking library for Python" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "decoy-2.1.0-py3-none-any.whl", hash = "sha256:f1823fbb85e2cd602bc3eb386fd1b6be1a293ea7dc3324f6ccf7af56b3b127d6"}, + {file = "decoy-2.1.0.tar.gz", hash = "sha256:c6a6c09d158bc77de693332f40c4992b61f0a9dcd72631033f161ee570ee2d88"}, +] + [[package]] name = "defusedxml" version = "0.7.1" @@ -1490,17 +1501,6 @@ files = [ {file = "mistune-2.0.5.tar.gz", hash = "sha256:0246113cb2492db875c6be56974a7c893333bf26cd92891c85f63151cee09d34"}, ] -[[package]] -name = "mockito" -version = "1.4.0" -description = "Spying framework" -optional = false -python-versions = ">=2.7" -files = [ - {file = "mockito-1.4.0-py3-none-any.whl", hash = "sha256:1719c6bec3523f9b465c86d247bb76027f53ab10f76b2a126dde409d0492fe3e"}, - {file = "mockito-1.4.0.tar.gz", hash = "sha256:409ab604c9ebe1bb7dc18ec6b0ed98a8ad5127b08273f5804b22f4d1b51e5222"}, -] - [[package]] name = "mpmath" version = "1.3.0" @@ -2563,20 +2563,6 @@ pytest = ">=4.6" [package.extras] testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] -[[package]] -name = "pytest-mockito" -version = "0.0.4" -description = "Base fixtures for mockito" -optional = false -python-versions = "*" -files = [ - {file = "pytest-mockito-0.0.4.tar.gz", hash = "sha256:40d40cdf118127dcb1e3c9e838b0d1c11d5197a23beaf10b6e3f42f9b6cb68a9"}, -] - -[package.dependencies] -mockito = ">=1.0.6" -pytest = ">=3" - [[package]] name = "python-dateutil" version = "2.8.2" @@ -3968,4 +3954,4 @@ thot = ["sil-thot"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.12" -content-hash = "5737e433ca2a038cc44c20a878e94dba57baf5dd1431985c0229da1b9f875fd1" +content-hash = "7d53fccb68beb7d88f67c7e02522bd23ac4210c877db3e55912b95c9a129ecce" diff --git a/pyproject.toml b/pyproject.toml index aea760a..6c1fb24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,11 +76,11 @@ black = "^23.3.0" # match the vscode extension flake8 = "^3.9.2" isort = "^5.9.3" pytest-cov = "^4.1.0" -pytest-mockito = "^0.0.4" ipykernel = "^6.7.0" jupyter = "^1.0.0" pandas = "^1.3.0" pyright = "^1.1.331" +decoy = "^2.1.0" [tool.poetry.group.gpu.dependencies] # Torch is not included in the normal install to allow the user to choose the versions of these dependencies when diff --git a/tests/jobs/test_nmt_engine_build_job.py b/tests/jobs/test_nmt_engine_build_job.py index 7212e3e..b3a1bf6 100644 --- a/tests/jobs/test_nmt_engine_build_job.py +++ b/tests/jobs/test_nmt_engine_build_job.py @@ -1,10 +1,10 @@ import json from contextlib import contextmanager from io import StringIO -from typing import Iterator, Type, TypeVar, cast +from typing import Iterator import pytest -from mockito import ANY, mock, verify, when +from decoy import Decoy, matchers from machine.annotations import Range from machine.corpora import DictionaryTextCorpus @@ -14,18 +14,18 @@ from machine.utils import CanceledError, ContextManagedGenerator -def test_run() -> None: - env = _TestEnvironment() +def test_run(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) env.job.run() - verify(env.engine, times=1).translate_batch(...) + decoy.verify(env.engine.translate_batch(matchers.Anything()), times=1) pretranslations = json.loads(env.target_pretranslations) assert len(pretranslations) == 1 assert pretranslations[0]["translation"] == "Please, I have booked a room." -def test_cancel() -> None: - env = _TestEnvironment() +def test_cancel(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) checker = _CancellationChecker(3) with pytest.raises(CanceledError): env.job.run(check_canceled=checker.check_canceled) @@ -34,26 +34,20 @@ def test_cancel() -> None: class _TestEnvironment: - def __init__(self) -> None: + def __init__(self, decoy: Decoy) -> None: config = {"src_lang": "es", "trg_lang": "en", "batch_size": 100} - self.source_tokenizer_trainer = _mock(Trainer) - when(self.source_tokenizer_trainer).train(check_canceled=ANY).thenReturn() - when(self.source_tokenizer_trainer).save().thenReturn() + self.source_tokenizer_trainer = decoy.mock(cls=Trainer) + self.target_tokenizer_trainer = decoy.mock(cls=Trainer) - self.target_tokenizer_trainer = _mock(Trainer) - when(self.target_tokenizer_trainer).train(check_canceled=ANY).thenReturn() - when(self.target_tokenizer_trainer).save().thenReturn() - - self.model_trainer = _mock(Trainer) - when(self.model_trainer).train(progress=ANY, check_canceled=ANY).thenReturn() - when(self.model_trainer).save().thenReturn() + self.model_trainer = decoy.mock(cls=Trainer) stats = TrainStats() stats.train_corpus_size = 3 stats.metrics["bleu"] = 30.0 - setattr(self.model_trainer, "stats", stats) + decoy.when(self.model_trainer.stats).then_return(stats) - self.engine = _mock(TranslationEngine) - when(self.engine).translate_batch(ANY).thenReturn( + self.engine = decoy.mock(cls=TranslationEngine) + decoy.when(self.engine.__enter__()).then_return(self.engine) + decoy.when(self.engine.translate_batch(matchers.Anything())).then_return( [ TranslationResult( translation="Please, I have booked a room.", @@ -78,20 +72,23 @@ def __init__(self) -> None: ] ) - self.nmt_model_factory = _mock(NmtModelFactory) - setattr(self.nmt_model_factory, "train_tokenizer", True) - when(self.nmt_model_factory).init().thenReturn() - when(self.nmt_model_factory).create_source_tokenizer_trainer(ANY).thenReturn(self.source_tokenizer_trainer) - when(self.nmt_model_factory).create_target_tokenizer_trainer(ANY).thenReturn(self.target_tokenizer_trainer) - when(self.nmt_model_factory).create_model_trainer(ANY).thenReturn(self.model_trainer) - when(self.nmt_model_factory).create_engine().thenReturn(self.engine) - - self.shared_file_service = _mock(SharedFileService) - when(self.shared_file_service).create_source_corpus().thenReturn(DictionaryTextCorpus()) - when(self.shared_file_service).create_target_corpus().thenReturn(DictionaryTextCorpus()) - when(self.shared_file_service).exists_source_corpus().thenReturn(True) - when(self.shared_file_service).exists_target_corpus().thenReturn(True) - when(self.shared_file_service).get_source_pretranslations().thenAnswer( + self.nmt_model_factory = decoy.mock(cls=NmtModelFactory) + decoy.when(self.nmt_model_factory.train_tokenizer).then_return(True) + decoy.when(self.nmt_model_factory.create_source_tokenizer_trainer(matchers.Anything())).then_return( + self.source_tokenizer_trainer + ) + decoy.when(self.nmt_model_factory.create_target_tokenizer_trainer(matchers.Anything())).then_return( + self.target_tokenizer_trainer + ) + decoy.when(self.nmt_model_factory.create_model_trainer(matchers.Anything())).then_return(self.model_trainer) + decoy.when(self.nmt_model_factory.create_engine()).then_return(self.engine) + + self.shared_file_service = decoy.mock(cls=SharedFileService) + decoy.when(self.shared_file_service.create_source_corpus()).then_return(DictionaryTextCorpus()) + decoy.when(self.shared_file_service.create_target_corpus()).then_return(DictionaryTextCorpus()) + decoy.when(self.shared_file_service.exists_source_corpus()).then_return(True) + decoy.when(self.shared_file_service.exists_target_corpus()).then_return(True) + decoy.when(self.shared_file_service.get_source_pretranslations()).then_do( lambda: ContextManagedGenerator( ( pi @@ -117,24 +114,13 @@ def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[Pretran file.write("\n]\n") env.target_pretranslations = file.getvalue() - when(self.shared_file_service).open_target_pretranslation_writer().thenAnswer( + decoy.when(self.shared_file_service.open_target_pretranslation_writer()).then_do( lambda: open_target_pretranslation_writer(self) ) self.job = NmtEngineBuildJob(config, self.nmt_model_factory, self.shared_file_service) -T = TypeVar("T") - - -def _mock(class_to_mock: Type[T]) -> T: - o = cast(T, mock(class_to_mock)) - if hasattr(class_to_mock, "__enter__"): - when(o).__enter__().thenReturn(o) - when(o).__exit__(ANY, ANY, ANY).thenReturn() - return o - - class _CancellationChecker: def __init__(self, raise_count: int) -> None: self._call_count = 0 diff --git a/tests/testutils/__init__.py b/tests/testutils/__init__.py index 622131f..e77f9b4 100644 --- a/tests/testutils/__init__.py +++ b/tests/testutils/__init__.py @@ -1,18 +1,3 @@ -from abc import ABC from pathlib import Path -from typing import Type, TypeVar, cast - -T = TypeVar("T", bound=ABC) TEST_DATA_PATH = Path(__file__).parent / "data" - - -def make_concrete(abc_class: Type[T]) -> Type[T]: - if "__abstractmethods__" not in abc_class.__dict__: - return abc_class - new_dict = abc_class.__dict__.copy() - for abstractmethod in abc_class.__abstractmethods__: - # replace each abc method or property with an identity function: - new_dict[abstractmethod] = lambda x, *args, **kw: (x, args, kw) - # creates a new class, with the overriden ABCs: - return cast(Type[T], type("dummy_concrete_%s" % abc_class.__name__, (abc_class,), new_dict)) # type: ignore diff --git a/tests/translation/test_error_correction_model.py b/tests/translation/test_error_correction_model.py new file mode 100644 index 0000000..09e6476 --- /dev/null +++ b/tests/translation/test_error_correction_model.py @@ -0,0 +1,186 @@ +from machine.annotations import Range +from machine.translation import ErrorCorrectionModel, TranslationResultBuilder, TranslationSources, WordAlignmentMatrix + +ECM = ErrorCorrectionModel() + + +def test_correct_prefix_empty_uncorrected_prefix_appends_prefix() -> None: + builder = _create_result_builder("") + + prefix = "this is a test".split() + assert ( + ECM.correct_prefix( + builder, uncorrected_prefix_len=len(builder.target_tokens), prefix=prefix, is_last_word_complete=True + ) + == 4 + ) + assert len(builder.confidences) == len(prefix) + assert builder.target_tokens == prefix + assert len(builder.phrases) == 0 + + +def test_correct_prefix_new_end_word_inserts_word_at_end() -> None: + builder = _create_result_builder("this is a", 2, 3) + + prefix = "this is a test".split() + assert ( + ECM.correct_prefix( + builder, uncorrected_prefix_len=len(builder.target_tokens), prefix=prefix, is_last_word_complete=True + ) + == 1 + ) + assert len(builder.confidences) == len(prefix) + assert builder.target_tokens == prefix + assert len(builder.phrases) == 2 + assert builder.phrases[0].target_cut == 2 + assert builder.phrases[0].alignment.column_count == 2 + assert builder.phrases[1].target_cut == 3 + assert builder.phrases[1].alignment.column_count == 1 + + +def test_correct_prefix_substring_uncorrected_prefix_new_end_word_inserts_word_at_end() -> None: + builder = _create_result_builder("this is a and only a test", 2, 3, 5, 7) + + prefix = "this is a test".split() + assert ECM.correct_prefix(builder, uncorrected_prefix_len=3, prefix=prefix, is_last_word_complete=True) == 0 + assert len(builder.confidences) == 8 + assert builder.target_tokens == "this is a test and only a test".split() + assert len(builder.phrases) == 4 + assert builder.phrases[0].target_cut == 2 + assert builder.phrases[0].alignment.column_count == 2 + assert builder.phrases[1].target_cut == 3 + assert builder.phrases[1].alignment.column_count == 1 + assert builder.phrases[2].target_cut == 6 + assert builder.phrases[2].alignment.column_count == 3 + assert builder.phrases[3].target_cut == 8 + assert builder.phrases[3].alignment.column_count == 2 + + +def test_correct_prefix_new_middle_word_inserts_word() -> None: + builder = _create_result_builder("this is a test", 2, 4) + + prefix = "this is , a test".split() + assert ( + ECM.correct_prefix( + builder, uncorrected_prefix_len=len(builder.target_tokens), prefix=prefix, is_last_word_complete=True + ) + == 0 + ) + assert len(builder.confidences) == len(prefix) + assert builder.target_tokens == prefix + assert len(builder.phrases) == 2 + assert builder.phrases[0].target_cut == 2 + assert builder.phrases[0].alignment.column_count == 2 + assert builder.phrases[1].target_cut == 5 + assert builder.phrases[1].alignment.column_count == 3 + + +def test_correct_prefix_new_start_word_inserts_word_at_beginning() -> None: + builder = _create_result_builder("this is a test", 2, 4) + + prefix = "yes this is a test".split() + assert ( + ECM.correct_prefix( + builder, uncorrected_prefix_len=len(builder.target_tokens), prefix=prefix, is_last_word_complete=True + ) + == 0 + ) + assert len(builder.confidences) == len(prefix) + assert builder.target_tokens == prefix + assert len(builder.phrases) == 2 + assert builder.phrases[0].target_cut == 3 + assert builder.phrases[0].alignment.column_count == 3 + assert builder.phrases[1].target_cut == 5 + assert builder.phrases[1].alignment.column_count == 2 + + +def test_correct_prefix_missing_end_word_deletes_world_at_end() -> None: + builder = _create_result_builder("this is a test", 2, 4) + + prefix = "this is a".split() + assert ( + ECM.correct_prefix( + builder, uncorrected_prefix_len=len(builder.target_tokens), prefix=prefix, is_last_word_complete=True + ) + == 0 + ) + assert len(builder.confidences) == len(prefix) + assert builder.target_tokens == prefix + assert len(builder.phrases) == 2 + assert builder.phrases[0].target_cut == 2 + assert builder.phrases[0].alignment.column_count == 2 + assert builder.phrases[1].target_cut == 3 + assert builder.phrases[1].alignment.column_count == 1 + + +def test_correct_prefix_substring_uncorrected_prefix_missing_end_word_deletes_word_at_end() -> None: + builder = _create_result_builder("this is a test and only a test", 2, 4, 6, 8) + + prefix = "this is a".split() + assert ECM.correct_prefix(builder, uncorrected_prefix_len=4, prefix=prefix, is_last_word_complete=True) == 0 + assert len(builder.confidences) == 7 + assert builder.target_tokens == "this is a and only a test".split() + assert len(builder.phrases) == 4 + assert builder.phrases[0].target_cut == 2 + assert builder.phrases[0].alignment.column_count == 2 + assert builder.phrases[1].target_cut == 3 + assert builder.phrases[1].alignment.column_count == 1 + assert builder.phrases[2].target_cut == 5 + assert builder.phrases[2].alignment.column_count == 2 + assert builder.phrases[3].target_cut == 7 + assert builder.phrases[3].alignment.column_count == 2 + + +def test_correct_prefix_missing_middle_word_deletes_word() -> None: + builder = _create_result_builder("this is a test", 2, 4) + + prefix = "this a test".split() + assert ( + ECM.correct_prefix( + builder, uncorrected_prefix_len=len(builder.target_tokens), prefix=prefix, is_last_word_complete=True + ) + == 0 + ) + assert len(builder.confidences) == len(prefix) + assert builder.target_tokens == prefix + assert len(builder.phrases) == 2 + assert builder.phrases[0].target_cut == 1 + assert builder.phrases[0].alignment.column_count == 1 + assert builder.phrases[1].target_cut == 3 + assert builder.phrases[1].alignment.column_count == 2 + + +def test_correct_prefix_missing_start_word_deletes_word_at_beginning() -> None: + builder = _create_result_builder("yes this is a test", 3, 5) + + prefix = "this is a test".split() + assert ( + ECM.correct_prefix( + builder, uncorrected_prefix_len=len(builder.target_tokens), prefix=prefix, is_last_word_complete=True + ) + == 0 + ) + assert len(builder.confidences) == len(prefix) + assert builder.target_tokens == prefix + assert len(builder.phrases) == 2 + assert builder.phrases[0].target_cut == 2 + assert builder.phrases[0].alignment.column_count == 2 + assert builder.phrases[1].target_cut == 4 + assert builder.phrases[1].alignment.column_count == 2 + + +def _create_result_builder(target: str, *cuts: int) -> TranslationResultBuilder: + builder = TranslationResultBuilder("esto es una prueba".split()) + if target != "": + i = 0 + k = 0 + words = target.split() + for j in range(len(words)): + builder.append_token(words[j], TranslationSources.SMT, 1) + cut = j + 1 + if k < len(cuts) and cuts[k] == cut: + length = cut - i + builder.mark_phrase(Range.create(i, cut), WordAlignmentMatrix.from_word_pairs(length, length)) + k += 1 + i = cut + return builder diff --git a/tests/translation/test_interactive_translator.py b/tests/translation/test_interactive_translator.py new file mode 100644 index 0000000..f118e9a --- /dev/null +++ b/tests/translation/test_interactive_translator.py @@ -0,0 +1,526 @@ +from itertools import islice + +from decoy import Decoy + +from machine.annotations import Range +from machine.tokenization import WHITESPACE_TOKENIZER +from machine.translation import ( + MAX_SEGMENT_LENGTH, + InteractiveTranslationEngine, + InteractiveTranslator, + InteractiveTranslatorFactory, + TranslationSources, + WordAlignmentMatrix, + WordGraph, + WordGraphArc, +) + +_SOURCE_SEGMENT = "En el principio la Palabra ya existía ." + + +def test_get_current_results_empty_prefix(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + translator = env.create_translator() + + result = next(translator.get_current_results()) + assert result.translation == "In the beginning the Word already existía ." + + +def test_get_current_results_append_complete_word(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + translator = env.create_translator() + translator.append_to_prefix("In ") + + result = next(translator.get_current_results()) + assert result.translation == "In the beginning the Word already existía ." + + +def test_get_current_results_append_partial_word(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + translator = env.create_translator() + translator.append_to_prefix("In ") + translator.append_to_prefix("t") + + result = next(translator.get_current_results()) + assert result.translation == "In the beginning the Word already existía ." + + +def test_get_current_results_remove_word(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + translator = env.create_translator() + translator.append_to_prefix("In the beginning ") + translator.set_prefix("In the ") + + result = next(translator.get_current_results()) + assert result.translation == "In the beginning the Word already existía ." + + +def test_get_current_results_remove_all_words(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + translator = env.create_translator() + translator.append_to_prefix("In the beginning ") + translator.set_prefix("") + + result = next(translator.get_current_results()) + assert result.translation == "In the beginning the Word already existía ." + + +def test_is_source_segment_valid_valid(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + translator = env.create_translator() + + assert translator.is_segment_valid + + +def test_is_source_segment_valid_invalid(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + source_segment = "" + for _ in range(MAX_SEGMENT_LENGTH): + source_segment += "word " + source_segment += "." + decoy.when(env.engine.get_word_graph(source_segment)).then_return( + WordGraph(WHITESPACE_TOKENIZER.tokenize(source_segment)) + ) + translator = env.create_translator(source_segment) + + assert not translator.is_segment_valid + + +def test_approve_aligned_only(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + translator = env.create_translator() + translator.append_to_prefix("In the beginning ") + translator.approve(aligned_only=True) + + decoy.verify(env.engine.train_segment("En el principio", "In the beginning", sentence_start=True), times=1) + + translator.append_to_prefix("the Word already existed .") + translator.approve(aligned_only=True) + + decoy.verify( + env.engine.train_segment( + "En el principio la Palabra ya existía .", + "In the beginning the Word already existed .", + sentence_start=True, + ), + times=1, + ) + + +def test_approve_whole_source_segment(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + translator = env.create_translator() + translator.append_to_prefix("In the beginning ") + translator.approve(aligned_only=False) + + decoy.verify( + env.engine.train_segment("En el principio la Palabra ya existía .", "In the beginning", sentence_start=True), + times=1, + ) + + +def test_get_current_results_multiple_suggestions_empty_prefix(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + env.use_simple_word_graph() + translator = env.create_translator() + + results = list(islice(translator.get_current_results(), 2)) + assert results[0].translation == "In the beginning the Word already existía ." + assert results[1].translation == "In the start the Word already existía ." + + +def test_get_current_results_multiple_suggestions_nonempty_prefix(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + env.use_simple_word_graph() + translator = env.create_translator() + translator.append_to_prefix("In the ") + + results = list(islice(translator.get_current_results(), 2)) + assert results[0].translation == "In the beginning the Word already existía ." + assert results[1].translation == "In the start the Word already existía ." + + translator.append_to_prefix("beginning") + + results = list(islice(translator.get_current_results(), 2)) + assert results[0].translation == "In the beginning the Word already existía ." + assert results[1].translation == "In the beginning his Word already existía ." + + +class _TestEnvironment: + def __init__(self, decoy: Decoy) -> None: + self._decoy = decoy + self.engine = decoy.mock(cls=InteractiveTranslationEngine) + + word_graph = WordGraph( + source_tokens=WHITESPACE_TOKENIZER.tokenize(_SOURCE_SEGMENT), + arcs=[ + WordGraphArc( + 0, + 1, + -22.4162, + ["now", "it"], + WordAlignmentMatrix.from_word_pairs(2, 2, [(0, 1), (1, 0)]), + Range.create(0, 2), + [TranslationSources.SMT, TranslationSources.SMT], + [0.00006755903, 0.0116618536], + ), + WordGraphArc( + 0, + 2, + -23.5761, + ["In", "your"], + WordAlignmentMatrix.from_word_pairs(2, 2, [(0, 0), (1, 1)]), + Range.create(0, 2), + [TranslationSources.SMT, TranslationSources.SMT], + [0.355293363, 0.0000941652761], + ), + WordGraphArc( + 0, + 3, + -11.1167, + ["In", "the"], + WordAlignmentMatrix.from_word_pairs(2, 2, [(0, 0), (1, 1)]), + Range.create(0, 2), + [TranslationSources.SMT, TranslationSources.SMT], + [0.355293363, 0.5004668], + ), + WordGraphArc( + 0, + 4, + -13.7804, + ["In"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(0, 1), + [TranslationSources.SMT], + [0.355293363], + ), + WordGraphArc( + 3, + 5, + -12.9695, + ["beginning"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(2, 3), + [TranslationSources.SMT], + [0.348795831], + ), + WordGraphArc( + 4, + 5, + -7.68319, + ["the", "beginning"], + WordAlignmentMatrix.from_word_pairs(2, 2, [(0, 0), (1, 1)]), + Range.create(1, 3), + [TranslationSources.SMT, TranslationSources.SMT], + [0.5004668, 0.348795831], + ), + WordGraphArc( + 4, + 3, + -14.4373, + ["the"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(1, 2), + [TranslationSources.SMT], + [0.5004668], + ), + WordGraphArc( + 5, + 6, + -19.3042, + ["his", "Word"], + WordAlignmentMatrix.from_word_pairs(2, 2, [(0, 0), (1, 1)]), + Range.create(3, 5), + [TranslationSources.SMT, TranslationSources.SMT], + [0.00347203249, 0.477621228], + ), + WordGraphArc( + 5, + 7, + -8.49148, + ["the", "Word"], + WordAlignmentMatrix.from_word_pairs(2, 2, [(0, 0), (1, 1)]), + Range.create(3, 5), + [TranslationSources.SMT, TranslationSources.SMT], + [0.346071422, 0.477621228], + ), + WordGraphArc( + 1, + 8, + -15.2926, + ["beginning"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(2, 3), + [TranslationSources.SMT], + [0.348795831], + ), + WordGraphArc( + 2, + 9, + -15.2926, + ["beginning"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(2, 3), + [TranslationSources.SMT], + [0.348795831], + ), + WordGraphArc( + 7, + 10, + -14.3453, + ["already"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(5, 6), + [TranslationSources.SMT], + [0.2259867], + ), + WordGraphArc( + 8, + 6, + -19.3042, + ["his", "Word"], + WordAlignmentMatrix.from_word_pairs(2, 2, [(0, 0), (1, 1)]), + Range.create(3, 5), + [TranslationSources.SMT, TranslationSources.SMT], + [0.00347203249, 0.477621228], + ), + WordGraphArc( + 8, + 7, + -8.49148, + ["the", "Word"], + WordAlignmentMatrix.from_word_pairs(2, 2, [(0, 0), (1, 1)]), + Range.create(3, 5), + [TranslationSources.SMT, TranslationSources.SMT], + [0.346071422, 0.477621228], + ), + WordGraphArc( + 9, + 6, + -19.3042, + ["his", "Word"], + WordAlignmentMatrix.from_word_pairs(2, 2, [(0, 0), (1, 1)]), + Range.create(3, 5), + [TranslationSources.SMT, TranslationSources.SMT], + [0.00347203249, 0.477621228], + ), + WordGraphArc( + 9, + 7, + -8.49148, + ["the", "Word"], + WordAlignmentMatrix.from_word_pairs(2, 2, [(0, 0), (1, 1)]), + Range.create(3, 5), + [TranslationSources.SMT, TranslationSources.SMT], + [0.346071422, 0.477621228], + ), + WordGraphArc( + 6, + 10, + -14.0526, + ["already"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(5, 6), + [TranslationSources.SMT], + [0.2259867], + ), + WordGraphArc( + 10, + 11, + 51.1117, + ["existía"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(6, 7), + [TranslationSources.NONE], + [0.0], + ), + WordGraphArc( + 11, + 12, + -29.0049, + ["you", "."], + WordAlignmentMatrix.from_word_pairs(1, 2, [(0, 1)]), + Range.create(7, 8), + [TranslationSources.SMT, TranslationSources.SMT], + [0.005803475, 0.317073762], + ), + WordGraphArc( + 11, + 13, + -27.7143, + ["to"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(7, 8), + [TranslationSources.SMT], + [0.038961038], + ), + WordGraphArc( + 11, + 14, + -30.0868, + [".", "‘"], + WordAlignmentMatrix.from_word_pairs(1, 2, [(0, 0)]), + Range.create(7, 8), + [TranslationSources.SMT, TranslationSources.SMT], + [0.317073762, 0.06190489], + ), + WordGraphArc( + 11, + 15, + -30.1586, + [".", "he"], + WordAlignmentMatrix.from_word_pairs(1, 2, [(0, 0)]), + Range.create(7, 8), + [TranslationSources.SMT, TranslationSources.SMT], + [0.317073762, 0.06702433], + ), + WordGraphArc( + 11, + 16, + -28.2444, + [".", "the"], + WordAlignmentMatrix.from_word_pairs(1, 2, [(0, 0)]), + Range.create(7, 8), + [TranslationSources.SMT, TranslationSources.SMT], + [0.317073762, 0.115540564], + ), + WordGraphArc( + 11, + 17, + -23.8056, + ["and"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(7, 8), + [TranslationSources.SMT], + [0.08047272], + ), + WordGraphArc( + 11, + 18, + -23.5842, + ["the"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(7, 8), + [TranslationSources.SMT], + [0.09361572], + ), + WordGraphArc( + 11, + 19, + -18.8988, + [","], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(7, 8), + [TranslationSources.SMT], + [0.1428188], + ), + WordGraphArc( + 11, + 20, + -11.9218, + [".", "’"], + WordAlignmentMatrix.from_word_pairs(1, 2, [(0, 0)]), + Range.create(7, 8), + [TranslationSources.SMT, TranslationSources.SMT], + [0.317073762, 0.018057242], + ), + WordGraphArc( + 11, + 21, + -3.51852, + ["."], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(7, 8), + [TranslationSources.SMT], + [0.317073762], + ), + ], + final_states=[12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + initial_state_score=-191.0998, + ) + + decoy.when(self.engine.get_word_graph(_SOURCE_SEGMENT)).then_return(word_graph) + + self._factory = InteractiveTranslatorFactory(self.engine) + + def use_simple_word_graph(self) -> None: + word_graph = WordGraph( + source_tokens=WHITESPACE_TOKENIZER.tokenize(_SOURCE_SEGMENT), + arcs=[ + WordGraphArc( + 0, + 1, + -10, + ["In", "the", "beginning"], + WordAlignmentMatrix.from_word_pairs(3, 3, [(0, 0), (1, 1), (2, 2)]), + Range.create(0, 3), + [TranslationSources.SMT, TranslationSources.SMT, TranslationSources.SMT], + [0.5, 0.5, 0.5], + ), + WordGraphArc( + 0, + 1, + -11, + ["In", "the", "start"], + WordAlignmentMatrix.from_word_pairs(3, 3, [(0, 0), (1, 1), (2, 2)]), + Range.create(0, 3), + [TranslationSources.SMT, TranslationSources.SMT, TranslationSources.SMT], + [0.5, 0.5, 0.4], + ), + WordGraphArc( + 1, + 2, + -10, + ["the", "Word"], + WordAlignmentMatrix.from_word_pairs(2, 2, [(0, 0), (1, 1)]), + Range.create(3, 5), + [TranslationSources.SMT, TranslationSources.SMT], + [0.5, 0.5], + ), + WordGraphArc( + 1, + 2, + -11, + ["his", "Word"], + WordAlignmentMatrix.from_word_pairs(2, 2, [(0, 0), (1, 1)]), + Range.create(3, 5), + [TranslationSources.SMT, TranslationSources.SMT], + [0.4, 0.5], + ), + WordGraphArc( + 2, + 3, + -10, + ["already"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(5, 6), + [TranslationSources.SMT], + [0.5], + ), + WordGraphArc( + 3, + 4, + 50, + ["existía"], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(6, 7), + [TranslationSources.NONE], + [0.0], + ), + WordGraphArc( + 4, + 5, + -10, + ["."], + WordAlignmentMatrix.from_word_pairs(1, 1, [(0, 0)]), + Range.create(7, 8), + [TranslationSources.SMT], + [0.5], + ), + ], + final_states=[5], + ) + self._decoy.when(self.engine.get_word_graph(_SOURCE_SEGMENT)).then_return(word_graph) + + def create_translator(self, segment: str = _SOURCE_SEGMENT) -> InteractiveTranslator: + return self._factory.create(segment) diff --git a/tests/translation/test_phrase_translation_suggester.py b/tests/translation/test_phrase_translation_suggester.py new file mode 100644 index 0000000..da09d6d --- /dev/null +++ b/tests/translation/test_phrase_translation_suggester.py @@ -0,0 +1,307 @@ +from typing import List + +from machine.annotations import Range +from machine.translation import ( + PhraseTranslationSuggester, + TranslationResult, + TranslationResultBuilder, + TranslationSources, + WordAlignmentMatrix, +) + + +def test_get_suggestions_punctuation() -> None: + builder = TranslationResultBuilder(["esto", "es", "una", "prueba", "."]) + builder.append_token("this", TranslationSources.SMT, 0.5) + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.append_token("a", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 3), + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=3, set_values=[(0, 0), (1, 1), (2, 2)]), + ) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 5), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + + suggester = PhraseTranslationSuggester(confidence_threshold=0.2) + suggestions = suggester.get_suggestions( + n=1, prefix_count=0, is_last_word_complete=True, results=[builder.to_result()] + ) + assert list(suggestions[0].target_words) == ["this", "is", "a", "test"] + + +def test_get_suggestions_untranslated_word() -> None: + builder = TranslationResultBuilder(["esto", "es", "una", "prueba", "."]) + builder.append_token("this", TranslationSources.SMT, 0.5) + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 2), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + builder.append_token("a", TranslationSources.NONE, 0) + builder.mark_phrase( + Range.create(2, 3), + WordAlignmentMatrix.from_word_pairs(row_count=1, column_count=1, set_values=[(0, 0)]), + ) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 5), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + + suggester = PhraseTranslationSuggester(confidence_threshold=0.2) + suggestions = suggester.get_suggestions( + n=1, prefix_count=0, is_last_word_complete=True, results=[builder.to_result()] + ) + assert list(suggestions[0].target_words) == ["this", "is"] + + +def test_get_suggestions_prefix_incomplete_word() -> None: + builder = TranslationResultBuilder(["esto", "es", "una", "prueba", "."]) + builder.append_token("this", TranslationSources.SMT | TranslationSources.PREFIX, 0.5) + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.append_token("a", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 3), + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=3, set_values=[(0, 0), (1, 1), (2, 2)]), + ) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 5), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + + suggester = PhraseTranslationSuggester(confidence_threshold=0.2) + suggestions = suggester.get_suggestions( + n=1, prefix_count=1, is_last_word_complete=False, results=[builder.to_result()] + ) + assert list(suggestions[0].target_words) == ["this", "is", "a", "test"] + + +def test_get_suggestions_prefix_complete_word() -> None: + builder = TranslationResultBuilder(["esto", "es", "una", "prueba", "."]) + builder.append_token("this", TranslationSources.SMT | TranslationSources.PREFIX, 0.5) + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.append_token("a", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 3), + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=3, set_values=[(0, 0), (1, 1), (2, 2)]), + ) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 5), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + + suggester = PhraseTranslationSuggester(confidence_threshold=0.2) + suggestions = suggester.get_suggestions( + n=1, prefix_count=1, is_last_word_complete=True, results=[builder.to_result()] + ) + assert list(suggestions[0].target_words) == ["is", "a", "test"] + + +def test_get_suggestions_prefix_partial_word() -> None: + builder = TranslationResultBuilder(["esto", "es", "una", "prueba", "."]) + builder.append_token("te", TranslationSources.PREFIX, -1) + builder.append_token("this", TranslationSources.SMT, 0.5) + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.append_token("a", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 3), + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=4, set_values=[(0, 1), (1, 2), (2, 3)]), + ) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 5), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + + suggester = PhraseTranslationSuggester(confidence_threshold=0.2) + suggestions = suggester.get_suggestions( + n=1, prefix_count=1, is_last_word_complete=False, results=[builder.to_result()] + ) + assert suggestions == [] + + +def test_get_suggestions_multiple() -> None: + results: List[TranslationResult] = [] + builder = TranslationResultBuilder(["esto", "es", "una", "prueba", "."]) + builder.append_token("this", TranslationSources.SMT, 0.5) + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.append_token("a", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 3), + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=3, set_values=[(0, 0), (1, 1), (2, 2)]), + ) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 5), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + results.append(builder.to_result()) + + builder.reset() + builder.append_token("that", TranslationSources.SMT, 0.5) + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.append_token("a", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 3), + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=3, set_values=[(0, 0), (1, 1), (2, 2)]), + ) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 5), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + results.append(builder.to_result()) + + builder.reset() + builder.append_token("other", TranslationSources.SMT, 0.5) + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.append_token("a", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 3), + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=3, set_values=[(0, 0), (1, 1), (2, 2)]), + ) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 5), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + results.append(builder.to_result()) + + suggester = PhraseTranslationSuggester(confidence_threshold=0.2) + suggestions = suggester.get_suggestions(n=2, prefix_count=0, is_last_word_complete=True, results=results) + assert len(suggestions) == 2 + assert list(suggestions[0].target_words) == ["this", "is", "a", "test"] + assert list(suggestions[1].target_words) == ["that", "is", "a", "test"] + + +def test_get_suggestions_duplicate() -> None: + results: List[TranslationResult] = [] + builder = TranslationResultBuilder(["esto", "es", "una", "prueba", ".", "segunda", "frase"]) + builder.append_token("this", TranslationSources.SMT, 0.5) + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.append_token("a", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 3), + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=3, set_values=[(0, 0), (1, 1), (2, 2)]), + ) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 5), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + builder.append_token("second", TranslationSources.SMT, 0.1) + builder.append_token("sentence", TranslationSources.SMT, 0.1) + builder.mark_phrase( + Range.create(5, 7), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + results.append(builder.to_result()) + + builder.reset() + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.append_token("a", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 3), + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=2, set_values=[(1, 0), (2, 1)]), + ) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 5), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + builder.append_token("second", TranslationSources.SMT, 0.1) + builder.append_token("sentence", TranslationSources.SMT, 0.1) + builder.mark_phrase( + Range.create(5, 7), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + results.append(builder.to_result()) + + suggester = PhraseTranslationSuggester(confidence_threshold=0.2) + suggestions = suggester.get_suggestions(n=2, prefix_count=0, is_last_word_complete=True, results=results) + assert len(suggestions) == 1 + assert list(suggestions[0].target_words) == ["this", "is", "a", "test"] + + +def test_get_suggestions_starts_with_punctuation() -> None: + results: List[TranslationResult] = [] + builder = TranslationResultBuilder(["esto", "es", "una", "prueba", "."]) + builder.append_token(",", TranslationSources.SMT, 0.5) + builder.append_token("this", TranslationSources.SMT, 0.5) + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.append_token("a", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 3), + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=4, set_values=[(0, 1), (1, 2), (2, 3)]), + ) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 5), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + results.append(builder.to_result()) + + builder.reset() + builder.append_token("this", TranslationSources.SMT, 0.5) + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.append_token("a", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 3), + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=3, set_values=[(0, 0), (1, 1), (2, 2)]), + ) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 5), + WordAlignmentMatrix.from_word_pairs(row_count=2, column_count=2, set_values=[(0, 0), (1, 1)]), + ) + results.append(builder.to_result()) + + suggester = PhraseTranslationSuggester(confidence_threshold=0.2) + suggestions = suggester.get_suggestions(n=2, prefix_count=0, is_last_word_complete=True, results=results) + assert len(suggestions) == 1 + assert list(suggestions[0].target_words) == ["this", "is", "a", "test"] + + +def test_get_suggestions_below_threshold() -> None: + builder = TranslationResultBuilder(["esto", "es", "una", "prueba", "."]) + builder.append_token("this", TranslationSources.SMT, 0.5) + builder.append_token("is", TranslationSources.SMT, 0.5) + builder.append_token("a", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(0, 3), + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=3, set_values=[(0, 0), (1, 1), (2, 2)]), + ) + builder.append_token("bad", TranslationSources.SMT, 0.1) + builder.append_token("test", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(3, 4), + WordAlignmentMatrix.from_word_pairs(row_count=1, column_count=2, set_values=[(0, 1)]), + ) + builder.append_token(".", TranslationSources.SMT, 0.5) + builder.mark_phrase( + Range.create(4, 5), + WordAlignmentMatrix.from_word_pairs(row_count=1, column_count=1, set_values=[(0, 0)]), + ) + + suggester = PhraseTranslationSuggester(confidence_threshold=0.2) + suggestions = suggester.get_suggestions( + n=1, prefix_count=0, is_last_word_complete=True, results=[builder.to_result()] + ) + assert list(suggestions[0].target_words) == ["this", "is", "a"] diff --git a/tests/translation/test_word_aligner.py b/tests/translation/test_word_aligner.py index 4d01296..cf46838 100644 --- a/tests/translation/test_word_aligner.py +++ b/tests/translation/test_word_aligner.py @@ -1,5 +1,4 @@ -from mockito import ANY, when -from testutils import make_concrete +from typing import Sequence from machine.corpora.parallel_text_row import ParallelTextRow from machine.translation import WordAligner, WordAlignmentMatrix @@ -18,10 +17,19 @@ def test_align_parallel_text_row() -> None: estimated_alignment = WordAlignmentMatrix.from_word_pairs( 10, 7, {(1, 1), (2, 1), (4, 2), (5, 1), (6, 3), (7, 4), (8, 5), (9, 6)} ) - TestWordAligner = make_concrete(WordAligner) - when(TestWordAligner).align(ANY, ANY).thenReturn(estimated_alignment) - aligner = TestWordAligner() # type: ignore + aligner = _MockWordAligner(estimated_alignment) alignment = aligner.align_parallel_text_row(row) assert alignment == WordAlignmentMatrix.from_word_pairs( 10, 7, {(0, 0), (1, 1), (2, 1), (4, 2), (6, 3), (8, 4), (7, 5), (9, 6)} ) + + +class _MockWordAligner(WordAligner): + def __init__(self, alignment: WordAlignmentMatrix) -> None: + self._alignment = alignment + + def align(self, source_segment: Sequence[str], target_segment: Sequence[str]) -> WordAlignmentMatrix: + return self._alignment + + def align_batch(self, segments: Sequence[Sequence[Sequence[str]]]) -> Sequence[WordAlignmentMatrix]: + return [self._alignment for _ in segments]