Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

313 save final model #340

Merged
merged 11 commits into from
Jun 28, 2024
Merged

313 save final model #340

merged 11 commits into from
Jun 28, 2024

Conversation

Lilferrit
Copy link
Contributor

In order to insure the final model is always saved the following lines were added to the end of ModelRunner.train:

# Always save final model weights at the end of training
if self.config.model_save_folder_path is not None:
    self.trainer.save_checkpoint(
        os.path.join(
            self.config.model_save_folder_path,
            "train-run-final.ckpt"
        )
    )

This implementation was tested using a small training run for the case where val_check_interval is not a factor of the total number of training steps and the case where val_check_interval is greater than the total number of training steps. In both cases the final model checkpoints were saved.

@Lilferrit
Copy link
Contributor Author

Added final epoch number to file name - changed implementation to:

# Always save final model weights at the end of training
if self.config.model_save_folder_path is not None:
    self.trainer.save_checkpoint(
        os.path.join(
            self.config.model_save_folder_path,
            f"train-run-final-{self.trainer.current_epoch}.ckpt"
        )
    )

@wsnoble wsnoble requested a review from bittremieux June 25, 2024 16:32
Copy link
Collaborator

@bittremieux bittremieux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments to address.

However, can't we get the same results much easier by setting enable_checkpointing always to True when creating the Trainer and adding a default ModelCheckPoint callback? See the Trainer documentation. That way we can benefit from letting Lightning handle all of this.

casanovo/denovo/model_runner.py Outdated Show resolved Hide resolved
casanovo/denovo/model_runner.py Outdated Show resolved Hide resolved
casanovo/denovo/model_runner.py Outdated Show resolved Hide resolved
@Lilferrit
Copy link
Contributor Author

Some minor comments to address.

However, can't we get the same results much easier by setting enable_checkpointing always to True when creating the Trainer and adding a default ModelCheckPoint callback? See the Trainer documentation. That way we can benefit from letting Lightning handle all of this.

Agreed - I've reimplemented this using a ModelCheckpoint instead. I will push it upstream once I've done some more testing on my end. However regarding the enable_checkpointing operation it looks like this option only adds a default callback if no user defined callbacks are added to callbacks, so effectively this would do nothing if the validation ModelCheckpoint is added. In order to always save the final model I instead added a new ModelCheckpoint that fires at the end of every training epoch.

@Lilferrit
Copy link
Contributor Author

Reimplemented using ModelCheckpoint, the last lines of the ModelRunner constructor now are:

# Configure checkpoints.
self.callbacks = [
    ModelCheckpoint(
        dirpath=config.model_save_folder_path,
        save_on_train_epoch_end=True,
    )
]

if config.save_top_k is not None:
    self.callbacks.append(
        ModelCheckpoint(
            dirpath=config.model_save_folder_path,
            monitor="valid_CELoss",
            mode="min",
            save_top_k=config.save_top_k,
        )
    )

Copy link
Collaborator

@bittremieux bittremieux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, I think that this is a better solution.

The remaining thing to do is add unit tests that verify that the final model is saved for different situations (different values of steps and epochs). You could also add some tests to check that the periodic checkpoints are properly created as well.

@bittremieux bittremieux linked an issue Jun 26, 2024 that may be closed by this pull request
Copy link

codecov bot commented Jun 26, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 89.88%. Comparing base (70ea9fc) to head (a743dc5).

Additional details and impacted files
@@            Coverage Diff             @@
##              dev     #340      +/-   ##
==========================================
+ Coverage   89.77%   89.88%   +0.10%     
==========================================
  Files          12       12              
  Lines         929      929              
==========================================
+ Hits          834      835       +1     
+ Misses         95       94       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Lilferrit
Copy link
Contributor Author

Great, I think that this is a better solution.

The remaining thing to do is add unit tests that verify that the final model is saved for different situations (different values of steps and epochs). You could also add some tests to check that the periodic checkpoints are properly created as well.

Sounds good, I added some unit test that ensure the last model weights are saved in the scenarios where val_check_Interval is greater than and not a factor of the number of training steps. Unfortunately since the ModelCheckpoint that saves the model checkpoints at the end of evert epoch deletes the last training epoch's checkpoints (it doesn't touch the validation checkpoints) when a new epoch checkpoint is saved and the CLIRunner.invoke is blocking I couldn't think of a practical way to test whether the periodic checkpoints are saved properly.

Copy link
Collaborator

@bittremieux bittremieux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few suggestions for the tests.

tests/conftest.py Outdated Show resolved Hide resolved
tests/conftest.py Outdated Show resolved Hide resolved
tests/conftest.py Outdated Show resolved Hide resolved
tests/conftest.py Outdated Show resolved Hide resolved
tests/conftest.py Outdated Show resolved Hide resolved
tests/test_integration.py Outdated Show resolved Hide resolved
@Lilferrit
Copy link
Contributor Author

Sounds good, I factored the save final model test into a separate unit test test_save_final_model in test_runner.py.

@bittremieux
Copy link
Collaborator

Great! The final thing to do is update the changelog.

@Lilferrit
Copy link
Contributor Author

Sounds great, I added an entry to the changelog.

@Lilferrit Lilferrit merged commit 7372eb0 into dev Jun 28, 2024
6 checks passed
@Lilferrit Lilferrit deleted the 313-save-final-model branch June 28, 2024 20:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Save final model
2 participants