Skip to content

Commit

Permalink
I think it's working now.
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Nov 18, 2024
1 parent 2276dbc commit 98c1dcd
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 76 deletions.
3 changes: 2 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
"tamasfe.even-better-toml",
"github.vscode-github-actions",
"mhutchie.git-graph",
"GitHub.copilot"
"GitHub.copilot",
"ms-toolsai.jupyter"
]
}
}
Expand Down
18 changes: 14 additions & 4 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,21 @@
},
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.analysis.extraPaths": ["tests"],
"python.analysis.importFormat": "relative",
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
},
},
"black-formatter.path": ["poetry", "run", "black"]
}
"black-formatter.path": [
"poetry",
"run",
"black"
],
"isort.args": [
"--profile",
"black"
],
}
27 changes: 27 additions & 0 deletions machine/corpora/parallel_text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def from_hf_dataset(
ref_factory,
)

@classmethod
def from_parallel_rows(
cls,
rows: Iterable[ParallelTextRow],
) -> ParallelTextCorpus:
return _FromParallelRowsTextCorpus(rows)

@property
@abstractmethod
def is_source_tokenized(self) -> bool: ...
Expand Down Expand Up @@ -754,3 +761,23 @@ def _get_translation(self, lang: str, example: dict) -> str:
except ValueError:
return ""
return ""


class _FromParallelRowsTextCorpus(ParallelTextCorpus):
def __init__(self, rows: Iterable[ParallelTextRow]) -> None:
self._rows = rows

def _get_rows(self, text_ids: Optional[Iterable[str]] = None) -> Generator[ParallelTextRow, None, None]:
if text_ids is None:
yield from self._rows
else:
text_ids = set(text_ids)
yield from [row for row in self._rows if row.text_id in text_ids]

@property
def is_source_tokenized(self) -> bool:
return True

@property
def is_target_tokenized(self) -> bool:
return True
17 changes: 16 additions & 1 deletion machine/jobs/shared_file_service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@
from typing import Any, Iterator, TextIO


class PrettyFloat(float):
def __repr__(self):
return "%.8g" % self


def pretty_floats(obj):
if isinstance(obj, float):
return PrettyFloat(obj)
elif isinstance(obj, dict):
return dict((k, pretty_floats(v)) for k, v in obj.items())
elif isinstance(obj, (list, tuple)):
return list(map(pretty_floats, obj))
return obj


class DictToJsonWriter:
def __init__(self, file: TextIO) -> None:
self._file = file
Expand All @@ -12,7 +27,7 @@ def __init__(self, file: TextIO) -> None:
def write(self, pi: object) -> None:
if not self._first:
self._file.write(",\n")
self._file.write(" " + json.dumps(pi))
self._file.write(" " + json.dumps(pretty_floats(pi)))
self._first = False


Expand Down
55 changes: 36 additions & 19 deletions machine/jobs/word_alignment_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from ..corpora.aligned_word_pair import AlignedWordPair
from ..corpora.parallel_text_corpus import ParallelTextCorpus
from ..corpora.parallel_text_row import ParallelTextRow
from ..tokenization.tokenizer_factory import create_tokenizer
from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter
from ..utils.progress_status import ProgressStatus
Expand Down Expand Up @@ -51,11 +52,7 @@ def run(

logger.info("Generating alignments")

source_word_alignment_corpus = self._word_alignment_file_service.create_source_word_alignment_corpus()
target_word_alignment_corpus = self._word_alignment_file_service.create_target_word_alignment_corpus()
word_alignment_parallel_corpus: ParallelTextCorpus = source_word_alignment_corpus.align_rows(target_word_alignment_corpus)

self._batch_inference(word_alignment_parallel_corpus, progress_reporter, check_canceled)
self._batch_inference(progress_reporter, check_canceled)

self._save_model()
return train_corpus_size
Expand Down Expand Up @@ -88,39 +85,59 @@ def _train_model(

def _batch_inference(
self,
parallel_corpus: ParallelTextCorpus,
progress_reporter: PhasedProgressReporter,
check_canceled: Optional[Callable[[], None]],
) -> None:
inference_step_count = parallel_corpus.count(include_empty=False)

inference_inputs = self._word_alignment_file_service.get_word_alignment_inputs()

inference_step_count = len(inference_inputs)

with ExitStack() as stack:
phase_progress = stack.enter_context(progress_reporter.start_next_phase())
alignment_model = stack.enter_context(self._word_alignment_model_factory.create_alignment_model())
writer = stack.enter_context(self._word_alignment_file_service.open_target_alignment_writer())
writer = stack.enter_context(self._word_alignment_file_service.open_alignment_output_writer())
current_inference_step = 0
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
batch_size = self._config["inference_batch_size"]
segment_batch = list(parallel_corpus.lowercase().tokenize(self._tokenizer).take(batch_size))

parallel_corpus = ParallelTextCorpus.from_parallel_rows(
[
ParallelTextRow(
ii["textId"],
ii["refs"],
ii["refs"],
list(self._tokenizer.tokenize(ii["source"])),
list(self._tokenizer.tokenize(ii["target"])),
)
for ii in inference_inputs
]
).lowercase()

segment_batch = list(parallel_corpus.take(batch_size))
if check_canceled is not None:
check_canceled()
alignments = alignment_model.align_batch(segment_batch)
if check_canceled is not None:
check_canceled()

def format_score(score: float) -> str:
return f"{score:.8f}".rstrip("0").rstrip(".")

for row, alignment in zip(parallel_corpus.get_rows(), alignments):
source_segment = list(self._tokenizer.tokenize(row.source_text))
target_segment = list(self._tokenizer.tokenize(row.target_text))
for parallel_text_row, inference_input, alignment in zip(
parallel_corpus.get_rows(), inference_inputs, alignments
):
word_pairs = alignment.to_aligned_word_pairs(include_null=True)
alignment_model.compute_aligned_word_pair_scores(source_segment, target_segment, word_pairs)
alignment_model.compute_aligned_word_pair_scores(
parallel_text_row.source_segment, parallel_text_row.target_segment, word_pairs
)

word_alignment_info = {
"refs": [str(ref) for ref in row.source_refs],
"column_count": alignment.column_count,
"row_count": alignment.row_count,
"corpus_id": inference_input["corpusId"],
"text_id": inference_input["textId"],
"refs": [str(ref) for ref in inference_input["refs"]],
"source_tokens": parallel_text_row.source_segment,
"target_tokens": parallel_text_row.target_segment,
"confidences": [
word_pair.alignment_score * word_pair.translation_score for word_pair in word_pairs
],
"alignment": AlignedWordPair.to_string(word_pairs),
}
writer.write(word_alignment_info)
Expand Down
58 changes: 36 additions & 22 deletions machine/jobs/word_alignment_file_service.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,38 @@
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Iterator
from typing import Any, Iterator, List, TypedDict

import json_stream

from ..corpora.text_corpus import TextCorpus
from ..corpora.text_file_text_corpus import TextFileTextCorpus
from .shared_file_service_base import DictToJsonWriter, SharedFileServiceBase
from .shared_file_service_factory import SharedFileServiceType, get_shared_file_service


class WordAlignmentInput(TypedDict):
corpusId: str # noqa: N815
textId: str # noqa: N815
refs: List[str]
source: str
target: str


class WordAlignmentFileService:
def __init__(
self,
type: SharedFileServiceType,
config: Any,
source_filename: str = "train.src.txt",
target_filename: str = "train.trg.txt",
source_word_alignment_filename: str = "align_words.src.json",
target_word_alignment_filename: str = "align_words.trg.json",
word_alignment_filename: str = "word_alignments.json",
word_alignment_input_filename: str = "word_alignment_inputs.json",
word_alignment_output_filename: str = "word_alignment_outputs.json",
) -> None:

self._source_filename = source_filename
self._target_filename = target_filename
self._source_word_alignment_filename = source_word_alignment_filename
self._target_word_alignment_filename = target_word_alignment_filename
self._word_alignment_filename = word_alignment_filename
self._word_alignment_input_filename = word_alignment_input_filename
self._word_alignment_output_filename = word_alignment_output_filename

self.shared_file_service: SharedFileServiceBase = get_shared_file_service(type, config)

Expand All @@ -38,31 +46,37 @@ def create_target_corpus(self) -> TextCorpus:
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._target_filename}")
)

def create_source_word_alignment_corpus(self) -> TextCorpus:
return TextFileTextCorpus(
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._source_word_alignment_filename}")
def get_word_alignment_inputs(self) -> List[WordAlignmentInput]:
src_pretranslate_path = self.shared_file_service.download_file(
f"{self.shared_file_service.build_path}/{self._word_alignment_input_filename}"
)
with src_pretranslate_path.open("r", encoding="utf-8-sig") as file:
wa_inputs = [
WordAlignmentInput(
corpusId=pi["corpusId"],
textId=pi["textId"],
refs=list(pi["refs"]),
source=pi["source"],
target=pi["target"],
)
for pi in json_stream.load(file)
]
return wa_inputs

def create_target_word_alignment_corpus(self) -> TextCorpus:
return TextFileTextCorpus(
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._target_word_alignment_filename}")
)

def exists_source_corpus(self) -> bool:
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._source_filename}")

def exists_target_corpus(self) -> bool:
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._target_filename}")

def exists_source_word_alignment_corpus(self) -> bool:
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._source_word_alignment_filename}")

def exists_target_word_alignment_corpus(self) -> bool:
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._target_word_alignment_filename}")
def exists_word_alignment_inputs(self) -> bool:
return self.shared_file_service._exists_file(
f"{self.shared_file_service.build_path}/{self._word_alignment_input_filename}"
)

def save_model(self, model_path: Path, destination: str) -> None:
self.shared_file_service.upload_path(model_path, destination)

@contextmanager
def open_target_alignment_writer(self) -> Iterator[DictToJsonWriter]:
return self.shared_file_service.open_target_writer(self._word_alignment_filename)
def open_alignment_output_writer(self) -> Iterator[DictToJsonWriter]:
return self.shared_file_service.open_target_writer(self._word_alignment_output_filename)
Loading

0 comments on commit 98c1dcd

Please sign in to comment.