Skip to content

Commit

Permalink
Script to strip the optimizer state from the model checkpoints (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
avaucher authored Oct 5, 2023
1 parent fa0fba9 commit 5beae7a
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 11 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ console_scripts =
rxn-onmt-continue-training = rxn.onmt_models.scripts.rxn_onmt_continue_training:main
rxn-onmt-finetune = rxn.onmt_models.scripts.rxn_onmt_finetune:main
rxn-onmt-preprocess = rxn.onmt_models.scripts.rxn_onmt_preprocess:main
rxn-onmt-strip-checkpoints = rxn.onmt_models.scripts.rxn_onmt_strip_checkpoints:main
rxn-onmt-train = rxn.onmt_models.scripts.rxn_onmt_train:main
rxn-plan-training = rxn.onmt_models.scripts.rxn_plan_training:main
rxn-prepare-data = rxn.onmt_models.scripts.rxn_prepare_data:main
Expand Down
80 changes: 80 additions & 0 deletions src/rxn/onmt_models/scripts/rxn_onmt_strip_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import copy
import logging
from pathlib import Path

import click
from rxn.onmt_utils.strip_model import strip_model
from rxn.utilities.logging import setup_console_logger

from rxn.onmt_models.training_files import ModelFiles

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


@click.command(context_settings=dict(show_default=True))
@click.option(
"--model_dir",
"-m",
type=click.Path(exists=True, file_okay=False, path_type=Path),
required=True,
help="Directory with the model checkpoints to strip.",
)
@click.option(
"--strip_last_checkpoint",
is_flag=True,
help="If specified, the optimizer state will be removed from the last checkpoint as well.",
)
def main(model_dir: Path, strip_last_checkpoint: bool) -> None:
"""Strip the checkpoints (i.e. remove the optimizer state) contained in a model directory.
By default, it will not remove the optimizer state from the last checkpoint, as
that one may be needed for finetuning or continued training.
Also, all the model files that do not incorporate a step number are ignored.
If you want to strip a single model, use the ``rxn-strip-opennmt-model`` command.
"""
setup_console_logger()

model_files = ModelFiles(model_dir)

all_checkpoints = model_files.get_checkpoints()

symlink_checkpoints = [p for p in all_checkpoints if p.is_symlink()]
checkpoints_to_strip = [p for p in all_checkpoints if not p.is_symlink()]

if symlink_checkpoints:
print("The following checkpoint(s) are symlinks and will not be stripped:")
for checkpoint in symlink_checkpoints:
print(f" - {checkpoint}")

checkpoints_not_to_strip = copy.deepcopy(symlink_checkpoints)
if not strip_last_checkpoint:
checkpoints_not_to_strip.append(checkpoints_to_strip[-1])
checkpoints_to_strip = checkpoints_to_strip[:-1]

if checkpoints_to_strip:
print("The optimizer state will be removed from the following checkpoints:")
for checkpoint in checkpoints_to_strip:
print(f" - {checkpoint}")
else:
print("No checkpoint to modify.")

if checkpoints_not_to_strip:
print("The following checkpoints will not be modified:")
for checkpoint in checkpoints_not_to_strip:
print(f" - {checkpoint}")

confirm = click.confirm("Do you want to proceed?", default=True)

if not confirm:
print("Stopping here.")
return

for checkpoint in checkpoints_to_strip:
strip_model(model_in=checkpoint, model_out=checkpoint)


if __name__ == "__main__":
main()
29 changes: 18 additions & 11 deletions src/rxn/onmt_models/training_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from itertools import count
from pathlib import Path
from typing import Optional
from typing import List, Optional

from rxn.utilities.files import PathLike

Expand Down Expand Up @@ -39,23 +39,30 @@ def next_config_file(self) -> Path:
return config_file
return Path() # Note: in order to satisfy mypy. This is never reached.

def get_checkpoints(self) -> List[Path]:
"""Get the checkpoints contained in the directory, sorted by step number."""
steps_and_models = [
(self._get_checkpoint_step(path), path) for path in self.model_dir.iterdir()
]
steps_and_models = [
(step, path) for step, path in steps_and_models if step is not None
]

# Sort, from low checkpoint to high checkpoint
steps_and_models.sort()

return [model for _, model in steps_and_models]

def get_last_checkpoint(self) -> Path:
"""Get the last checkpoint matching the naming including the step number.
Raises:
RuntimeError: no model is found in the expected directory.
"""
models_and_steps = [
(self._get_checkpoint_step(path), path) for path in self.model_dir.iterdir()
]
models_and_steps = [
(step, path) for step, path in models_and_steps if step is not None
]
if not models_and_steps:
models = self.get_checkpoints()
if not models:
raise RuntimeError(f'No model found in "{self.model_dir}"')

# Reverse sort, get the path of the first item.
return sorted(models_and_steps, reverse=True)[0][1]
return models[-1]

@staticmethod
def _get_checkpoint_step(path: Path) -> Optional[int]:
Expand Down
53 changes: 53 additions & 0 deletions tests/test_training_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from pathlib import Path
from typing import Iterable

from rxn.utilities.files import named_temporary_directory, paths_are_identical

from rxn.onmt_models.training_files import ModelFiles


def create_files(directory: Path, files_to_create: Iterable[str]) -> None:
for filename in files_to_create:
(directory / filename).touch()


def test_get_checkpoints() -> None:
with named_temporary_directory() as directory:
create_files(
directory,
[
"model_ref.pt",
"model_step_99.pt",
"model_step_0.pt",
"model_step_100.pt",
"model_100000.pt",
],
)

model_files = ModelFiles(directory)
checkpoints = model_files.get_checkpoints()

# check by verifying the names only
assert [p.name for p in checkpoints] == [
"model_step_0.pt",
"model_step_99.pt",
"model_step_100.pt",
]


def test_get_last_checkpoint() -> None:
with named_temporary_directory() as directory:
create_files(
directory,
[
"model_ref.pt",
"model_step_99.pt",
"model_step_0.pt",
"model_step_100.pt",
"model_100000.pt",
],
)

model_files = ModelFiles(directory)
last_checkpoint = model_files.get_last_checkpoint()
assert paths_are_identical(last_checkpoint, directory / "model_step_100.pt")

0 comments on commit 5beae7a

Please sign in to comment.