From 6df7ece3f552ac0206b363bebf44b338a24bea1b Mon Sep 17 00:00:00 2001 From: John Lambert Date: Thu, 5 Sep 2024 16:59:13 -0400 Subject: [PATCH] More consistent fix Better type hinting --- machine/jobs/nmt_engine_build_job.py | 2 +- .../huggingface/hugging_face_nmt_model_trainer.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/machine/jobs/nmt_engine_build_job.py b/machine/jobs/nmt_engine_build_job.py index c3caf72..63b4b81 100644 --- a/machine/jobs/nmt_engine_build_job.py +++ b/machine/jobs/nmt_engine_build_job.py @@ -74,7 +74,7 @@ def _train_model( ) as model_trainer: model_trainer.train(progress=phase_progress, check_canceled=check_canceled) model_trainer.save() - train_corpus_size = parallel_corpus.count() + train_corpus_size = model_trainer.stats.train_corpus_size return train_corpus_size, float("nan") def _batch_inference( diff --git a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py index 6103d17..a29824b 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py +++ b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py @@ -102,10 +102,11 @@ def __init__( self._add_unk_trg_tokens = add_unk_trg_tokens self._mpn = MosesPunctNormalizer() self._mpn.substitutions = [(re.compile(r), sub) for r, sub in self._mpn.substitutions] + self._stats = TrainStats() @property def stats(self) -> TrainStats: - return super().stats + return self._stats def train( self, @@ -141,7 +142,7 @@ def train( set_seed(self._training_args.seed) if isinstance(self._model, PreTrainedModel): - model = self._model + model: PreTrainedModel = self._model self._original_use_cache = model.config.use_cache model.config.use_cache = not self._training_args.gradient_checkpointing else: @@ -366,6 +367,7 @@ def preprocess_function(examples): self._metrics = train_result.metrics self._metrics["train_samples"] = len(train_dataset) + self._stats.train_corpus_size = len(train_dataset) self._trainer.log_metrics("train", self._metrics) logger.info("Model training finished") @@ -377,9 +379,10 @@ def save(self) -> None: self._trainer.save_metrics("train", self._metrics) self._trainer.save_state() if isinstance(self._model, PreTrainedModel): - self._model.name_or_path = self._training_args.output_dir - self._model.config.name_or_path = self._training_args.output_dir - self._model.config.use_cache = self._original_use_cache + model: PreTrainedModel = self._model + model.name_or_path = self._training_args.output_dir + model.config.name_or_path = self._training_args.output_dir + model.config.use_cache = self._original_use_cache def __exit__(self, type: Any, value: Any, traceback: Any) -> None: if self._trainer is not None: