Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use config options and auto-downloaded weights #246

Merged
merged 14 commits into from
Dec 12, 2023
8 changes: 4 additions & 4 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def sequence(
to sequence peptides.
"""
output = setup_logging(output, verbosity)
config = setup_model(model, config, output, False)
config, model = setup_model(model, config, output, False)
with ModelRunner(config, model) as runner:
logger.info("Sequencing peptides from:")
for peak_file in peak_path:
Expand Down Expand Up @@ -164,7 +164,7 @@ def evaluate(
such as those provided by MassIVE-KB.
"""
output = setup_logging(output, verbosity)
config = setup_model(model, config, output, False)
config, model = setup_model(model, config, output, False)
with ModelRunner(config, model) as runner:
logger.info("Sequencing and evaluating peptides from:")
for peak_file in annotated_peak_path:
Expand Down Expand Up @@ -207,7 +207,7 @@ def train(
provided by MassIVE-KB, from which to train a new Casnovo model.
"""
output = setup_logging(output, verbosity)
config = setup_model(model, config, output, True)
config, model = setup_model(model, config, output, True)
with ModelRunner(config, model) as runner:
logger.info("Training a model from:")
for peak_file in train_peak_path:
Expand Down Expand Up @@ -378,7 +378,7 @@ def setup_model(
for key, value in config.items():
logger.debug("%s = %s", str(key), str(value))

return config
return config, model
bittremieux marked this conversation as resolved.
Show resolved Hide resolved


def _get_model_weights() -> str:
Expand Down
7 changes: 3 additions & 4 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,11 @@ def initialize_model(self, train: bool) -> None:
self.model = Spec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
**model_params,
)
except RuntimeError:
self.model = Spec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
**model_params,
raise RuntimeError(
"Mismatching parameters between loaded model and config file"
)

def initialize_data_module(
Expand Down
Loading