From 3d91f813247c3d14379ced1532a4a655d080a128 Mon Sep 17 00:00:00 2001 From: Lilferrit Date: Fri, 30 Aug 2024 16:49:13 -0700 Subject: [PATCH] requested changes --- casanovo/casanovo.py | 52 ++++++++++++++++++++++++++------- casanovo/denovo/model_runner.py | 14 ++++----- casanovo/utils.py | 24 ++++++--------- tests/test_integration.py | 15 ---------- tests/unit_tests/test_runner.py | 13 +++++++-- tests/unit_tests/test_unit.py | 4 +-- 6 files changed, 69 insertions(+), 53 deletions(-) diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index bdeb9766..f9c37f08 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -96,7 +96,7 @@ def __init__(self, *args, **kwargs) -> None: default="info", ), click.Option( - ("-f", "--overwrite"), + ("-f", "--force_overwrite"), help="Whether to overwrite output files.", is_flag=True, show_default=True, @@ -159,7 +159,7 @@ def sequence( output_dir: Optional[str], output_root: Optional[str], verbosity: str, - overwrite: bool, + force_overwrite: bool, evaluate: bool, ) -> None: """De novo sequence peptides from tandem mass spectra. @@ -168,11 +168,13 @@ def sequence( to sequence peptides. If evaluate is set to True PEAK_PATH must be one or more annotated MGF file. """ - output, output_dir = _resolve_output(output_dir, output_root, verbosity) - if output_root is not None and not overwrite: - file_pattern = re.escape(output_root) + r"\.(?:log|mztab)" - utils.check_dir(output_dir, [file_pattern]) + file_patterns = list() + if output_root is not None and not force_overwrite: + file_patterns = [f"{output_root}.log", f"{output_root}.mztab"] + output, output_dir = _resolve_output( + output_dir, output_root, file_patterns, verbosity + ) config, model = setup_model(model, config, output, False) start_time = time.time() with ModelRunner(config, model, output_root, output_dir, False) as runner: @@ -216,21 +218,24 @@ def train( output_dir: Optional[str], output_root: Optional[str], verbosity: str, - overwrite: bool, + force_overwrite: bool, ) -> None: """Train a Casanovo model on your own data. TRAIN_PEAK_PATH must be one or more annoated MGF files, such as those provided by MassIVE-KB, from which to train a new Casnovo model. """ - output, output_dir = _resolve_output(output_dir, output_root, verbosity) - if output_root is not None and not overwrite: - utils.check_dir(output_dir, [re.escape(output_root) + r"\.log"]) + file_patterns = list() + if output_root is not None and not force_overwrite: + file_patterns = [f"{output_root}.log"] + output, output_dir = _resolve_output( + output_dir, output_root, file_patterns, verbosity + ) config, model = setup_model(model, config, output, True) start_time = time.time() with ModelRunner( - config, model, output_root, output_dir, not overwrite + config, model, output_root, output_dir, not force_overwrite ) as runner: logger.info("Training a model from:") for peak_file in train_peak_path: @@ -520,12 +525,37 @@ def _get_model_weights(cache_dir: Path) -> str: def _resolve_output( output_dir: str | None, output_root: str | None, + file_patterns: list[str], verbosity: str, ) -> Tuple[Path, str]: + """ + Resolves the output directory and sets up logging. + + Parameters: + ----------- + output_dir : str | None + The path to the output directory. If `None`, the current working + directory will be used. + output_root : str | None + The base name for the output files. If `None`, no specific base name is + set, and logging will be configured accordingly to the behavior of + `setup_logging`. + file_patterns : list[str] + A list of file patterns that should be checked within the `output_dir`. + verbosity : str + The verbosity level for logging. + + Returns: + -------- + Tuple[Path, str] + The output directory and the base name for log and results files (if + applicable). + """ output_dir = Path(output_dir) if output_dir is not None else Path.cwd() output_base_name = ( None if output_root is None else (output_dir / output_root) ) + utils.check_dir(output_dir, file_patterns) output = setup_logging(output_base_name, verbosity) return output, output_dir diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index ff441179..5c8833de 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -71,13 +71,13 @@ def __init__( ) if overwrite_ckpt_check: - patterns = [r"epoch=\d+\-step=\d+\.ckpt", r"best\.ckpt"] - if output_rootname is not None: - patterns = [ - re.escape(output_rootname + ".") + pattern - for pattern in patterns - ] - utils.check_dir(output_dir, patterns) + utils.check_dir( + output_dir, + [ + f"{curr_filename.format(epoch='*', step='*')}.ckpt", + f"{best_filename}.ckpt", + ], + ) # Configure checkpoints. self.callbacks = [ diff --git a/casanovo/utils.py b/casanovo/utils.py index 3d778791..fb0c3327 100644 --- a/casanovo/utils.py +++ b/casanovo/utils.py @@ -256,9 +256,7 @@ def log_sequencing_report( ) -def check_dir( - dir: pathlib.Path, file_patterns: Iterable[re.Pattern[str]] -) -> None: +def check_dir(dir: pathlib.Path, file_patterns: Iterable[str]) -> None: """ Check that no file names in dir match any of file_patterns @@ -266,8 +264,8 @@ def check_dir( ---------- dir : pathlib.Path The directory to check for matching file names - file_patterns : Iterable[re.Pattern[str]] - File name re patterns to test file names against + file_patterns : Iterable[str] + UNIX style wildcard pattern to test file names against Raises ------ @@ -275,13 +273,9 @@ def check_dir( If matching file name is found in dir """ for pattern in file_patterns: - comp_pattern = re.compile(pattern) - for file in dir.iterdir(): - if not file.is_file(): - continue - - if comp_pattern.fullmatch(file.name) is not None: - raise FileExistsError( - f"File {file.name} already exists in {dir} " - "and can not be overwritten." - ) + matches = list(dir.glob(pattern)) + if len(matches) > 0: + raise FileExistsError( + f"File {matches[0].name} already exists in {dir} " + "and can not be overwritten." + ) diff --git a/tests/test_integration.py b/tests/test_integration.py index ded8294d..bb6ef66e 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -3,7 +3,6 @@ from pathlib import Path import pyteomics.mztab -import pytest from click.testing import CliRunner from casanovo import casanovo @@ -44,10 +43,6 @@ def test_train_and_run( assert model_file.exists() assert best_model.exists() - # Check that re-running train fails due to no overwrite - with pytest.raises(FileExistsError): - run(train_args) - assert model_file.exists() assert best_model.exists() @@ -90,12 +85,6 @@ def test_train_and_run( assert psms.loc[4, "sequence"] == "PEPTLDEK" assert psms.loc[4, "spectra_ref"] == "ms_run[2]:scan=111" - # Verify that running predict again fails due to no overwrite - with pytest.raises(FileExistsError): - run(predict_args) - - assert output_filename.is_file() - # Finally, try evaluating: output_rootname = "test-eval" output_filename = (tmp_path / output_rootname).with_suffix(".mztab") @@ -156,10 +145,6 @@ def test_train_and_run( ] ) - # Verify that running again fails due to no overwrite - with pytest.raises(FileExistsError): - run(eval_args) - assert output_filename.is_file() diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py index f31725ff..8cea9021 100644 --- a/tests/unit_tests/test_runner.py +++ b/tests/unit_tests/test_runner.py @@ -121,12 +121,17 @@ def test_save_and_load_weights_deprecated(tmp_path, mgf_small, tiny_config): torch.save(ckpt_data, str(ckpt)) # Inference. - with ModelRunner(config=config, model_filename=str(ckpt)) as runner: + with ModelRunner( + config=config, model_filename=str(ckpt), overwrite_ckpt_check=False + ) as runner: runner.initialize_model(train=False) assert runner.model.cosine_schedule_period_iters == 5 # Fine-tuning. with ModelRunner( - config=config, model_filename=str(ckpt), output_dir=tmp_path + config=config, + model_filename=str(ckpt), + output_dir=tmp_path, + overwrite_ckpt_check=False, ) as runner: with pytest.warns(DeprecationWarning): runner.train([mgf_small], [mgf_small]) @@ -149,7 +154,9 @@ def test_calculate_precision(tmp_path, mgf_small, tiny_config): assert "valid_pep_precision" not in runner.model.history.columns config.calculate_precision = True - runner = ModelRunner(config=config, output_dir=tmp_path) + runner = ModelRunner( + config=config, output_dir=tmp_path, overwrite_ckpt_check=False + ) with runner: runner.train([mgf_small], [mgf_small]) diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index 33f20bf5..8cdd291d 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -924,9 +924,9 @@ def test_run_map(mgf_small): def test_check_dir(tmp_path): exists_path = tmp_path / "exists-1234.ckpt" - exists_pattern = r"exists\-\d+\.ckpt" + exists_pattern = "exists-*.ckpt" exists_path.touch() - dne_pattern = r"dne\-\d+\.ckpt" + dne_pattern = "dne-*.ckpt" with pytest.raises(FileExistsError): utils.check_dir(tmp_path, [exists_pattern, dne_pattern])