diff --git a/.vscode/settings.json b/.vscode/settings.json index f5cbecac..6e4c2a3e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,5 +8,6 @@ "[python]": { "editor.defaultFormatter": "ms-python.black-formatter", "editor.formatOnSave": true - } + }, + "black-formatter.path": ["poetry", "run", "black"] } diff --git a/machine/jobs/build_nmt_engine.py b/machine/jobs/build_nmt_engine.py index 5e9b7e33..9dc78550 100644 --- a/machine/jobs/build_nmt_engine.py +++ b/machine/jobs/build_nmt_engine.py @@ -88,6 +88,7 @@ def main() -> None: parser.add_argument("--trg-lang", required=True, type=str, help="Target language tag") parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task") parser.add_argument("--build-options", default=None, type=str, help="Build configurations") + parser.add_argument("--save-model", default=None, type=str, help="Save the model using the specified base name") args = parser.parse_args() run({k: v for k, v in vars(args).items() if v is not None}) diff --git a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py index df8cd91b..871c3afd 100644 --- a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py +++ b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py @@ -1,4 +1,5 @@ import logging +import tarfile from pathlib import Path from typing import Any, cast @@ -84,7 +85,15 @@ def create_engine(self) -> TranslationEngine: ) def save_model(self) -> None: - self._shared_file_service.save_model(self._model_dir) + if "save_model" not in self._config: + return + + tar_file_path = Path(self._config.data_dir, "builds", self._config.build_id, "model.tar.gz") + with tarfile.open(tar_file_path, "w:gz") as tar: + for path in self._model_dir.iterdir(): + if path.is_file(): + tar.add(path, arcname=path.name) + self._shared_file_service.save_model(tar_file_path, self._config.save_model + ".tar.gz") @property def _model_dir(self) -> Path: diff --git a/machine/jobs/nmt_engine_build_job.py b/machine/jobs/nmt_engine_build_job.py index 7fb19efd..24d6789d 100644 --- a/machine/jobs/nmt_engine_build_job.py +++ b/machine/jobs/nmt_engine_build_job.py @@ -94,6 +94,10 @@ def run( current_inference_step += len(pi_batch) phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) + if "save_model" in self._config and self._config.save_model is not None: + logger.info("Saving model") + self._nmt_model_factory.save_model() + def _translate_batch( engine: TranslationEngine, diff --git a/machine/jobs/shared_file_service.py b/machine/jobs/shared_file_service.py index 112b9ff2..0a478b2a 100644 --- a/machine/jobs/shared_file_service.py +++ b/machine/jobs/shared_file_service.py @@ -76,8 +76,11 @@ def open_target_pretranslation_writer(self) -> Iterator[PretranslationWriter]: def get_parent_model(self, language_tag: str) -> Path: return self._download_folder(f"parent_models/{language_tag}", cache=True) - def save_model(self, model_dir: Path) -> None: - self._upload_folder(f"models/{self._engine_id}", model_dir) + def save_model(self, model_path: Path, name: str) -> None: + if model_path.is_file(): + self._upload_file(f"models/{name}", model_path) + else: + self._upload_folder(f"models/{name}", model_path) @property def _data_dir(self) -> Path: