Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Designate a filename for the "best" model #351

Closed
wsnoble opened this issue Jul 2, 2024 · 11 comments · Fixed by #365
Closed

Designate a filename for the "best" model #351

wsnoble opened this issue Jul 2, 2024 · 11 comments · Fixed by #365
Assignees
Labels
enhancement New feature or request

Comments

@wsnoble
Copy link
Contributor

wsnoble commented Jul 2, 2024

It would be nice if there were a programmatic way to identify the best-performing model, as measured by validation error. I am thinking that in addition to outputting checkpoints with names like epoch=4-step=450000.ckpt we could output a copy (or symlink) to the one with the best validation error using a name like "final.ckpt".

@wsnoble wsnoble added the enhancement New feature or request label Jul 2, 2024
@Lilferrit
Copy link
Contributor

Unfortunately looking at the PyLightning documentation this isn't supported by the ModelCheckpoint class natively. However, it would be possible to define a custom callback which would own a reference to the validation ModelCheckpoint, and whenever the validation ModelCheckpoint fires the custom callback would save a symlink to the validation checkpoint with the best validation loss with a name like best.ckpt. Thoughts on this approach would be appreciated.

@wsnoble
Copy link
Contributor Author

wsnoble commented Jul 2, 2024

I like the name best.ckpt better than my suggestion (final.ckpt) because it's more accurate. Using a symlink rather than a copy is also better. I will let others comment on the best way to implement this.

@bittremieux
Copy link
Collaborator

Simply saving the best model is already supported by setting save_top_k to 1 in the config, no code changes required.

@wsnoble
Copy link
Contributor Author

wsnoble commented Jul 3, 2024

Yes, but my point is that the user has no way in advance of knowing exactly what that model will be called. And if save_top_k is >1, then it's hard to tell which is the best.

@bittremieux
Copy link
Collaborator

bittremieux commented Jul 4, 2024

Yes, but my point is that the user has no way in advance of knowing exactly what that model will be called.

I see two low-code options to address that:

  1. Provide a config option to set the name of the output file, optionally removing the epoch and step information.
  2. Add an additionall ModelCheckpoint that only saves the best model with a fixed name, in addition to the currently saved best 5 and last model.

And if save_top_k is >1, then it's hard to tell which is the best.

Well yes, but that's by design. That's exactly what saving the top 5 models means. If you don't want that, change the number of models to be saved. 🤷‍♂️ Likely you don't need any of those lower-performing models anyway.

@Lilferrit
Copy link
Contributor

Addressing the second low-code approach, I'm not terribly familiar with the PyLightning API but wouldn't this create a copy of the best performing weights rather than a symlink? I'm not sure how concerned we are about disk usage.

@bittremieux
Copy link
Collaborator

Yes, it would be a copy of the weights. I don't think that's a real concern though (weights will be smaller than training data or even a run for sequencing anyway).

@Lilferrit
Copy link
Contributor

I tried introducing another model checkpoint that monitors the valid_CELoss but has k set to one. However, it turns out you can't have to ModelCheckpoints monitoring the same quantity in the same mode (in this case valid_CELoss and min). I'm looking into a workaround atm.

@Lilferrit
Copy link
Contributor

Lilferrit commented Aug 2, 2024

After investigating some potential workarounds I found a workable low(ish) code solution. By adding this to model_runner.py (tentatively, let me know if it should be moved):

class SaveBestModelCheckpoint(ModelCheckpoint):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def on_validation_end(self, trainer, pl_module):
        super().on_validation_end(trainer, pl_module)
        target_path = Path(self.best_model_path)
        simlink_path = Path(self.dirpath) / "best.ckpt"
        simlink_path.unlink()
        simlink_path.symlink_to(target_path)

and replacing

if config.save_top_k is not None:
  self.callbacks.append(
      ModelCheckpoint(
          dirpath=config.model_save_folder_path,
          monitor="valid_CELoss",
          mode="min",
          save_top_k=config.save_top_k,
      )
  )

with

if config.save_top_k is not None:
  self.callbacks.append(
      SaveBestModelCheckpoint(
          dirpath=config.model_save_folder_path,
          monitor="valid_CELoss",
          mode="min",
          save_top_k=config.save_top_k,
      )
  )

we can essentially create a thin wrapper around ModelCheckpoint that creates a simlink to the best model. I've done some preliminary testing and this solution appears to work. The only problem thus far is (on windows at least) creating a simlink requires admin level privileges, so maybe SaveBestModelCheckpoint can simply copy the best file instead of creating a simlink.

@bittremieux
Copy link
Collaborator

However, it turns out you can't have to ModelCheckpoints monitoring the same quantity in the same mode

So the conflict is between saving the top k models and only the best one? We've discussed this a bit before already, but I'm not sure what the benefit is of saving the best 5 models. Why would you want to use models ranked 2–5? 🤷‍♂️

Alternatively, we could change it to save all checkpoints, which is default Lightning behavior and doesn't require a ModelCheckpoint callback, and then use the callback for just the best model.

@wsnoble
Copy link
Contributor Author

wsnoble commented Aug 3, 2024

I agree that the utility of saving the top 5 is not clear to me. I'm fine with saving all models, as long as there is a way to disable that behavior to save disk space.

@wsnoble wsnoble added this to the Casanovo v5.0.0 milestone Aug 6, 2024
@Lilferrit Lilferrit linked a pull request Aug 12, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants