Skip to content

Commit

Permalink
More consistent fix
Browse files Browse the repository at this point in the history
Better type hinting
  • Loading branch information
johnml1135 committed Sep 5, 2024
1 parent 1ff02b4 commit 6df7ece
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion machine/jobs/nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _train_model(
) as model_trainer:
model_trainer.train(progress=phase_progress, check_canceled=check_canceled)
model_trainer.save()
train_corpus_size = parallel_corpus.count()
train_corpus_size = model_trainer.stats.train_corpus_size
return train_corpus_size, float("nan")

def _batch_inference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ def __init__(
self._add_unk_trg_tokens = add_unk_trg_tokens
self._mpn = MosesPunctNormalizer()
self._mpn.substitutions = [(re.compile(r), sub) for r, sub in self._mpn.substitutions]
self._stats = TrainStats()

@property
def stats(self) -> TrainStats:
return super().stats
return self._stats

def train(
self,
Expand Down Expand Up @@ -141,7 +142,7 @@ def train(
set_seed(self._training_args.seed)

if isinstance(self._model, PreTrainedModel):
model = self._model
model: PreTrainedModel = self._model
self._original_use_cache = model.config.use_cache
model.config.use_cache = not self._training_args.gradient_checkpointing
else:
Expand Down Expand Up @@ -366,6 +367,7 @@ def preprocess_function(examples):

self._metrics = train_result.metrics
self._metrics["train_samples"] = len(train_dataset)
self._stats.train_corpus_size = len(train_dataset)

self._trainer.log_metrics("train", self._metrics)
logger.info("Model training finished")
Expand All @@ -377,9 +379,10 @@ def save(self) -> None:
self._trainer.save_metrics("train", self._metrics)
self._trainer.save_state()
if isinstance(self._model, PreTrainedModel):
self._model.name_or_path = self._training_args.output_dir
self._model.config.name_or_path = self._training_args.output_dir
self._model.config.use_cache = self._original_use_cache
model: PreTrainedModel = self._model
model.name_or_path = self._training_args.output_dir
model.config.name_or_path = self._training_args.output_dir
model.config.use_cache = self._original_use_cache

def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
if self._trainer is not None:
Expand Down

0 comments on commit 6df7ece

Please sign in to comment.