Skip to content

Commit

Permalink
Override mismatching params with ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
melihyilmaz committed Nov 28, 2023
1 parent b6b374f commit bfb168e
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import uuid
from pathlib import Path
from typing import Iterable, List, Optional, Union
import warnings

import lightning.pytorch as pl
import numpy as np
Expand Down Expand Up @@ -217,6 +218,7 @@ def initialize_model(self, train: bool) -> None:
max_charge=self.config.max_charge,
precursor_mass_tol=self.config.precursor_mass_tol,
isotope_error_range=self.config.isotope_error_range,
min_peptide_len=self.config.min_peptide_len,
n_beams=self.config.n_beams,
top_match=self.config.top_match,
n_log=self.config.n_log,
Expand All @@ -229,6 +231,24 @@ def initialize_model(self, train: bool) -> None:
calculate_precision=self.config.calculate_precision,
)

# Reconfigurable non-architecture related parameters for a loaded model
loaded_model_params = dict(
max_length=self.config.max_length,
precursor_mass_tol=self.config.precursor_mass_tol,
isotope_error_range=self.config.isotope_error_range,
n_beams=self.config.n_beams,
min_peptide_len=self.config.min_peptide_len,
top_match=self.config.top_match,
n_log=self.config.n_log,
tb_summarywriter=self.config.tb_summarywriter,
warmup_iters=self.config.warmup_iters,
max_iters=self.config.max_iters,
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay,
out_writer=self.writer,
calculate_precision=self.config.calculate_precision,
)

from_scratch = (
self.config.train_from_scratch,
self.model_filename is None,
Expand All @@ -254,11 +274,23 @@ def initialize_model(self, train: bool) -> None:
self.model = Spec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
**model_params,
**loaded_model_params,
)

architecture_params = set(model_params.keys()) - set(
loaded_model_params.keys()
)
for param in architecture_params:
if model_params[param] != self.model.hparams[param]:
warnings.warn(
f"Mismatching {param} parameter in "
f"model checkpoint ({self.model.hparams[param]}) "
f"vs. config file ({model_params[param]}), "
f"using the checkpoint."
)
except RuntimeError:
raise RuntimeError(
"Mismatching parameters between loaded model and config file"
"Weights file incompatible with the current version of Casanovo."
)

def initialize_data_module(
Expand Down

0 comments on commit bfb168e

Please sign in to comment.