From 7b9557b6d6bdcf773bd3ea4d857596a68ad272e0 Mon Sep 17 00:00:00 2001 From: Lilferrit Date: Fri, 9 Aug 2024 13:30:48 -0700 Subject: [PATCH] verify annotated mgf files --- casanovo/denovo/model_runner.py | 55 +++++++++++++++++++++++++++------ tests/unit_tests/test_runner.py | 28 ++++++++++++++--- 2 files changed, 68 insertions(+), 15 deletions(-) diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 1404a1fe..b47b6688 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -8,7 +8,7 @@ import uuid import warnings from pathlib import Path -from typing import Iterable, List, Optional, Union +from typing import Iterable, List, Optional, Union, TextIO import depthcharge.masses import lightning.pytorch as pl @@ -399,6 +399,21 @@ def _get_index( msg = msg.strip() filenames = _get_peak_filenames(peak_path, ext) + if annotated: + # Filter unannotated MGF files to avoid Depth Charge exception + filtered_fnames = list() + for fname in filenames: + if Path( + fname + ).suffix.lower() == ".mgf" and not _mgf_is_annotated(fname): + warnings.warn( + f"Ignoring unannotated MGF peak file: {fname}", + RuntimeWarning, + ) + else: + filtered_fnames.append(fname) + filenames = filtered_fnames + if not filenames: not_found_err = f"Cound not find {msg} peak files" logger.error(not_found_err + " from %s", peak_path) @@ -415,15 +430,9 @@ def _get_index( else: index_fname = Path(self.tmp_dir.name) / f"{uuid.uuid4().hex}.hdf5" - try: - Index = AnnotatedSpectrumIndex if annotated else SpectrumIndex - valid_charge = np.arange(1, self.config.max_charge + 1) - return Index(index_fname, filenames, valid_charge=valid_charge) - except TypeError: - raise ValueError( - "MGF peak files must be annotated when constructing " - "annotated spectrum indices." - ) + Index = AnnotatedSpectrumIndex if annotated else SpectrumIndex + valid_charge = np.arange(1, self.config.max_charge + 1) + return Index(index_fname, filenames, valid_charge=valid_charge) def _get_strategy(self) -> Union[str, DDPStrategy]: """Get the strategy for the Trainer. @@ -448,6 +457,32 @@ def _get_strategy(self) -> Union[str, DDPStrategy]: return "auto" +def _mgf_is_annotated(mgf_path: TextIO) -> bool: + """Check whether MGF file is annotated + + Parameters + ---------- + mgf_path : TextIO + MGF peak file to check + + Returns + ------- + bool + Whether MGF peak file is annotated + """ + num_spectra = 0 + num_annotations = 0 + + with open(mgf_path) as f: + for curr_line in f: + if curr_line.startswith("BEGIN IONS"): + num_spectra += 1 + elif curr_line.startswith("SEQ="): + num_annotations += 1 + + return num_spectra == num_annotations + + def _get_peak_filenames( paths: Iterable[str], supported_ext: Iterable[str] ) -> List[str]: diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py index fd52aeef..853e193f 100644 --- a/tests/unit_tests/test_runner.py +++ b/tests/unit_tests/test_runner.py @@ -6,7 +6,7 @@ import torch from casanovo.config import Config -from casanovo.denovo.model_runner import ModelRunner +from casanovo.denovo.model_runner import ModelRunner, _mgf_is_annotated def test_initialize_model(tmp_path, mgf_small): @@ -203,31 +203,49 @@ def test_evaluate( with ModelRunner(config, model_filename=str(model_file)) as runner: runner.predict([mzml_small], result_file, evaluate=True) - with pytest.raises(ValueError): + with pytest.raises(FileNotFoundError): with ModelRunner(config, model_filename=str(model_file)) as runner: runner.predict([mgf_small_unannotated], result_file, evaluate=True) - with pytest.raises(ValueError): + with pytest.raises(FileNotFoundError): with ModelRunner(config, model_filename=str(model_file)) as runner: runner.predict( [mgf_small_unannotated, mzml_small], result_file, evaluate=True ) + # MzTab with just metadata is written in the case of FileNotFound + # early exit + result_file.unlink() + # Test mix of annotated an unannotated peak files with pytest.warns(RuntimeWarning): with ModelRunner(config, model_filename=str(model_file)) as runner: runner.predict([mgf_small, mzml_small], result_file, evaluate=True) - with pytest.raises(ValueError): + assert result_file.is_file() + result_file.unlink() + + with pytest.warns(RuntimeWarning): with ModelRunner(config, model_filename=str(model_file)) as runner: runner.predict( [mgf_small, mgf_small_unannotated], result_file, evaluate=True ) - with pytest.raises(ValueError): + assert result_file.is_file() + result_file.unlink() + + with pytest.warns(RuntimeWarning): with ModelRunner(config, model_filename=str(model_file)) as runner: runner.predict( [mgf_small, mgf_small_unannotated, mzml_small], result_file, evaluate=True, ) + + assert result_file.is_file() + result_file.unlink() + + +def test_mgf_is_annotated(mgf_small, mgf_small_unannotated): + assert _mgf_is_annotated(mgf_small) + assert not _mgf_is_annotated(mgf_small_unannotated)