Skip to content

Commit

Permalink
requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Aug 30, 2024
1 parent e68858b commit 3d91f81
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 53 deletions.
52 changes: 41 additions & 11 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, *args, **kwargs) -> None:
default="info",
),
click.Option(
("-f", "--overwrite"),
("-f", "--force_overwrite"),
help="Whether to overwrite output files.",
is_flag=True,
show_default=True,
Expand Down Expand Up @@ -159,7 +159,7 @@ def sequence(
output_dir: Optional[str],
output_root: Optional[str],
verbosity: str,
overwrite: bool,
force_overwrite: bool,
evaluate: bool,
) -> None:
"""De novo sequence peptides from tandem mass spectra.
Expand All @@ -168,11 +168,13 @@ def sequence(
to sequence peptides. If evaluate is set to True PEAK_PATH must be
one or more annotated MGF file.
"""
output, output_dir = _resolve_output(output_dir, output_root, verbosity)
if output_root is not None and not overwrite:
file_pattern = re.escape(output_root) + r"\.(?:log|mztab)"
utils.check_dir(output_dir, [file_pattern])
file_patterns = list()
if output_root is not None and not force_overwrite:
file_patterns = [f"{output_root}.log", f"{output_root}.mztab"]

output, output_dir = _resolve_output(
output_dir, output_root, file_patterns, verbosity
)
config, model = setup_model(model, config, output, False)
start_time = time.time()
with ModelRunner(config, model, output_root, output_dir, False) as runner:
Expand Down Expand Up @@ -216,21 +218,24 @@ def train(
output_dir: Optional[str],
output_root: Optional[str],
verbosity: str,
overwrite: bool,
force_overwrite: bool,
) -> None:
"""Train a Casanovo model on your own data.
TRAIN_PEAK_PATH must be one or more annoated MGF files, such as those
provided by MassIVE-KB, from which to train a new Casnovo model.
"""
output, output_dir = _resolve_output(output_dir, output_root, verbosity)
if output_root is not None and not overwrite:
utils.check_dir(output_dir, [re.escape(output_root) + r"\.log"])
file_patterns = list()
if output_root is not None and not force_overwrite:
file_patterns = [f"{output_root}.log"]

output, output_dir = _resolve_output(
output_dir, output_root, file_patterns, verbosity
)
config, model = setup_model(model, config, output, True)
start_time = time.time()
with ModelRunner(
config, model, output_root, output_dir, not overwrite
config, model, output_root, output_dir, not force_overwrite
) as runner:
logger.info("Training a model from:")
for peak_file in train_peak_path:
Expand Down Expand Up @@ -520,12 +525,37 @@ def _get_model_weights(cache_dir: Path) -> str:
def _resolve_output(
output_dir: str | None,
output_root: str | None,
file_patterns: list[str],
verbosity: str,
) -> Tuple[Path, str]:
"""
Resolves the output directory and sets up logging.
Parameters:
-----------
output_dir : str | None
The path to the output directory. If `None`, the current working
directory will be used.
output_root : str | None
The base name for the output files. If `None`, no specific base name is
set, and logging will be configured accordingly to the behavior of
`setup_logging`.
file_patterns : list[str]
A list of file patterns that should be checked within the `output_dir`.
verbosity : str
The verbosity level for logging.
Returns:
--------
Tuple[Path, str]
The output directory and the base name for log and results files (if
applicable).
"""
output_dir = Path(output_dir) if output_dir is not None else Path.cwd()
output_base_name = (
None if output_root is None else (output_dir / output_root)
)
utils.check_dir(output_dir, file_patterns)
output = setup_logging(output_base_name, verbosity)
return output, output_dir

Expand Down
14 changes: 7 additions & 7 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ def __init__(
)

if overwrite_ckpt_check:
patterns = [r"epoch=\d+\-step=\d+\.ckpt", r"best\.ckpt"]
if output_rootname is not None:
patterns = [
re.escape(output_rootname + ".") + pattern
for pattern in patterns
]
utils.check_dir(output_dir, patterns)
utils.check_dir(
output_dir,
[
f"{curr_filename.format(epoch='*', step='*')}.ckpt",
f"{best_filename}.ckpt",
],
)

# Configure checkpoints.
self.callbacks = [
Expand Down
24 changes: 9 additions & 15 deletions casanovo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,32 +256,26 @@ def log_sequencing_report(
)


def check_dir(
dir: pathlib.Path, file_patterns: Iterable[re.Pattern[str]]
) -> None:
def check_dir(dir: pathlib.Path, file_patterns: Iterable[str]) -> None:
"""
Check that no file names in dir match any of file_patterns
Parameters
----------
dir : pathlib.Path
The directory to check for matching file names
file_patterns : Iterable[re.Pattern[str]]
File name re patterns to test file names against
file_patterns : Iterable[str]
UNIX style wildcard pattern to test file names against
Raises
------
FileExistsError
If matching file name is found in dir
"""
for pattern in file_patterns:
comp_pattern = re.compile(pattern)
for file in dir.iterdir():
if not file.is_file():
continue

if comp_pattern.fullmatch(file.name) is not None:
raise FileExistsError(
f"File {file.name} already exists in {dir} "
"and can not be overwritten."
)
matches = list(dir.glob(pattern))
if len(matches) > 0:
raise FileExistsError(
f"File {matches[0].name} already exists in {dir} "
"and can not be overwritten."
)
15 changes: 0 additions & 15 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path

import pyteomics.mztab
import pytest
from click.testing import CliRunner

from casanovo import casanovo
Expand Down Expand Up @@ -44,10 +43,6 @@ def test_train_and_run(
assert model_file.exists()
assert best_model.exists()

# Check that re-running train fails due to no overwrite
with pytest.raises(FileExistsError):
run(train_args)

assert model_file.exists()
assert best_model.exists()

Expand Down Expand Up @@ -90,12 +85,6 @@ def test_train_and_run(
assert psms.loc[4, "sequence"] == "PEPTLDEK"
assert psms.loc[4, "spectra_ref"] == "ms_run[2]:scan=111"

# Verify that running predict again fails due to no overwrite
with pytest.raises(FileExistsError):
run(predict_args)

assert output_filename.is_file()

# Finally, try evaluating:
output_rootname = "test-eval"
output_filename = (tmp_path / output_rootname).with_suffix(".mztab")
Expand Down Expand Up @@ -156,10 +145,6 @@ def test_train_and_run(
]
)

# Verify that running again fails due to no overwrite
with pytest.raises(FileExistsError):
run(eval_args)

assert output_filename.is_file()


Expand Down
13 changes: 10 additions & 3 deletions tests/unit_tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,17 @@ def test_save_and_load_weights_deprecated(tmp_path, mgf_small, tiny_config):
torch.save(ckpt_data, str(ckpt))

# Inference.
with ModelRunner(config=config, model_filename=str(ckpt)) as runner:
with ModelRunner(
config=config, model_filename=str(ckpt), overwrite_ckpt_check=False
) as runner:
runner.initialize_model(train=False)
assert runner.model.cosine_schedule_period_iters == 5
# Fine-tuning.
with ModelRunner(
config=config, model_filename=str(ckpt), output_dir=tmp_path
config=config,
model_filename=str(ckpt),
output_dir=tmp_path,
overwrite_ckpt_check=False,
) as runner:
with pytest.warns(DeprecationWarning):
runner.train([mgf_small], [mgf_small])
Expand All @@ -149,7 +154,9 @@ def test_calculate_precision(tmp_path, mgf_small, tiny_config):
assert "valid_pep_precision" not in runner.model.history.columns

config.calculate_precision = True
runner = ModelRunner(config=config, output_dir=tmp_path)
runner = ModelRunner(
config=config, output_dir=tmp_path, overwrite_ckpt_check=False
)
with runner:
runner.train([mgf_small], [mgf_small])

Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,9 +924,9 @@ def test_run_map(mgf_small):

def test_check_dir(tmp_path):
exists_path = tmp_path / "exists-1234.ckpt"
exists_pattern = r"exists\-\d+\.ckpt"
exists_pattern = "exists-*.ckpt"
exists_path.touch()
dne_pattern = r"dne\-\d+\.ckpt"
dne_pattern = "dne-*.ckpt"

with pytest.raises(FileExistsError):
utils.check_dir(tmp_path, [exists_pattern, dne_pattern])
Expand Down

0 comments on commit 3d91f81

Please sign in to comment.