Skip to content

Commit

Permalink
verify annotated mgf files
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Aug 9, 2024
1 parent 81a3267 commit 7b9557b
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 15 deletions.
55 changes: 45 additions & 10 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import uuid
import warnings
from pathlib import Path
from typing import Iterable, List, Optional, Union
from typing import Iterable, List, Optional, Union, TextIO

import depthcharge.masses
import lightning.pytorch as pl
Expand Down Expand Up @@ -399,6 +399,21 @@ def _get_index(

msg = msg.strip()
filenames = _get_peak_filenames(peak_path, ext)
if annotated:
# Filter unannotated MGF files to avoid Depth Charge exception
filtered_fnames = list()
for fname in filenames:
if Path(
fname
).suffix.lower() == ".mgf" and not _mgf_is_annotated(fname):
warnings.warn(
f"Ignoring unannotated MGF peak file: {fname}",
RuntimeWarning,
)
else:
filtered_fnames.append(fname)
filenames = filtered_fnames

if not filenames:
not_found_err = f"Cound not find {msg} peak files"
logger.error(not_found_err + " from %s", peak_path)
Expand All @@ -415,15 +430,9 @@ def _get_index(
else:
index_fname = Path(self.tmp_dir.name) / f"{uuid.uuid4().hex}.hdf5"

try:
Index = AnnotatedSpectrumIndex if annotated else SpectrumIndex
valid_charge = np.arange(1, self.config.max_charge + 1)
return Index(index_fname, filenames, valid_charge=valid_charge)
except TypeError:
raise ValueError(
"MGF peak files must be annotated when constructing "
"annotated spectrum indices."
)
Index = AnnotatedSpectrumIndex if annotated else SpectrumIndex
valid_charge = np.arange(1, self.config.max_charge + 1)
return Index(index_fname, filenames, valid_charge=valid_charge)

def _get_strategy(self) -> Union[str, DDPStrategy]:
"""Get the strategy for the Trainer.
Expand All @@ -448,6 +457,32 @@ def _get_strategy(self) -> Union[str, DDPStrategy]:
return "auto"


def _mgf_is_annotated(mgf_path: TextIO) -> bool:
"""Check whether MGF file is annotated
Parameters
----------
mgf_path : TextIO
MGF peak file to check
Returns
-------
bool
Whether MGF peak file is annotated
"""
num_spectra = 0
num_annotations = 0

with open(mgf_path) as f:
for curr_line in f:
if curr_line.startswith("BEGIN IONS"):
num_spectra += 1
elif curr_line.startswith("SEQ="):
num_annotations += 1

return num_spectra == num_annotations


def _get_peak_filenames(
paths: Iterable[str], supported_ext: Iterable[str]
) -> List[str]:
Expand Down
28 changes: 23 additions & 5 deletions tests/unit_tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from casanovo.config import Config
from casanovo.denovo.model_runner import ModelRunner
from casanovo.denovo.model_runner import ModelRunner, _mgf_is_annotated


def test_initialize_model(tmp_path, mgf_small):
Expand Down Expand Up @@ -203,31 +203,49 @@ def test_evaluate(
with ModelRunner(config, model_filename=str(model_file)) as runner:
runner.predict([mzml_small], result_file, evaluate=True)

with pytest.raises(ValueError):
with pytest.raises(FileNotFoundError):
with ModelRunner(config, model_filename=str(model_file)) as runner:
runner.predict([mgf_small_unannotated], result_file, evaluate=True)

with pytest.raises(ValueError):
with pytest.raises(FileNotFoundError):
with ModelRunner(config, model_filename=str(model_file)) as runner:
runner.predict(
[mgf_small_unannotated, mzml_small], result_file, evaluate=True
)

# MzTab with just metadata is written in the case of FileNotFound
# early exit
result_file.unlink()

# Test mix of annotated an unannotated peak files
with pytest.warns(RuntimeWarning):
with ModelRunner(config, model_filename=str(model_file)) as runner:
runner.predict([mgf_small, mzml_small], result_file, evaluate=True)

with pytest.raises(ValueError):
assert result_file.is_file()
result_file.unlink()

with pytest.warns(RuntimeWarning):
with ModelRunner(config, model_filename=str(model_file)) as runner:
runner.predict(
[mgf_small, mgf_small_unannotated], result_file, evaluate=True
)

with pytest.raises(ValueError):
assert result_file.is_file()
result_file.unlink()

with pytest.warns(RuntimeWarning):
with ModelRunner(config, model_filename=str(model_file)) as runner:
runner.predict(
[mgf_small, mgf_small_unannotated, mzml_small],
result_file,
evaluate=True,
)

assert result_file.is_file()
result_file.unlink()


def test_mgf_is_annotated(mgf_small, mgf_small_unannotated):
assert _mgf_is_annotated(mgf_small)
assert not _mgf_is_annotated(mgf_small_unannotated)

0 comments on commit 7b9557b

Please sign in to comment.