Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use config options and auto-downloaded weights #246

Merged
merged 14 commits into from
Dec 12, 2023
Merged
10 changes: 7 additions & 3 deletions .github/workflows/screenshots.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out the repo
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
ref: ${{ github.head_ref }}

- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v4
with:
python-version: "3.10"

- name: Install your custom tools
run: pip install .
run: |
python -m pip install --upgrade pip
pip install .

- name: Generate terminal images with rich-codex
uses: ewels/rich-codex@v1
Expand Down
8 changes: 4 additions & 4 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def sequence(
to sequence peptides.
"""
output = setup_logging(output, verbosity)
config = setup_model(model, config, output, False)
config, model = setup_model(model, config, output, False)
with ModelRunner(config, model) as runner:
logger.info("Sequencing peptides from:")
for peak_file in peak_path:
Expand Down Expand Up @@ -164,7 +164,7 @@ def evaluate(
such as those provided by MassIVE-KB.
"""
output = setup_logging(output, verbosity)
config = setup_model(model, config, output, False)
config, model = setup_model(model, config, output, False)
with ModelRunner(config, model) as runner:
logger.info("Sequencing and evaluating peptides from:")
for peak_file in annotated_peak_path:
Expand Down Expand Up @@ -207,7 +207,7 @@ def train(
provided by MassIVE-KB, from which to train a new Casnovo model.
"""
output = setup_logging(output, verbosity)
config = setup_model(model, config, output, True)
config, model = setup_model(model, config, output, True)
with ModelRunner(config, model) as runner:
logger.info("Training a model from:")
for peak_file in train_peak_path:
Expand Down Expand Up @@ -378,7 +378,7 @@ def setup_model(
for key, value in config.items():
logger.debug("%s = %s", str(key), str(value))

return config
return config, model
bittremieux marked this conversation as resolved.
Show resolved Hide resolved


def _get_model_weights() -> str:
Expand Down
50 changes: 45 additions & 5 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import uuid
from pathlib import Path
from typing import Iterable, List, Optional, Union
import warnings

import lightning.pytorch as pl
import numpy as np
Expand Down Expand Up @@ -217,6 +218,7 @@ def initialize_model(self, train: bool) -> None:
max_charge=self.config.max_charge,
precursor_mass_tol=self.config.precursor_mass_tol,
isotope_error_range=self.config.isotope_error_range,
min_peptide_len=self.config.min_peptide_len,
n_beams=self.config.n_beams,
top_match=self.config.top_match,
n_log=self.config.n_log,
Expand All @@ -229,6 +231,24 @@ def initialize_model(self, train: bool) -> None:
calculate_precision=self.config.calculate_precision,
)

# Reconfigurable non-architecture related parameters for a loaded model
loaded_model_params = dict(
max_length=self.config.max_length,
precursor_mass_tol=self.config.precursor_mass_tol,
isotope_error_range=self.config.isotope_error_range,
n_beams=self.config.n_beams,
min_peptide_len=self.config.min_peptide_len,
top_match=self.config.top_match,
n_log=self.config.n_log,
tb_summarywriter=self.config.tb_summarywriter,
warmup_iters=self.config.warmup_iters,
max_iters=self.config.max_iters,
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay,
out_writer=self.writer,
calculate_precision=self.config.calculate_precision,
)

from_scratch = (
self.config.train_from_scratch,
self.model_filename is None,
Expand All @@ -254,13 +274,33 @@ def initialize_model(self, train: bool) -> None:
self.model = Spec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
**loaded_model_params,
)
except RuntimeError:
self.model = Spec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
**model_params,

architecture_params = set(model_params.keys()) - set(
loaded_model_params.keys()
)
for param in architecture_params:
if model_params[param] != self.model.hparams[param]:
warnings.warn(
f"Mismatching {param} parameter in "
f"model checkpoint ({self.model.hparams[param]}) "
f"vs. config file ({model_params[param]}), "
f"using the checkpoint."
)
except RuntimeError:
# This only doesn't work if the weights are from an older version
try:
self.model = Spec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
**model_params,
)
except RuntimeError:
raise RuntimeError(
"Weights file incompatible "
"with the current version of Casanovo. "
)

def initialize_data_module(
self,
Expand Down
Loading