diff --git a/casanovo/__init__.py b/casanovo/__init__.py index 1afa731a..f0756992 100644 --- a/casanovo/__init__.py +++ b/casanovo/__init__.py @@ -1,4 +1,3 @@ from .version import _get_version - __version__ = _get_version() diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index fef73a9b..3bda9cd5 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -41,10 +41,9 @@ import tqdm from lightning.pytorch import seed_everything -from . import __version__ -from . import utils -from .denovo import ModelRunner +from . import __version__, utils from .config import Config +from .denovo import ModelRunner logger = logging.getLogger("casanovo") click.rich_click.USE_MARKDOWN = True @@ -139,7 +138,7 @@ def main() -> None: "peak_path", required=True, nargs=-1, - type=click.Path(exists=True, dir_okay=False), + type=click.Path(exists=True, dir_okay=True), ) @click.option( "--evaluate", @@ -206,7 +205,7 @@ def sequence( "peak_path", required=True, nargs=-1, - type=click.Path(exists=True, dir_okay=False), + type=click.Path(exists=True, dir_okay=True), ) @click.argument( "fasta_path", @@ -266,7 +265,7 @@ def db_search( "train_peak_path", required=True, nargs=-1, - type=click.Path(exists=True, dir_okay=False), + type=click.Path(exists=True, dir_okay=True), ) @click.option( "-p", @@ -277,7 +276,7 @@ def db_search( """, required=False, multiple=True, - type=click.Path(exists=True, dir_okay=False), + type=click.Path(exists=True, dir_okay=True), ) def train( train_peak_path: Tuple[str], diff --git a/casanovo/config.py b/casanovo/config.py index e276e12d..76c0ec5d 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -4,7 +4,7 @@ import shutil import warnings from pathlib import Path -from typing import Optional, Dict, Callable, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union import yaml @@ -55,6 +55,12 @@ class Config: max_charge=int, precursor_mass_tol=float, isotope_error_range=lambda min_max: (int(min_max[0]), int(min_max[1])), + enzyme=str, + digestion=str, + missed_cleavages=int, + max_mods=int, + allowed_fixed_mods=str, + allowed_var_mods=str, min_peptide_len=int, dim_model=int, n_head=int, @@ -83,6 +89,16 @@ class Config: calculate_precision=bool, accelerator=str, devices=int, + lance_dir=str, + shuffle=bool, + buffer_size=int, + reverse_peptides=bool, + replace_isoleucine_with_leucine=bool, + accumulate_grad_batches=int, + gradient_clip_val=float, + gradient_clip_algorithm=str, + precision=str, + mskb_tokenizer=bool, ) def __init__(self, config_file: Optional[str] = None): diff --git a/casanovo/config.yaml b/casanovo/config.yaml index b7179347..74d6b782 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -63,8 +63,8 @@ max_mods: 1 # where aa is a standard amino acid (or "nterm" for an N-terminal mod) # and mod_residue is a key from the "residues" dictionary. # Example: "M:M+15.995,nterm:+43.006" -allowed_fixed_mods: "C:C+57.021" -allowed_var_mods: "M:M+15.995,N:N+0.984,Q:Q+0.984,nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" +allowed_fixed_mods: "C:C[Carbamidomethyl]" +allowed_var_mods: "M:M[Oxidation],N:N[Deamidated],Q:Q[Deamidated],nterm:[Acetyl]-,nterm:[Carbamyl]-,nterm:[Ammonia-loss]-,nterm:[+25.980265]-" ### @@ -84,6 +84,8 @@ tb_summarywriter: false log_metrics: false # How often to log optimizer parameters in steps log_every_n_steps: 50 +# Path to save lance instances +lance_dir: # Model validation and checkpointing frequency in training steps. val_check_interval: 50_000 @@ -125,6 +127,10 @@ learning_rate: 5e-4 weight_decay: 1e-5 # Amount of label smoothing when computing the training loss. train_label_smoothing: 0.01 +# Shuffle dataset during training. +# A buffer of size buffer_size is filled and examples from this buffer are randomly sampled. +shuffle: True +buffer_size: 100_000 # TRAINING/INFERENCE OPTIONS # Number of spectra in one training batch. @@ -137,6 +143,19 @@ num_sanity_val_steps: 0 # This is expensive, so we recommend against it. calculate_precision: False +# Additional Pytorch lightning trainer flags +accumulate_grad_batches: 1 +gradient_clip_val: +gradient_clip_algorithm: +precision: "32-true" # '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true', '64', '32', '16', 'bf16' + +# Replace I by L in peptide sequences +replace_isoleucine_with_leucine: True +# Reverse peptide sequences +reverse_peptides: True +# mskb tokenizer, otherwise proforma syntax +mskb_tokenizer: True + # AMINO ACID AND MODIFICATION VOCABULARY residues: "G": 57.021464 @@ -145,7 +164,7 @@ residues: "P": 97.052764 "V": 99.068414 "T": 101.047670 - "C+57.021": 160.030649 # 103.009185 + 57.021464 + "C[Carbamidomethyl]": 160.030649 # 103.009185 + 57.021464 "L": 113.084064 "I": 113.084064 "N": 114.042927 @@ -160,11 +179,11 @@ residues: "Y": 163.063329 "W": 186.079313 # Amino acid modifications. - "M+15.995": 147.035400 # Met oxidation: 131.040485 + 15.994915 - "N+0.984": 115.026943 # Asn deamidation: 114.042927 + 0.984016 - "Q+0.984": 129.042594 # Gln deamidation: 128.058578 + 0.984016 + "M[Oxidation]": 147.035400 # Met oxidation: 131.040485 + 15.994915 + "N[Deamidated]": 115.026943 # Asn deamidation: 114.042927 + 0.984016 + "Q[Deamidated]": 129.042594 # Gln deamidation: 128.058578 + 0.984016 # N-terminal modifications. - "+42.011": 42.010565 # Acetylation - "+43.006": 43.005814 # Carbamylation - "-17.027": -17.026549 # NH3 loss - "+43.006-17.027": 25.980265 # Carbamylation and NH3 loss + "[Acetyl]-": 42.010565 # Acetylation + "[Carbamyl]-": 43.005814 # Carbamylation "+43.006" + "[Ammonia-loss]-": -17.026549 # NH3 loss + "[+25.980265]-": 25.980265 # Carbamylation and NH3 loss diff --git a/casanovo/data/datasets.py b/casanovo/data/datasets.py deleted file mode 100644 index 3917a2c8..00000000 --- a/casanovo/data/datasets.py +++ /dev/null @@ -1,269 +0,0 @@ -"""A PyTorch Dataset class for annotated spectra.""" - -from typing import Optional, Tuple - -import depthcharge -import numpy as np -import spectrum_utils.spectrum as sus -import torch -from torch.utils.data import Dataset - - -class SpectrumDataset(Dataset): - """ - Parse and retrieve collections of MS/MS spectra. - - Parameters - ---------- - spectrum_index : depthcharge.data.SpectrumIndex - The MS/MS spectra to use as a dataset. - n_peaks : Optional[int] - The number of top-n most intense peaks to keep in each spectrum. `None` - retains all peaks. - min_mz : float - The minimum m/z to include. The default is 140 m/z, in order to exclude - TMT and iTRAQ reporter ions. - max_mz : float - The maximum m/z to include. - min_intensity : float - Remove peaks whose intensity is below `min_intensity` percentage of the - base peak intensity. - remove_precursor_tol : float - Remove peaks within the given mass tolerance in Dalton around the - precursor mass. - random_state : Optional[int] - The NumPy random state. ``None`` leaves mass spectra in the order they - were parsed. - """ - - def __init__( - self, - spectrum_index: depthcharge.data.SpectrumIndex, - n_peaks: int = 150, - min_mz: float = 140.0, - max_mz: float = 2500.0, - min_intensity: float = 0.01, - remove_precursor_tol: float = 2.0, - random_state: Optional[int] = None, - ): - """Initialize a SpectrumDataset""" - super().__init__() - self.n_peaks = n_peaks - self.min_mz = min_mz - self.max_mz = max_mz - self.min_intensity = min_intensity - self.remove_precursor_tol = remove_precursor_tol - self.rng = np.random.default_rng(random_state) - self._index = spectrum_index - - def __len__(self) -> int: - """The number of spectra.""" - return self.n_spectra - - def __getitem__( - self, idx - ) -> Tuple[torch.Tensor, float, int, Tuple[str, str]]: - """ - Return the MS/MS spectrum with the given index. - - Parameters - ---------- - idx : int - The index of the spectrum to return. - - Returns - ------- - spectrum : torch.Tensor of shape (n_peaks, 2) - A tensor of the spectrum with the m/z and intensity peak values. - precursor_mz : float - The precursor m/z. - precursor_charge : int - The precursor charge. - spectrum_id: Tuple[str, str] - The unique spectrum identifier, formed by its original peak file and - identifier (index or scan number) therein. - """ - mz_array, int_array, precursor_mz, precursor_charge = self.index[idx][ - :4 - ] - spectrum = self._process_peaks( - mz_array, int_array, precursor_mz, precursor_charge - ) - return ( - spectrum, - precursor_mz, - precursor_charge, - self.get_spectrum_id(idx), - ) - - def get_spectrum_id(self, idx: int) -> Tuple[str, str]: - """ - Return the identifier of the MS/MS spectrum with the given index. - - Parameters - ---------- - idx : int - The index of the MS/MS spectrum within the SpectrumIndex. - - Returns - ------- - ms_data_file : str - The peak file from which the MS/MS spectrum was originally parsed. - identifier : str - The MS/MS spectrum identifier, per PSI recommendations. - """ - with self.index: - return self.index.get_spectrum_id(idx) - - def _process_peaks( - self, - mz_array: np.ndarray, - int_array: np.ndarray, - precursor_mz: float, - precursor_charge: int, - ) -> torch.Tensor: - """ - Preprocess the spectrum by removing noise peaks and scaling the peak - intensities. - - Parameters - ---------- - mz_array : numpy.ndarray of shape (n_peaks,) - The spectrum peak m/z values. - int_array : numpy.ndarray of shape (n_peaks,) - The spectrum peak intensity values. - precursor_mz : float - The precursor m/z. - precursor_charge : int - The precursor charge. - - Returns - ------- - torch.Tensor of shape (n_peaks, 2) - A tensor of the spectrum with the m/z and intensity peak values. - """ - spectrum = sus.MsmsSpectrum( - "", - precursor_mz, - precursor_charge, - mz_array.astype(np.float64), - int_array.astype(np.float32), - ) - try: - spectrum.set_mz_range(self.min_mz, self.max_mz) - if len(spectrum.mz) == 0: - raise ValueError - spectrum.remove_precursor_peak(self.remove_precursor_tol, "Da") - if len(spectrum.mz) == 0: - raise ValueError - spectrum.filter_intensity(self.min_intensity, self.n_peaks) - if len(spectrum.mz) == 0: - raise ValueError - spectrum.scale_intensity("root", 1) - intensities = spectrum.intensity / np.linalg.norm( - spectrum.intensity - ) - return torch.tensor(np.array([spectrum.mz, intensities])).T.float() - except ValueError: - # Replace invalid spectra by a dummy spectrum. - return torch.tensor([[0, 1]]).float() - - @property - def n_spectra(self) -> int: - """The total number of spectra.""" - return self.index.n_spectra - - @property - def index(self) -> depthcharge.data.SpectrumIndex: - """The underlying SpectrumIndex.""" - return self._index - - @property - def rng(self): - """The NumPy random number generator.""" - return self._rng - - @rng.setter - def rng(self, seed): - """Set the NumPy random number generator.""" - self._rng = np.random.default_rng(seed) - - -class AnnotatedSpectrumDataset(SpectrumDataset): - """ - Parse and retrieve collections of annotated MS/MS spectra. - - Parameters - ---------- - annotated_spectrum_index : depthcharge.data.SpectrumIndex - The MS/MS spectra to use as a dataset. - n_peaks : Optional[int] - The number of top-n most intense peaks to keep in each spectrum. `None` - retains all peaks. - min_mz : float - The minimum m/z to include. The default is 140 m/z, in order to exclude - TMT and iTRAQ reporter ions. - max_mz : float - The maximum m/z to include. - min_intensity : float - Remove peaks whose intensity is below `min_intensity` percentage of the - base peak intensity. - remove_precursor_tol : float - Remove peaks within the given mass tolerance in Dalton around the - precursor mass. - random_state : Optional[int] - The NumPy random state. ``None`` leaves mass spectra in the order they - were parsed. - """ - - def __init__( - self, - annotated_spectrum_index: depthcharge.data.SpectrumIndex, - n_peaks: int = 150, - min_mz: float = 140.0, - max_mz: float = 2500.0, - min_intensity: float = 0.01, - remove_precursor_tol: float = 2.0, - random_state: Optional[int] = None, - ): - super().__init__( - annotated_spectrum_index, - n_peaks=n_peaks, - min_mz=min_mz, - max_mz=max_mz, - min_intensity=min_intensity, - remove_precursor_tol=remove_precursor_tol, - random_state=random_state, - ) - - def __getitem__(self, idx: int) -> Tuple[torch.Tensor, float, int, str]: - """ - Return the annotated MS/MS spectrum with the given index. - - Parameters - ---------- - idx : int - The index of the spectrum to return. - - Returns - ------- - spectrum : torch.Tensor of shape (n_peaks, 2) - A tensor of the spectrum with the m/z and intensity peak values. - precursor_mz : float - The precursor m/z. - precursor_charge : int - The precursor charge. - annotation : str - The peptide annotation of the spectrum. - """ - ( - mz_array, - int_array, - precursor_mz, - precursor_charge, - peptide, - ) = self.index[idx] - spectrum = self._process_peaks( - mz_array, int_array, precursor_mz, precursor_charge - ) - return spectrum, precursor_mz, precursor_charge, peptide diff --git a/casanovo/data/db_utils.py b/casanovo/data/db_utils.py index d3670930..6c5bc69a 100644 --- a/casanovo/data/db_utils.py +++ b/casanovo/data/db_utils.py @@ -7,13 +7,13 @@ import string from typing import Dict, Iterator, Pattern, Set, Tuple -import depthcharge.masses +import depthcharge.constants +import depthcharge.tokenizers import numpy as np import pandas as pd import pyteomics.fasta import pyteomics.parser - logger = logging.getLogger("casanovo") # CONSTANTS @@ -53,8 +53,8 @@ class ProteinDatabase: A comma-separated string of fixed modifications to consider. allowed_var_mods : str A comma-separated string of variable modifications to consider. - residues : Dict[str, float] - A dictionary of amino acid masses. + tokenizer: depthcharge.tokenizers.PeptideTokenizer + Used to access residues. """ def __init__( @@ -70,7 +70,7 @@ def __init__( isotope_error: Tuple[int, int], allowed_fixed_mods: str, allowed_var_mods: str, - residues: Dict[str, float], + tokenizer: depthcharge.tokenizers.PeptideTokenizer, ): self.fixed_mods, self.var_mods, self.swap_map = _construct_mods_dict( allowed_fixed_mods, allowed_var_mods @@ -86,7 +86,9 @@ def __init__( missed_cleavages, min_peptide_len, max_peptide_len, - set([aa[0] for aa in residues.keys() if aa[0].isalpha()]), + set( + [aa[0] for aa in tokenizer.residues.keys() if aa[0].isalpha()] + ), ) logger.info( "Digesting FASTA file (enzyme = %s, digestion = %s, missed " @@ -95,6 +97,7 @@ def __init__( digestion, missed_cleavages, ) + self.tokenizer = tokenizer self.db_peptides = self._digest_fasta(peptide_generator) self.precursor_tolerance = precursor_tolerance self.isotope_error = isotope_error @@ -148,9 +151,8 @@ def _digest_fasta( .reset_index() ) # Calculate the mass of each peptide. - mass_calculator = depthcharge.masses.PeptideMass(residues="massivekb") peptides["calc_mass"] = ( - peptides["peptide"].apply(mass_calculator.mass).round(5) + peptides["peptide"].apply(self._calc_pep_mass).round(5) ) # Sort by peptide mass and index by peptide sequence. peptides.sort_values( @@ -163,6 +165,27 @@ def _digest_fasta( ) return peptides + def _calc_pep_mass(self, pep: str) -> float: + """ + Calculates the neutral mass of a peptide sequence. + + Parameters + ---------- + pep : str + The peptide sequence for which the mass is to be calculated. + + Returns + ------- + float + The neutral mass of the peptide + """ + return ( + self.tokenizer.masses[self.tokenizer.tokenize(pep)] + .sum(dim=1) + .item() + + depthcharge.constants.H2O + ) + def get_candidates( self, precursor_mz: float, diff --git a/casanovo/data/ms_io.py b/casanovo/data/ms_io.py index bb9a8a3e..da9f7dbb 100644 --- a/casanovo/data/ms_io.py +++ b/casanovo/data/ms_io.py @@ -142,7 +142,7 @@ def set_ms_run(self, peak_filenames: List[str]) -> None: self.metadata.append( (f"ms_run[{i}]-location", Path(filename).as_uri()), ) - self._run_map[filename] = i + self._run_map[Path(filename).name] = i def save(self) -> None: """ @@ -184,8 +184,11 @@ def save(self) -> None: ), 1, ): - filename = os.path.abspath(psm.spectrum_id[0]) + filename = psm.spectrum_id[0] idx = psm.spectrum_id[1] + if Path(filename).suffix == ".mgf" and idx.isnumeric(): + idx = f"index={idx}" + writer.writerow( [ "PSM", diff --git a/casanovo/data/psm.py b/casanovo/data/psm.py index eece07a4..cef4a29a 100644 --- a/casanovo/data/psm.py +++ b/casanovo/data/psm.py @@ -1,7 +1,7 @@ """Peptide spectrum match dataclass.""" import dataclasses -from typing import Tuple, Iterable +from typing import Iterable, Tuple @dataclasses.dataclass diff --git a/casanovo/denovo/dataloaders.py b/casanovo/denovo/dataloaders.py index cdbf71bf..c22e7887 100644 --- a/casanovo/denovo/dataloaders.py +++ b/casanovo/denovo/dataloaders.py @@ -1,18 +1,24 @@ """Data loaders for the de novo sequencing task.""" -import functools import logging import os -from typing import List, Optional, Tuple +import tempfile +from pathlib import Path +from typing import Iterable, Optional import lightning.pytorch as pl import numpy as np +import pyarrow as pa import torch -from depthcharge.data import AnnotatedSpectrumIndex - -from ..data import db_utils -from ..data.datasets import AnnotatedSpectrumDataset, SpectrumDataset - +from depthcharge.data import ( + AnnotatedSpectrumDataset, + CustomField, + SpectrumDataset, + preprocessing, +) +from depthcharge.tokenizers import PeptideTokenizer +from torch.utils.data import DataLoader +from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe logger = logging.getLogger("casanovo") @@ -23,12 +29,12 @@ class DeNovoDataModule(pl.LightningDataModule): Parameters ---------- - train_index : Optional[AnnotatedSpectrumIndex] - The spectrum index file corresponding to the training data. - valid_index : Optional[AnnotatedSpectrumIndex] - The spectrum index file corresponding to the validation data. - test_index : Optional[AnnotatedSpectrumIndex] - The spectrum index file corresponding to the testing data. + train_paths : str, optional + A spectrum lance path for model training. + valid_paths : str, optional + A spectrum lance path for validation. + test_paths : str, optional + A spectrum lance path for evaluation or inference. train_batch_size : int The batch size to use for training. eval_batch_size : int @@ -48,18 +54,27 @@ class DeNovoDataModule(pl.LightningDataModule): Remove peaks within the given mass tolerance in Dalton around the precursor mass. n_workers : int, optional - The number of workers to use for data loading. By default, the - number of available CPU cores on the current machine is used. + The number of workers to use for data loading. By default, the number of + available CPU cores on the current machine is used. + max_charge: int + Remove PSMs which precursor charge higher than specified max_charge + tokenizer: Optional[PeptideTokenizer] + Peptide tokenizer for tokenizing sequences random_state : Optional[int] - The NumPy random state. ``None`` leaves mass spectra in the - order they were parsed. + The NumPy random state. ``None`` leaves mass spectra in the order they + were parsed. + shuffle: Optional[bool] + Should the training dataset be shuffled? Suffling based on specified buffer_size + buffer_size: Optional[int] + See more here: + https://huggingface.co/docs/datasets/v1.11.0/dataset_streaming.html#shuffling-the-dataset-shuffle """ def __init__( self, - train_index: Optional[AnnotatedSpectrumIndex] = None, - valid_index: Optional[AnnotatedSpectrumIndex] = None, - test_index: Optional[AnnotatedSpectrumIndex] = None, + train_paths: Optional[Iterable[str]] = None, + valid_paths: Optional[Iterable[str]] = None, + test_paths: Optional[str] = None, train_batch_size: int = 128, eval_batch_size: int = 1028, n_peaks: Optional[int] = 150, @@ -69,25 +84,135 @@ def __init__( remove_precursor_tol: float = 2.0, n_workers: Optional[int] = None, random_state: Optional[int] = None, + max_charge: Optional[int] = 10, + tokenizer: Optional[PeptideTokenizer] = None, + lance_dir: Optional[str] = None, + shuffle: Optional[bool] = True, + buffer_size: Optional[int] = 100_000, ): super().__init__() - self.train_index: Optional[AnnotatedSpectrumIndex] = train_index - self.valid_index: Optional[AnnotatedSpectrumIndex] = valid_index - self.test_index: Optional[AnnotatedSpectrumIndex] = test_index + self.train_paths = train_paths + self.valid_paths = valid_paths + self.test_paths = test_paths self.train_batch_size = train_batch_size self.eval_batch_size = eval_batch_size - self.n_peaks: Optional[int] = n_peaks - self.min_mz = min_mz - self.max_mz = max_mz - self.min_intensity = min_intensity - self.remove_precursor_tol = remove_precursor_tol - self.n_workers = n_workers if n_workers is not None else os.cpu_count() - self.rng = np.random.default_rng(random_state) + + self.tokenizer = ( + tokenizer if tokenizer is not None else PeptideTokenizer() + ) + self.lance_dir = ( + lance_dir + if lance_dir is not None + else tempfile.TemporaryDirectory(suffix=".lance").name + ) + self.train_dataset = None self.valid_dataset = None self.test_dataset = None self.protein_database = None + self.n_workers = n_workers if n_workers is not None else os.cpu_count() + self.shuffle = ( + shuffle if shuffle else None + ) # set to None if not wanted. Otherwise torch throws and error + self.buffer_size = buffer_size + + self.valid_charge = np.arange(1, max_charge + 1) + self.preprocessing_fn = [ + preprocessing.set_mz_range(min_mz=min_mz, max_mz=max_mz), + preprocessing.remove_precursor_peak(remove_precursor_tol, "Da"), + preprocessing.filter_intensity(min_intensity, n_peaks), + preprocessing.scale_intensity("root", 1), + scale_to_unit_norm, + ] + self.custom_field_test_mgf = [ + CustomField("title", lambda x: x["params"]["title"], pa.string()), + ] + self.custom_field_test_mzml = [ + CustomField("title", lambda x: x["id"], pa.string()), + ] + + self.custom_field_anno = [ + CustomField("seq", lambda x: x["params"]["seq"], pa.string()) + ] + + def make_dataset(self, paths, annotated, mode, shuffle): + """Make spectrum datasets. + + Parameters + ---------- + paths : Iterable[str] + Paths to input datasets + annotated: bool + True if peptide sequence annotations are available for the test + data. + mode: str {"train", "valid", "test"} + The mode indicating name of lance instance + shuffle: bool + Indicates whether to shuffle training data based on buffer_size + """ + custom_fields = self.custom_field_anno if annotated else [] + + if mode == "test": + if all([Path(f).suffix in (".mgf") for f in paths]): + custom_fields = custom_fields + self.custom_field_test_mgf + if all( + [Path(f).suffix in (".mzml", ".mzxml", ".mzML") for f in paths] + ): + custom_fields = custom_fields + self.custom_field_test_mzml + + lance_path = f"{self.lance_dir}/{mode}.lance" + + parse_kwargs = dict( + preprocessing_fn=self.preprocessing_fn, + custom_fields=custom_fields, + valid_charge=self.valid_charge, + ) + + dataset_params = dict( + batch_size=( + self.train_batch_size + if mode == "train" + else self.eval_batch_size + ) + ) + anno_dataset_params = dataset_params | dict( + tokenizer=self.tokenizer, + annotations="seq", + ) + + if any([Path(f).suffix in (".lance") for f in paths]): + if annotated: + dataset = AnnotatedSpectrumDataset.from_lance( + paths[0], **anno_dataset_params + ) + else: + dataset = SpectrumDataset.from_lance( + paths[0], **dataset_params + ) + else: + if annotated: + dataset = AnnotatedSpectrumDataset( + spectra=paths, + path=lance_path, + parse_kwargs=parse_kwargs, + **anno_dataset_params, + ) + else: + dataset = SpectrumDataset( + spectra=paths, + path=lance_path, + parse_kwargs=parse_kwargs, + **dataset_params, + ) + + if shuffle: + dataset = ShufflerIterDataPipe( + dataset, buffer_size=self.buffer_size + ) + + return dataset + def setup(self, stage: str = None, annotated: bool = True) -> None: """ Set up the PyTorch Datasets. @@ -102,39 +227,34 @@ def setup(self, stage: str = None, annotated: bool = True) -> None: test data. """ if stage in (None, "fit", "validate"): - make_dataset = functools.partial( - AnnotatedSpectrumDataset, - n_peaks=self.n_peaks, - min_mz=self.min_mz, - max_mz=self.max_mz, - min_intensity=self.min_intensity, - remove_precursor_tol=self.remove_precursor_tol, - ) - if self.train_index is not None: - self.train_dataset = make_dataset( - self.train_index, - random_state=self.rng, + if self.train_paths is not None: + self.train_dataset = self.make_dataset( + self.train_paths, + annotated=True, + mode="train", + shuffle=self.shuffle, + ) + if self.valid_paths is not None: + self.valid_dataset = self.make_dataset( + self.valid_paths, + annotated=True, + mode="valid", + shuffle=False, ) - if self.valid_index is not None: - self.valid_dataset = make_dataset(self.valid_index) if stage in (None, "test"): - make_dataset = functools.partial( - AnnotatedSpectrumDataset if annotated else SpectrumDataset, - n_peaks=self.n_peaks, - min_mz=self.min_mz, - max_mz=self.max_mz, - min_intensity=self.min_intensity, - remove_precursor_tol=self.remove_precursor_tol, - ) - if self.test_index is not None: - self.test_dataset = make_dataset(self.test_index) + if self.test_paths is not None: + self.test_dataset = self.make_dataset( + self.test_paths, + annotated=annotated, + mode="test", + shuffle=False, + ) def _make_loader( self, dataset: torch.utils.data.Dataset, batch_size: int, shuffle: bool = False, - collate_fn: Optional[callable] = None, ) -> torch.utils.data.DataLoader: """ Create a PyTorch DataLoader. @@ -147,18 +267,15 @@ def _make_loader( The batch size to use. shuffle : bool Option to shuffle the batches. - collate_fn : Optional[callable] - A function to collate the data into a batch. Returns ------- torch.utils.data.DataLoader A PyTorch DataLoader. """ - return torch.utils.data.DataLoader( + return DataLoader( dataset, batch_size=batch_size, - collate_fn=prepare_batch if collate_fn is None else collate_fn, pin_memory=True, num_workers=self.n_workers, shuffle=shuffle, @@ -167,7 +284,7 @@ def _make_loader( def train_dataloader(self) -> torch.utils.data.DataLoader: """Get the training DataLoader.""" return self._make_loader( - self.train_dataset, self.train_batch_size, shuffle=True + self.train_dataset, self.train_batch_size, shuffle=self.shuffle ) def val_dataloader(self) -> torch.utils.data.DataLoader: @@ -184,123 +301,15 @@ def predict_dataloader(self) -> torch.utils.data.DataLoader: def db_dataloader(self) -> torch.utils.data.DataLoader: """Get a special dataloader for DB search.""" - return self._make_loader( - self.test_dataset, - self.eval_batch_size, - collate_fn=functools.partial( - prepare_psm_batch, protein_database=self.protein_database - ), - ) - - -def prepare_batch( - batch: List[Tuple[torch.Tensor, float, int, str]] -) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]: - """ - Collate MS/MS spectra into a batch. + return self._make_loader(self.test_dataset, self.eval_batch_size) - The MS/MS spectra will be padded so that they fit nicely as a - tensor. However, the padded elements are ignored during the - subsequent steps. - Parameters - ---------- - batch : List[Tuple[torch.Tensor, float, int, str]] - A batch of data from an AnnotatedSpectrumDataset, consisting of - for each spectrum (i) a tensor with the m/z and intensity peak - values, (ii), the precursor m/z, (iii) the precursor charge, - (iv) the spectrum identifier. - - Returns - ------- - spectra : torch.Tensor of shape (batch_size, n_peaks, 2) - The padded mass spectra tensor with the m/z and intensity peak - values for each spectrum. - precursors : torch.Tensor of shape (batch_size, 3) - A tensor with the precursor neutral mass, precursor charge, and - precursor m/z. - spectrum_ids : np.ndarray - The spectrum identifiers (during de novo sequencing) or peptide - sequences (during training). +def scale_to_unit_norm(spectrum): """ - spectra, precursor_mzs, precursor_charges, spectrum_ids = list(zip(*batch)) - spectra = torch.nn.utils.rnn.pad_sequence(spectra, batch_first=True) - precursor_mzs = torch.tensor(precursor_mzs) - precursor_charges = torch.tensor(precursor_charges) - precursor_masses = (precursor_mzs - 1.007276) * precursor_charges - precursors = torch.vstack( - [precursor_masses, precursor_charges, precursor_mzs] - ).T.float() - return spectra, precursors, np.asarray(spectrum_ids) - - -def prepare_psm_batch( - batch: List[Tuple[torch.Tensor, float, int, str]], - protein_database: db_utils.ProteinDatabase, -) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray]: - """ - Collate MS/MS spectra into a batch for DB search. - - The MS/MS spectra will be padded so that they fit nicely as a - tensor. However, the padded elements are ignored during the - subsequent steps. - - Parameters - ---------- - batch : List[Tuple[torch.Tensor, float, int, str]] - A batch of data from an AnnotatedSpectrumDataset, consisting of - for each spectrum (i) a tensor with the m/z and intensity peak - values, (ii), the precursor m/z, (iii) the precursor charge, - (iv) the spectrum identifier. - protein_database : db_utils.ProteinDatabase - The protein database to use for candidate peptide retrieval. - - Returns - ------- - batch_spectra : torch.Tensor of shape (batch_size, n_peaks, 2) - The padded mass spectra tensor with the m/z and intensity peak - values for each spectrum. - batch_precursors : torch.Tensor of shape (batch_size, 3) - A tensor with the precursor neutral mass, precursor charge, and - precursor m/z. - batch_spectrum_ids : np.ndarray - The spectrum identifiers. - batch_peptides : np.ndarray - The candidate peptides for each spectrum. + Scaling function used in Casanovo + slightly differing from the depthcharge implementation """ - spectra, precursors, spectrum_ids = prepare_batch(batch) - - batch_spectra = [] - batch_precursors = [] - batch_spectrum_ids = [] - batch_peptides = [] - # FIXME: This can be optimized by using a sliding window instead of - # retrieving candidates for each spectrum independently. - for i in range(len(batch)): - candidate_pep = protein_database.get_candidates( - precursors[i][2], precursors[i][1] - ) - if len(candidate_pep) == 0: - logger.debug( - "No candidate peptides found for spectrum %s with precursor " - "charge %d and precursor m/z %f", - spectrum_ids[i], - precursors[i][1], - precursors[i][2], - ) - else: - batch_spectra.append( - spectra[i].unsqueeze(0).repeat(len(candidate_pep), 1, 1) - ) - batch_precursors.append( - precursors[i].unsqueeze(0).repeat(len(candidate_pep), 1) - ) - batch_spectrum_ids.extend([spectrum_ids[i]] * len(candidate_pep)) - batch_peptides.extend(candidate_pep) - - return ( - torch.cat(batch_spectra, dim=0), - torch.cat(batch_precursors, dim=0), - np.asarray(batch_spectrum_ids), - np.asarray(batch_peptides), + spectrum._inner._intensity = spectrum.intensity / np.linalg.norm( + spectrum.intensity ) + return spectrum diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index f350f3b3..53c6a9a0 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -5,25 +5,23 @@ import itertools import logging import warnings -from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union -import depthcharge.masses import einops -import torch -import numpy as np import lightning.pytorch as pl -from torch.utils.tensorboard import SummaryWriter -from depthcharge.components import ModelMixin, PeptideDecoder, SpectrumEncoder +import numpy as np +import torch +from depthcharge.tokenizers import PeptideTokenizer -from . import evaluate from .. import config -from ..data import ms_io +from ..data import ms_io, psm +from ..denovo.transformers import PeptideDecoder, SpectrumEncoder +from . import evaluate logger = logging.getLogger("casanovo") -class Spec2Pep(pl.LightningModule, ModelMixin): +class Spec2Pep(pl.LightningModule): """ A Transformer model for de novo peptide sequencing. @@ -77,9 +75,6 @@ class Spec2Pep(pl.LightningModule, ModelMixin): Number of PSMs to return for each spectrum. n_log : int The number of epochs to wait between logging messages. - tb_summarywriter : Optional[Path] - Folder path to record performance metrics during training. If - ``None``, don't use a ``SummaryWriter``. train_label_smoothing : float Smoothing factor when calculating the training loss. warmup_iters : int @@ -93,6 +88,8 @@ class Spec2Pep(pl.LightningModule, ModelMixin): calculate_precision : bool Calculate the validation set precision during training. This is expensive. + tokenizer: Optional[PeptideTokenizer] + Tokenizer object to tokenize and detokenize peptide sequences. **kwargs : Dict Additional keyword arguments passed to the Adam optimizer. """ @@ -104,7 +101,6 @@ def __init__( dim_feedforward: int = 1024, n_layers: int = 9, dropout: float = 0.0, - dim_intensity: Optional[int] = None, max_peptide_len: int = 100, residues: Union[Dict[str, float], str] = "canonical", max_charge: int = 5, @@ -114,40 +110,44 @@ def __init__( n_beams: int = 1, top_match: int = 1, n_log: int = 10, - tb_summarywriter: Optional[Path] = None, train_label_smoothing: float = 0.01, warmup_iters: int = 100_000, cosine_schedule_period_iters: int = 600_000, out_writer: Optional[ms_io.MztabWriter] = None, calculate_precision: bool = False, + tokenizer: Optional[PeptideTokenizer] = None, **kwargs: Dict, ): super().__init__() self.save_hyperparameters() + self.tokenizer = ( + tokenizer if tokenizer is not None else PeptideTokenizer() + ) + self.vocab_size = len(self.tokenizer) + 1 # Build the model. self.encoder = SpectrumEncoder( - dim_model=dim_model, + d_model=dim_model, n_head=n_head, dim_feedforward=dim_feedforward, n_layers=n_layers, dropout=dropout, - dim_intensity=dim_intensity, ) self.decoder = PeptideDecoder( - dim_model=dim_model, + d_model=dim_model, + n_tokens=self.tokenizer, n_head=n_head, dim_feedforward=dim_feedforward, n_layers=n_layers, dropout=dropout, - residues=residues, max_charge=max_charge, ) self.softmax = torch.nn.Softmax(2) + ignore_index = 0 self.celoss = torch.nn.CrossEntropyLoss( - ignore_index=0, label_smoothing=train_label_smoothing + ignore_index=ignore_index, label_smoothing=train_label_smoothing ) - self.val_celoss = torch.nn.CrossEntropyLoss(ignore_index=0) + self.val_celoss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) # Optimizer settings. self.warmup_iters = warmup_iters self.cosine_schedule_period_iters = cosine_schedule_period_iters @@ -170,41 +170,40 @@ def __init__( self.min_peptide_len = min_peptide_len self.n_beams = n_beams self.top_match = top_match - self.peptide_mass_calculator = depthcharge.masses.PeptideMass( - self.residues - ) - self.stop_token = self.decoder._aa2idx["$"] + + self.stop_token = self.tokenizer.stop_int # Logging. self.calculate_precision = calculate_precision self.n_log = n_log self._history = [] - if tb_summarywriter is not None: - self.tb_summarywriter = SummaryWriter(str(tb_summarywriter)) - else: - self.tb_summarywriter = None # Output writer during predicting. self.out_writer: ms_io.MztabWriter = out_writer + @property + def device(self) -> torch.device: + """The current device for first parameter of the model.""" + return next(self.parameters()).device + + @property + def n_parameters(self): + """The number of learnable parameters.""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + def forward( - self, spectra: torch.Tensor, precursors: torch.Tensor + self, batch: dict ) -> List[List[Tuple[float, np.ndarray, str]]]: """ Predict peptide sequences for a batch of MS/MS spectra. Parameters ---------- - spectra : torch.Tensor of shape (n_spectra, n_peaks, 2) - The spectra for which to predict peptide sequences. - Axis 0 represents an MS/MS spectrum, axis 1 contains the - peaks in the MS/MS spectrum, and axis 2 is essentially a - 2-tuple specifying the m/z-intensity pair for each peak. - These should be zero-padded, such that all the spectra in - the batch are the same length. - precursors : torch.Tensor of size (n_spectra, 3) - The measured precursor mass (axis 0), precursor charge - (axis 1), and precursor m/z (axis 2) of each MS/MS spectrum. + batch : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str]] + A batch of (i) m/z values of MS/MS spectra, + (ii) intensity values of MS/MS spectra, + (iii) precursor information, + (iv) peptide sequences as torch Tensors. Returns ------- @@ -214,59 +213,67 @@ def forward( score, the amino acid scores, and the predicted peptide sequence. """ - return self.beam_search_decode( - spectra.to(self.encoder.device), - precursors.to(self.decoder.device), - ) + mzs, ints, precursors, _ = self._process_batch(batch) + return self.beam_search_decode(mzs, ints, precursors) def beam_search_decode( - self, spectra: torch.Tensor, precursors: torch.Tensor + self, mzs: torch.Tensor, ints: torch.Tensor, precursors: torch.Tensor ) -> List[List[Tuple[float, np.ndarray, str]]]: """ Beam search decoding of the spectrum predictions. Parameters ---------- - spectra : torch.Tensor of shape (n_spectra, n_peaks, 2) - The spectra for which to predict peptide sequences. - Axis 0 represents an MS/MS spectrum, axis 1 contains the - peaks in the MS/MS spectrum, and axis 2 is essentially a - 2-tuple specifying the m/z-intensity pair for each peak. - These should be zero-padded, such that all the spectra in - the batch are the same length. + mzs : torch.Tensor of shape (n_spectra, n_peaks) + The m/z axis of spectra for which to predict peptide sequences. + Axis 0 represents an MS/MS spectrum, axis 1 contains the peaks in + the MS/MS spectrum. These should be zero-padded, + such that all the spectra in the batch are the same length. + ints: torch.Tensor of shape (n_spectra, n_peaks) + The m/z axis of spectra for which to predict peptide sequences. + Axis 0 represents an MS/MS spectrum, axis 1 specifies + the m/z-intensity pair for each peak. These should be zero-padded, + such that all the spectra in the batch are the same length. precursors : torch.Tensor of size (n_spectra, 3) - The measured precursor mass (axis 0), precursor charge - (axis 1), and precursor m/z (axis 2) of each MS/MS spectrum. + The measured precursor mass (axis 0), precursor charge (axis 1), and + precursor m/z (axis 2) of each MS/MS spectrum. Returns ------- pred_peptides : List[List[Tuple[float, np.ndarray, str]]] - For each spectrum, a list with the top peptide - prediction(s). A peptide predictions consists of a tuple - with the peptide score, the amino acid scores, and the - predicted peptide sequence. + For each spectrum, a list with the top peptide prediction(s). A + peptide predictions consists of a tuple with the peptide score, + the amino acid scores, and the predicted peptide sequence. """ - memories, mem_masks = self.encoder(spectra) + memories, mem_masks = self.encoder(mzs, ints) # Sizes. - batch = spectra.shape[0] # B + batch = mzs.shape[0] # B length = self.max_peptide_len + 1 # L - vocab = self.decoder.vocab_size + 1 # V + vocab = self.vocab_size # V beam = self.n_beams # S # Initialize scores and tokens. scores = torch.full( size=(batch, length, vocab, beam), fill_value=torch.nan + ).type_as(mzs) + + tokens = torch.zeros( + batch, length, beam, dtype=torch.int64, device=self.encoder.device ) - scores = scores.type_as(spectra) - tokens = torch.zeros(batch, length, beam, dtype=torch.int64) - tokens = tokens.to(self.encoder.device) # Create cache for decoded beams. pred_cache = collections.OrderedDict((i, []) for i in range(batch)) # Get the first prediction. - pred, _ = self.decoder(None, precursors, memories, mem_masks) + pred = self.decoder( + tokens=torch.zeros( + batch, 0, dtype=torch.int64, device=self.encoder.device + ), + memory=memories, + memory_key_padding_mask=mem_masks, + precursors=precursors, + ) tokens[:, 0, :] = torch.topk(pred[:, 0, :], beam, dim=1)[1] scores[:, :1, :, :] = einops.repeat(pred, "B L V -> B L V S", S=beam) @@ -279,16 +286,15 @@ def beam_search_decode( # The main decoding loop. for step in range(0, self.max_peptide_len): - # Terminate beams exceeding the precursor m/z tolerance and - # track all finished beams (either terminated or stop token - # predicted). + # Terminate beams exceeding the precursor m/z tolerance and track + # all finished beams (either terminated or stop token predicted). ( finished_beams, beam_fits_precursor, discarded_beams, ) = self._finish_beams(tokens, precursors, step) - # Cache peptide predictions from the finished beams (but not - # the discarded beams). + # Cache peptide predictions from the finished beams (but not the + # discarded beams). self._cache_finished_beams( tokens, scores, @@ -299,26 +305,25 @@ def beam_search_decode( ) # Stop decoding when all current beams have been finished. - # Continue with beams that have not been finished and not - # discarded. + # Continue with beams that have not been finished and not discarded. finished_beams |= discarded_beams if finished_beams.all(): break # Update the scores. - scores[~finished_beams, : step + 2, :], _ = self.decoder( - tokens[~finished_beams, : step + 1], - precursors[~finished_beams, :], - memories[~finished_beams, :, :], - mem_masks[~finished_beams, :], + scores[~finished_beams, : step + 2, :] = self.decoder( + tokens=tokens[~finished_beams, : step + 1], + precursors=precursors[~finished_beams, :], + memory=memories[~finished_beams, :, :], + memory_key_padding_mask=mem_masks[~finished_beams, :], ) - # Find the top-k beams with the highest scores and continue - # decoding those. + # Find the top-k beams with the highest scores and continue decoding + # those. tokens, scores = self._get_topk_beams( tokens, scores, finished_beams, batch, step + 1 ) - # Return the peptide with the highest confidence score, within - # the precursor m/z tolerance if possible. + # Return the peptide with the highest confidence score, within the + # precursor m/z tolerance if possible. return list(self._get_top_peptide(pred_cache)) def _finish_beams( @@ -328,53 +333,54 @@ def _finish_beams( step: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Track all beams that have been finished, either by predicting - the stop token or because they were terminated due to exceeding - the precursor m/z tolerance. + Track all beams that have been finished, either by predicting the stop + token or because they were terminated due to exceeding the precursor + m/z tolerance. Parameters ---------- - tokens : torch.Tensor of shape (n_spectra * n_beams, max_peptide_len) + tokens : torch.Tensor of shape (n_spectra * n_beams, max_length) Predicted amino acid tokens for all beams and all spectra. scores : torch.Tensor of shape - (n_spectra * n_beams, max_peptide_len, n_amino_acids) - Scores for the predicted amino acid tokens for all beams and - all spectra. + (n_spectra * n_beams, max_length, n_amino_acids) + Scores for the predicted amino acid tokens for all beams and all + spectra. step : int Index of the current decoding step. Returns ------- finished_beams : torch.Tensor of shape (n_spectra * n_beams) - Boolean tensor indicating whether the current beams have - been finished. + Boolean tensor indicating whether the current beams have been + finished. beam_fits_precursor: torch.Tensor of shape (n_spectra * n_beams) - Boolean tensor indicating if current beams are within - precursor m/z tolerance. + Boolean tensor indicating if current beams are within precursor m/z + tolerance. discarded_beams : torch.Tensor of shape (n_spectra * n_beams) - Boolean tensor indicating whether the current beams should - be discarded (e.g. because they were predicted to end but - violate the minimum peptide length). + Boolean tensor indicating whether the current beams should be + discarded (e.g. because they were predicted to end but violate the + minimum peptide length). """ # Check for tokens with a negative mass (i.e. neutral loss). - aa_neg_mass = [None] - for aa, mass in self.peptide_mass_calculator.masses.items(): + aa_neg_mass_idx = [None] + for aa, mass in self.tokenizer.residues.items(): if mass < 0: - aa_neg_mass.append(aa) + # aa_neg_mass.append(aa) + aa_neg_mass_idx.append(self.tokenizer.index[aa]) + # Find N-terminal residues. n_term = torch.Tensor( [ - self.decoder._aa2idx[aa] - for aa in self.peptide_mass_calculator.masses - if aa.startswith(("+", "-")) + self.tokenizer.index[aa] + for aa in self.tokenizer.index + if aa.startswith("[") and aa.endswith("]-") ] ).to(self.decoder.device) beam_fits_precursor = torch.zeros( tokens.shape[0], dtype=torch.bool ).to(self.encoder.device) - # Beams with a stop token predicted in the current step can be - # finished. + # Beams with a stop token predicted in the current step can be finished. finished_beams = torch.zeros(tokens.shape[0], dtype=torch.bool).to( self.encoder.device ) @@ -385,10 +391,10 @@ def _finish_beams( discarded_beams = torch.zeros(tokens.shape[0], dtype=torch.bool).to( self.encoder.device ) + discarded_beams[tokens[:, step] == 0] = True - # Discard beams with invalid modification combinations (i.e. - # N-terminal modifications occur multiple times or in internal - # positions). + # Discard beams with invalid modification combinations (i.e. N-terminal + # modifications occur multiple times or in internal positions). if step > 1: # Only relevant for longer predictions. dim0 = torch.arange(tokens.shape[0]) final_pos = torch.full((ends_stop_token.shape[0],), step) @@ -405,44 +411,58 @@ def _finish_beams( ).any(dim=1) discarded_beams[multiple_mods | internal_mods] = True - # Check which beams should be terminated or discarded based on - # the predicted peptide. + # Check which beams should be terminated or discarded based on the + # predicted peptide. for i in range(len(finished_beams)): # Skip already discarded beams. if discarded_beams[i]: continue pred_tokens = tokens[i][: step + 1] peptide_len = len(pred_tokens) - peptide = self.decoder.detokenize(pred_tokens) + # Omit stop token. - if self.decoder.reverse and peptide[0] == "$": - peptide = peptide[1:] + if self.tokenizer.reverse and pred_tokens[0] == self.stop_token: + pred_tokens = pred_tokens[1:] peptide_len -= 1 - elif not self.decoder.reverse and peptide[-1] == "$": - peptide = peptide[:-1] + elif ( + not self.tokenizer.reverse + and pred_tokens[-1] == self.stop_token + ): + pred_tokens = pred_tokens[:-1] peptide_len -= 1 - # Discard beams that were predicted to end but don't fit the - # minimum peptide length. + # Discard beams that were predicted to end but don't fit the minimum + # peptide length. if finished_beams[i] and peptide_len < self.min_peptide_len: discarded_beams[i] = True continue - # Terminate the beam if it has not been finished by the - # model but the peptide mass exceeds the precursor m/z to an - # extent that it cannot be corrected anymore by a - # subsequently predicted AA with negative mass. + # Terminate the beam if it has not been finished by the model but + # the peptide mass exceeds the precursor m/z to an extent that it + # cannot be corrected anymore by a subsequently predicted AA with + # negative mass. precursor_charge = precursors[i, 1] precursor_mz = precursors[i, 2] matches_precursor_mz = exceeds_precursor_mz = False - for aa in [None] if finished_beams[i] else aa_neg_mass: + + # Send tokenizer masses to correct device for calculate_precursor_ions() + self.tokenizer.masses = self.tokenizer.masses.type_as(precursor_mz) + + for aa in [None] if finished_beams[i] else aa_neg_mass_idx: if aa is None: - calc_peptide = peptide + calc_peptide = pred_tokens else: - calc_peptide = peptide.copy() - calc_peptide.append(aa) - try: - calc_mz = self.peptide_mass_calculator.mass( - seq=calc_peptide, charge=precursor_charge + calc_peptide = pred_tokens.detach().clone() + calc_peptide = torch.cat( + ( + calc_peptide, + torch.tensor([aa]).type_as(calc_peptide), + ) ) + try: + calc_mz = self.tokenizer.calculate_precursor_ions( + calc_peptide.unsqueeze(0), + precursor_charge.unsqueeze(0), + )[0] + delta_mass_ppm = [ _calc_mass_error( calc_mz, @@ -455,18 +475,16 @@ def _finish_beams( self.isotope_error_range[1] + 1, ) ] - # Terminate the beam if the calculated m/z for the - # predicted peptide (without potential additional - # AAs with negative mass) is within the precursor - # m/z tolerance. + # Terminate the beam if the calculated m/z for the predicted + # peptide (without potential additional AAs with negative + # mass) is within the precursor m/z tolerance. matches_precursor_mz = aa is None and any( abs(d) < self.precursor_mass_tol for d in delta_mass_ppm ) - # Terminate the beam if the calculated m/z exceeds - # the precursor m/z + tolerance and hasn't been - # corrected by a subsequently predicted AA with - # negative mass. + # Terminate the beam if the calculated m/z exceeds the + # precursor m/z + tolerance and hasn't been corrected by a + # subsequently predicted AA with negative mass. if matches_precursor_mz: exceeds_precursor_mz = False else: @@ -481,8 +499,8 @@ def _finish_beams( except KeyError: matches_precursor_mz = exceeds_precursor_mz = False # Finish beams that fit or exceed the precursor m/z. - # Don't finish beams that don't include a stop token if they - # don't exceed the precursor m/z tolerance yet. + # Don't finish beams that don't include a stop token if they don't + # exceed the precursor m/z tolerance yet. if finished_beams[i]: beam_fits_precursor[i] = matches_precursor_mz elif exceeds_precursor_mz: @@ -506,17 +524,17 @@ def _cache_finished_beams( Parameters ---------- - tokens : torch.Tensor of shape (n_spectra * n_beams, max_peptide_len) + tokens : torch.Tensor of shape (n_spectra * n_beams, max_length) Predicted amino acid tokens for all beams and all spectra. scores : torch.Tensor of shape - (n_spectra * n_beams, max_peptide_len, n_amino_acids) - Scores for the predicted amino acid tokens for all beams and - all spectra. + (n_spectra * n_beams, max_length, n_amino_acids) + Scores for the predicted amino acid tokens for all beams and all + spectra. step : int Index of the current decoding step. beams_to_cache : torch.Tensor of shape (n_spectra * n_beams) - Boolean tensor indicating whether the current beams are - ready for caching. + Boolean tensor indicating whether the current beams are ready for + caching. beam_fits_precursor: torch.Tensor of shape (n_spectra * n_beams) Boolean tensor indicating whether the beams are within the precursor m/z tolerance. @@ -524,9 +542,9 @@ def _cache_finished_beams( int, List[Tuple[float, float, np.ndarray, torch.Tensor]] ] Priority queue with finished beams for each spectrum, ordered by - peptide score. For each finished beam, a tuple with the - (negated) peptide score, a random tie-breaking float, the - amino acid-level scores, and the predicted tokens is stored. + peptide score. For each finished beam, a tuple with the (negated) + peptide score, a random tie-breaking float, the amino acid-level + scores, and the predicted tokens is stored. """ for i in range(len(beams_to_cache)): if not beams_to_cache[i]: @@ -548,8 +566,8 @@ def _cache_finished_beams( continue smx = self.softmax(scores[i : i + 1, : step + 1, :]) aa_scores = smx[0, range(len(pred_tokens)), pred_tokens].tolist() - # Add an explicit score 0 for the missing stop token in case - # this was not predicted (i.e. early stopping). + # Add an explicit score 0 for the missing stop token in case this + # was not predicted (i.e. early stopping). if not has_stop_token: aa_scores.append(0) aa_scores = np.asarray(aa_scores) @@ -559,8 +577,8 @@ def _cache_finished_beams( ) # Omit the stop token from the amino acid-level scores. aa_scores = aa_scores[:-1] - # Add the prediction to the cache (minimum priority queue, - # maximum the number of beams elements). + # Add the prediction to the cache (minimum priority queue, maximum + # the number of beams elements). if len(pred_cache[spec_idx]) < self.n_beams: heapadd = heapq.heappush else: @@ -584,22 +602,22 @@ def _get_topk_beams( step: int, ) -> Tuple[torch.tensor, torch.tensor]: """ - Find the top-k beams with the highest scores and continue - decoding those. + Find the top-k beams with the highest scores and continue decoding + those. Stop decoding for beams that have been finished. Parameters ---------- - tokens : torch.Tensor of shape (n_spectra * n_beams, max_peptide_len) + tokens : torch.Tensor of shape (n_spectra * n_beams, max_length) Predicted amino acid tokens for all beams and all spectra. scores : torch.Tensor of shape - (n_spectra * n_beams, max_peptide_len, n_amino_acids) - Scores for the predicted amino acid tokens for all beams and - all spectra. + (n_spectra * n_beams, max_length, n_amino_acids) + Scores for the predicted amino acid tokens for all beams and all + spectra. finished_beams : torch.Tensor of shape (n_spectra * n_beams) - Boolean tensor indicating whether the current beams are - ready for caching. + Boolean tensor indicating whether the current beams are ready for + caching. batch: int Number of spectra in the batch. step : int @@ -607,15 +625,15 @@ def _get_topk_beams( Returns ------- - tokens : torch.Tensor of shape (n_spectra * n_beams, max_peptide_len) + tokens : torch.Tensor of shape (n_spectra * n_beams, max_length) Predicted amino acid tokens for all beams and all spectra. scores : torch.Tensor of shape - (n_spectra * n_beams, max_peptide_len, n_amino_acids) - Scores for the predicted amino acid tokens for all beams and - all spectra. + (n_spectra * n_beams, max_length, n_amino_acids) + Scores for the predicted amino acid tokens for all beams and all + spectra. """ beam = self.n_beams # S - vocab = self.decoder.vocab_size + 1 # V + vocab = self.vocab_size # V # Reshape to group by spectrum (B for "batch"). tokens = einops.rearrange(tokens, "(B S) L -> B L S", S=beam) @@ -647,7 +665,7 @@ def _get_topk_beams( ).float() # Mask out the index '0', i.e. padding token, by default. # FIXME: Set this to a very small, yet non-zero value, to only - # get padding after stop token. + # get padding after stop token. active_mask[:, :beam] = 1e-8 # Figure out the top K decodings. @@ -702,7 +720,7 @@ def _get_top_peptide( ( pep_score, aa_scores, - "".join(self.decoder.detokenize(pred_tokens)), + pred_tokens, ) for pep_score, _, aa_scores, pred_tokens in heapq.nlargest( self.top_match, peptides @@ -711,29 +729,61 @@ def _get_top_peptide( else: yield [] + def _process_batch(self, batch): + """Prepare batch returned from AnnotatedSpectrumDataset of the + latest depthcharge version + + Each batch is a dict and contains these keys: + ['peak_file', 'scan_id', 'ms_level', 'precursor_mz', + 'precursor_charge', 'mz_array', 'intensity_array', + 'seq'] + Returns + ------- + spectra : torch.Tensor of shape (batch_size, n_peaks, 2) + The padded mass spectra tensor with the m/z and intensity peak values + for each spectrum. + precursors : torch.Tensor of shape (batch_size, 3) + A tensor with the precursor neutral mass, precursor charge, and + precursor m/z. + seqs : np.ndarray + The spectrum identifiers (during de novo sequencing) or peptide + sequences (during training). + + """ + for k in batch.keys(): + try: + batch[k] = batch[k].squeeze(0) + except: + continue + + precursor_mzs = batch["precursor_mz"] + precursor_charges = batch["precursor_charge"] + precursor_masses = (precursor_mzs - 1.007276) * precursor_charges + precursors = torch.vstack( + [precursor_masses, precursor_charges, precursor_mzs] + ).T # .float() + + mzs, ints = batch["mz_array"], batch["intensity_array"] + # spectra = torch.stack([mzs, ints], dim=2) + + seqs = batch["seq"] if "seq" in batch else None + + return mzs, ints, precursors, seqs + def _forward_step( self, - spectra: torch.Tensor, - precursors: torch.Tensor, - sequences: List[str], + batch, ) -> Tuple[torch.Tensor, torch.Tensor]: """ The forward learning step. Parameters ---------- - spectra : torch.Tensor of shape (n_spectra, n_peaks, 2) - The spectra for which to predict peptide sequences. - Axis 0 represents an MS/MS spectrum, axis 1 contains the - peaks in the MS/MS spectrum, and axis 2 is essentially a - 2-tuple specifying the m/z-intensity pair for each peak. - These should be zero-padded, such that all the spectra in - the batch are the same length. - precursors : torch.Tensor of size (n_spectra, 3) - The measured precursor mass (axis 0), precursor charge - (axis 1), and precursor m/z (axis 2) of each MS/MS spectrum. - sequences : List[str] of length n_spectra - The partial peptide sequences to predict. + batch : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str]] + A batch of (i) m/z values of MS/MS spectra, + (ii) intensity values of MS/MS spectra, + (iii) precursor information, + (iv) peptide sequences as torch Tensors. Returns ------- @@ -742,11 +792,19 @@ def _forward_step( tokens : torch.Tensor of shape (n_spectra, length) The predicted tokens for each spectrum. """ - return self.decoder(sequences, precursors, *self.encoder(spectra)) + mzs, ints, precursors, tokens = self._process_batch(batch) + memories, mem_masks = self.encoder(mzs, ints) + decoded = self.decoder( + tokens=tokens, + memory=memories, + memory_key_padding_mask=mem_masks, + precursors=precursors, + ) + return decoded, tokens def training_step( self, - batch: Tuple[torch.Tensor, torch.Tensor, List[str]], + batch: dict, *args, mode: str = "train", ) -> torch.Tensor: @@ -755,9 +813,11 @@ def training_step( Parameters ---------- - batch : Tuple[torch.Tensor, torch.Tensor, List[str]] - A batch of (i) MS/MS spectra, (ii) precursor information, - (iii) peptide sequences as torch Tensors. + batch : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str]] + A batch of (i) m/z values of MS/MS spectra, + (ii) intensity values of MS/MS spectra, + (iii) precursor information, + (iv) peptide sequences as torch Tensors. mode : str Logging key to describe the current stage. @@ -766,8 +826,9 @@ def training_step( torch.Tensor The loss of the training step. """ - pred, truth = self._forward_step(*batch) - pred = pred[:, :-1, :].reshape(-1, self.decoder.vocab_size + 1) + pred, truth = self._forward_step(batch) + pred = pred[:, :-1, :].reshape(-1, self.vocab_size) + if mode == "train": loss = self.celoss(pred, truth.flatten()) else: @@ -778,6 +839,7 @@ def training_step( on_step=False, on_epoch=True, sync_dist=True, + batch_size=pred.shape[0], ) return loss @@ -789,9 +851,11 @@ def validation_step( Parameters ---------- - batch : Tuple[torch.Tensor, torch.Tensor, List[str]] - A batch of (i) MS/MS spectra, (ii) precursor information, - (iii) peptide sequences. + batch : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str]] + A batch of (i) m/z values of MS/MS spectra, + (ii) intensity values of MS/MS spectra, + (iii) precursor information, + (iv) peptide sequences as torch Tensors. Returns ------- @@ -803,21 +867,36 @@ def validation_step( if not self.calculate_precision: return loss - # Calculate and log amino acid and peptide match evaluation - # metrics from the predicted peptides. - peptides_pred, peptides_true = [], batch[2] - for spectrum_preds in self.forward(batch[0], batch[1]): + # Calculate and log amino acid and peptide match evaluation metrics from + # the predicted peptides. + peptides_true = [ + "".join(p) + for p in self.tokenizer.detokenize(batch["seq"], join=False) + ] + peptides_pred = [] + for spectrum_preds in self.forward(batch): for _, _, pred in spectrum_preds: peptides_pred.append(pred) - + peptides_pred = [ + "".join(p) + for p in self.tokenizer.detokenize(peptides_pred, join=False) + ] + batch_size = len(peptides_true) aa_precision, _, pep_precision = evaluate.aa_match_metrics( *evaluate.aa_match_batch( - peptides_true, peptides_pred, self.decoder._peptide_mass.masses + peptides_true, + peptides_pred, + self.tokenizer.residues, ) ) + log_args = dict(on_step=False, on_epoch=True, sync_dist=True) - self.log("Peptide precision at coverage=1", pep_precision, **log_args) - self.log("AA precision at coverage=1", aa_precision, **log_args) + self.log( + "pep_precision", pep_precision, **log_args, batch_size=batch_size + ) + self.log( + "aa_precision", aa_precision, **log_args, batch_size=batch_size + ) return loss def predict_step( @@ -828,39 +907,44 @@ def predict_step( Parameters ---------- - batch : Tuple[torch.Tensor, torch.Tensor, torch.Tensor] - A batch of (i) MS/MS spectra, (ii) precursor information, - (iii) spectrum identifiers as torch Tensors. + batch : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str]] + A batch of (i) m/z values of MS/MS spectra, + (ii) intensity values of MS/MS spectra, + (iii) precursor information, + (iv) peptide sequences as torch Tensors. Returns ------- predictions: List[ms_io.PepSpecMatch] Predicted PSMs for the given batch of spectra. """ + _, _, precursors, _ = self._process_batch(batch) + prec_charges = precursors[:, 1].cpu().detach().numpy() + prec_mzs = precursors[:, 2].cpu().detach().numpy() predictions = [] for ( precursor_charge, precursor_mz, - spectrum_i, + scan, + file_name, spectrum_preds, ) in zip( - batch[1][:, 1].cpu().detach().numpy(), - batch[1][:, 2].cpu().detach().numpy(), - batch[2], - self.forward(batch[0], batch[1]), + prec_charges, + prec_mzs, + batch["scan_id"], + batch["peak_file"], + self.forward(batch), ): for peptide_score, aa_scores, peptide in spectrum_preds: predictions.append( - ms_io.PepSpecMatch( - sequence=peptide, - spectrum_id=tuple(spectrum_i), - peptide_score=peptide_score, - charge=int(precursor_charge), - calc_mz=self.peptide_mass_calculator.mass( - peptide, precursor_charge - ), - exp_mz=precursor_mz, - aa_scores=aa_scores, + ( + scan[0], + precursor_charge, + precursor_mz, + peptide, + peptide_score, + aa_scores, + file_name[0], ) ) @@ -870,10 +954,15 @@ def on_train_epoch_end(self) -> None: """ Log the training loss at the end of each epoch. """ - train_loss = self.trainer.callback_metrics["train_CELoss"].detach() + if "train_CELoss" in self.trainer.callback_metrics: + train_loss = ( + self.trainer.callback_metrics["train_CELoss"].detach().item() + ) + else: + train_loss = np.nan metrics = { "step": self.trainer.global_step, - "train": train_loss.item(), + "train": train_loss, } self._history.append(metrics) self._log_history() @@ -890,12 +979,10 @@ def on_validation_epoch_end(self) -> None: if self.calculate_precision: metrics["valid_aa_precision"] = ( - callback_metrics["AA precision at coverage=1"].detach().item() + callback_metrics["aa_precision"].detach().item() ) metrics["valid_pep_precision"] = ( - callback_metrics["Peptide precision at coverage=1"] - .detach() - .item() + callback_metrics["pep_precision"].detach().item() ) self._history.append(metrics) self._log_history() @@ -909,9 +996,46 @@ def on_predict_batch_end( """ if self.out_writer is None: return - for pred in outputs: - if len(pred.sequence) > 0: - self.out_writer.psms.append(pred) + # Triply nested lists: results -> batch -> step -> spectrum. + for ( + scan, + charge, + precursor_mz, + peptide, + peptide_score, + aa_scores, + file_name, + ) in outputs: + if len(peptide) == 0: + continue + + # Compute mass and detokenize + calc_mass = self.tokenizer.calculate_precursor_ions( + peptide.unsqueeze(0), torch.tensor([charge]).type_as(peptide) + )[0] + peptide = "".join( + self.tokenizer.detokenize(peptide.unsqueeze(0), join=False)[0] + ) + + self.out_writer.psms.append( + psm.PepSpecMatch( + sequence=peptide, + spectrum_id=(file_name, scan), + peptide_score=peptide_score, + charge=int(charge), + calc_mz=calc_mass.item(), + exp_mz=precursor_mz, + aa_scores=aa_scores, + ) + ) + + def on_train_start(self): + """Log optimizer settings.""" + self.log("hp/optimizer_warmup_iters", self.warmup_iters) + self.log( + "hp/optimizer_cosine_schedule_period_iters", + self.cosine_schedule_period_iters, + ) def _log_history(self) -> None: """ @@ -943,18 +1067,6 @@ def _log_history(self) -> None: ] logger.info(msg, *vals) - if self.tb_summarywriter is not None: - for descr, key in [ - ("loss/train_crossentropy_loss", "train"), - ("loss/val_crossentropy_loss", "valid"), - ("eval/val_pep_precision", "valid_pep_precision"), - ("eval/val_aa_precision", "valid_aa_precision"), - ]: - metric_value = metrics.get(key, np.nan) - if not np.isnan(metric_value): - self.tb_summarywriter.add_scalar( - descr, metric_value, metrics["step"] - ) def configure_optimizers( self, @@ -998,7 +1110,7 @@ def __init__(self, *args, **kwargs): def predict_step( self, - batch: Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray], + batch: Dict[str, torch.Tensor | List], *args, ) -> List[ms_io.PepSpecMatch]: """ @@ -1006,59 +1118,68 @@ def predict_step( Parameters ---------- - batch : Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray] - A batch of (i) MS/MS spectra, (ii) precursor information, - (iii) spectrum identifiers, (iv) candidate peptides. + batch : Dict[str, torch.Tensor | List] + A batch of MS/MS spectra, as generated by a depthcharge + dataloader. Returns ------- predictions: List[ms_io.PepSpecMatch] Predicted PSMs for the given batch of spectra. """ + for batch_key in [ + "ms_level", + "precursor_mz", + "precursor_charge", + "mz_array", + "intensity_array", + ]: + batch[batch_key] = batch[batch_key].squeeze(0) + predictions_all = collections.defaultdict(list) - for start_i in range(0, len(batch[0]), self.psm_batch_size): - psm_batch = [ - b[start_i : start_i + self.psm_batch_size] for b in batch - ] - pred, truth = self._forward_step( - psm_batch[0], psm_batch[1], psm_batch[3] - ) + for psm_batch in self._psm_batches(batch): + pred, truth = self._forward_step(psm_batch) pred = self.softmax(pred) batch_peptide_scores, batch_aa_scores = _calc_match_score( - pred, truth, self.decoder.reverse + pred, + truth, ) + for ( + scan, charge, precursor_mz, - spectrum_i, + peptide, peptide_score, aa_scores, - peptide, + file_name, ) in zip( - psm_batch[1][:, 1].cpu().detach().numpy(), - psm_batch[1][:, 2].cpu().detach().numpy(), - psm_batch[2], + psm_batch["scan_id"], + psm_batch["precursor_charge"], + psm_batch["precursor_mz"], + self.tokenizer.detokenize(psm_batch["seq"]), batch_peptide_scores, batch_aa_scores, - psm_batch[3], + psm_batch["peak_file"], ): - spectrum_i = tuple(spectrum_i) - predictions_all[spectrum_i].append( - ms_io.PepSpecMatch( + spectrum_id = (file_name[0], scan[0]) + predictions_all[spectrum_id].append( + psm.PepSpecMatch( sequence=peptide, - spectrum_id=spectrum_i, + spectrum_id=spectrum_id, peptide_score=peptide_score, charge=int(charge), - calc_mz=self.peptide_mass_calculator.mass( + calc_mz=self.tokenizer.calculate_precursor_ions( peptide, charge - ), - exp_mz=precursor_mz, + ).item(), + exp_mz=precursor_mz.item(), aa_scores=aa_scores, protein=self.protein_database.get_associated_protein( peptide ), ) ) + # Filter the top-scoring prediction(s) for each spectrum. predictions = list( itertools.chain.from_iterable( @@ -1076,6 +1197,142 @@ def predict_step( ) return predictions + def on_predict_batch_end( + self, outputs: List[psm.PepSpecMatch], *args + ) -> None: + """ + Write top scoring batches to the outwriter + + Parameters + ---------- + outputs : List[psm.PepSpecMatch] + List of peptide-spectrum matches predicted in the batch. + *args : tuple + Additional arguments. + """ + self.out_writer.psms.extend(outputs) + + def _psm_batches( + self, batch: Dict[str, torch.Tensor | List] + ) -> Generator[Dict[str, Union[torch.Tensor, list]], None, None]: + """ + Generates batches of candidate database PSMs. + + Parameters + ---------- + batch : Dict[str, torch.Tensor | List] + One predict batch, from a depthcharge dataloader + + Yields + ------ + psm_batch : Dict[str, torch.Tensor | List] + A batch of candidate database PSMs ready for scoring. + """ + num_candidate_psms = 0 + psm_batch = self._initialize_psm_batch(batch) + + for i, (precursor_mz, precursor_charge) in enumerate( + zip(batch["precursor_mz"], batch["precursor_charge"]) + ): + candidate_peps = self.protein_database.get_candidates( + precursor_mz.item(), precursor_charge.item() + ).to_list() + + if len(candidate_peps) == 0: + logger.debug( + "No candidate peptides found for spectrum %s with precursor " + "charge %d and precursor m/z %f", + f"{batch['peak_file'][i]}:{batch['scan_id']}", + precursor_charge, + precursor_mz, + ) + continue + + while len(candidate_peps) > 0: + peps_to_add = min( + self.psm_batch_size + - (num_candidate_psms % self.psm_batch_size), + len(candidate_peps), + ) + + for key in batch.keys(): + psm_batch[key] += [batch[key][i]] * peps_to_add + + psm_batch["seq"] += candidate_peps[:peps_to_add] + num_candidate_psms += peps_to_add + + if self._pep_batch_ready(num_candidate_psms): + yield self._finalize_psm_batch(psm_batch) + psm_batch = self._initialize_psm_batch(batch) + + candidate_peps = candidate_peps[peps_to_add:] + + if ( + not self._pep_batch_ready(num_candidate_psms) + and num_candidate_psms > 0 + ): + yield self._finalize_psm_batch(psm_batch) + + def _pep_batch_ready(self, num_candidate_psms: int) -> bool: + """ + Checks if a batch of candidate PSMs is ready for processing. + + Parameters + ---------- + num_candidate_psms : int + Number of candidate PSMs processed so far. + + Returns + ------- + bool + True if the batch is ready, False otherwise. + """ + return ( + num_candidate_psms % self.psm_batch_size == 0 + ) and num_candidate_psms != 0 + + def _initialize_psm_batch(self, batch: Dict[str, Any]) -> Dict[str, List]: + """ + Initializes a new candidate PSM batch. + + Parameters + ---------- + batch : Dict[str, Any] + Input batch data to base the initialization on, usually from a + depthcharge dataloader. + + Returns + ------- + psm_batch : Dict[str, List] + A dictionary representing the initialized PSM batch. + """ + psm_batch = {key: list() for key in batch.keys()} + psm_batch["seq"] = list() + return psm_batch + + def _finalize_psm_batch( + self, psm_batch: Dict[str, List[Any]] + ) -> Dict[str, torch.Tensor | List[Any]]: + """ + Prepare a candidate PSM batch for scoring by the Casanovo model. + + Parameters + ---------- + psm_batch : Dict[str, List[Any]] + The current PSM batch to finalize. + + Returns + ------- + finalized_batch : Dict[str, torch.Tensor | List[Any]] + A finalized PSM batch ready for scoring. + """ + for key in psm_batch.keys(): + if isinstance(psm_batch[key][0], torch.Tensor): + psm_batch[key] = torch.stack(psm_batch[key]) + + psm_batch["seq"] = self.tokenizer.tokenize(psm_batch["seq"]) + return psm_batch + def _calc_match_score( batch_all_aa_scores: torch.Tensor, @@ -1235,3 +1492,14 @@ def _aa_pep_score( if not fits_precursor_mz: peptide_score -= 1 return aa_scores, peptide_score + + +def generate_tgt_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. + + Parameters + ---------- + sz : int + The length of the target sequence. + """ + return ~torch.triu(torch.ones(sz, sz, dtype=torch.bool)).transpose(0, 1) diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 30f86f24..c8fc7125 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -5,19 +5,19 @@ import logging import os import tempfile -import uuid import warnings from pathlib import Path from typing import Iterable, List, Optional, Union -import depthcharge.masses import lightning.pytorch as pl import lightning.pytorch.loggers -import numpy as np import torch -from depthcharge.data import AnnotatedSpectrumIndex, SpectrumIndex +import torch.utils.data +from depthcharge.tokenizers import PeptideTokenizer +from depthcharge.tokenizers.peptides import MskbPeptideTokenizer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from lightning.pytorch.strategies import DDPStrategy -from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor +from torch.utils.data import DataLoader from .. import utils from ..config import Config @@ -26,7 +26,6 @@ from ..denovo.evaluate import aa_match_batch, aa_match_metrics from ..denovo.model import DbSpec2Pep, Spec2Pep - logger = logging.getLogger("casanovo") @@ -108,7 +107,6 @@ def __init__( filename=best_filename, enable_version_counter=False, ), - LearningRateMonitor(log_momentum=True, log_weight_decay=True), ] def __enter__(self): @@ -147,6 +145,7 @@ def db_search( config_filename=self.config.file, ) self.initialize_trainer(train=True) + self.initialize_tokenizer() self.initialize_model(train=False, db_search=True) self.model.out_writer = self.writer self.model.psm_batch_size = self.config.predict_batch_size @@ -162,12 +161,11 @@ def db_search( self.config.isotope_error_range, self.config.allowed_fixed_mods, self.config.allowed_var_mods, - self.config.residues, + self.model.tokenizer, ) - test_index = self._get_index(peak_path, False, "db search") - self.writer.set_ms_run(test_index.ms_files) - - self.initialize_data_module(test_index=test_index) + test_paths = self._get_input_paths(peak_path, False, "test") + self.writer.set_ms_run(test_paths) + self.initialize_data_module(test_paths=test_paths) self.loaders.protein_database = self.model.protein_database self.loaders.setup(stage="test", annotated=False) self.trainer.predict(self.model, self.loaders.db_dataloader()) @@ -187,12 +185,15 @@ def train( The path to the MS data files for validation. """ self.initialize_trainer(train=True) + self.initialize_tokenizer() self.initialize_model(train=True) - train_index = self._get_index(train_peak_path, True, "training") - valid_index = self._get_index(valid_peak_path, True, "validation") - self.initialize_data_module(train_index, valid_index) + train_paths = self._get_input_paths(train_peak_path, True, "train") + valid_paths = self._get_input_paths(valid_peak_path, True, "valid") + self.initialize_data_module(train_paths, valid_paths) self.loaders.setup() + # logger.info(f'TRAIN PSMs: {self.loaders.train_dataset.n_spectra}') + # logger.info(f'VAL PSMs: {self.loaders.valid_dataset.n_spectra}') self.trainer.fit( self.model, @@ -200,28 +201,35 @@ def train( self.loaders.val_dataloader(), ) - def log_metrics(self, test_index: AnnotatedSpectrumIndex) -> None: - """Log peptide precision and amino acid precision. + def log_metrics(self, test_dataloader: DataLoader) -> None: + """Log peptide precision and amino acid precision Calculate and log peptide precision and amino acid precision - based off of model predictions and spectrum annotations. + based off of model predictions and spectrum annotations Parameters ---------- test_index : AnnotatedSpectrumIndex - Index containing the annotated spectra used to generate - model predictions. + Index containing the annotated spectra used to generate model + predictions """ seq_pred = [] seq_true = [] pred_idx = 0 - with test_index as t_ind: - for true_idx in range(t_ind.n_spectra): - seq_true.append(t_ind[true_idx][4]) - if pred_idx < len(self.writer.psms) and self.writer.psms[ - pred_idx - ].spectrum_id == t_ind.get_spectrum_id(true_idx): + for batch in test_dataloader: + for peak_file, scan_id, curr_seq_true in zip( + batch["peak_file"], + batch["scan_id"], + self.model.tokenizer.detokenize(batch["seq"][0]), + ): + spectrum_id_true = (peak_file, scan_id) + seq_true.append(curr_seq_true) + if ( + pred_idx < len(self.writer.psms) + and self.writer.psms[pred_idx].spectrum_id + == spectrum_id_true + ): seq_pred.append(self.writer.psms[pred_idx].sequence) pred_idx += 1 else: @@ -231,15 +239,14 @@ def log_metrics(self, test_index: AnnotatedSpectrumIndex) -> None: *aa_match_batch( seq_true, seq_pred, - depthcharge.masses.PeptideMass().masses, + self.model.tokenizer.residues, ) ) if self.config["top_match"] > 1: logger.warning( - "The behavior for calculating evaluation metrics is undefined " - "when the 'top_match' configuration option is set to a value " - "greater than 1." + "The behavior for calculating evaluation metrics is undefined when " + "the 'top_match' configuration option is set to a value greater than 1." ) logger.info("Peptide Precision: %.2f%%", 100 * pep_precision) @@ -278,17 +285,34 @@ def predict( ) self.initialize_trainer(train=False) + self.initialize_tokenizer() self.initialize_model(train=False) self.model.out_writer = self.writer - test_index = self._get_index(peak_path, evaluate, "") - self.writer.set_ms_run(test_index.ms_files) - self.initialize_data_module(test_index=test_index) - self.loaders.setup(stage="test", annotated=False) - self.trainer.predict(self.model, self.loaders.test_dataloader()) + test_paths = self._get_input_paths(peak_path, False, "test") + self.writer.set_ms_run(test_paths) + self.initialize_data_module(test_paths=test_paths) + + try: + self.loaders.setup(stage="test", annotated=evaluate) + except (KeyError, OSError) as e: + if evaluate: + error_message = ( + "Error creating annotated spectrum dataloaders. " + "This may be the result of having an unannotated peak file " + "present in the validation peak file path list.\n" + ) + + logger.error(error_message) + raise TypeError(error_message) from e + + raise + + predict_dataloader = self.loaders.predict_dataloader() + self.trainer.predict(self.model, predict_dataloader) if evaluate: - self.log_metrics(test_index) + self.log_metrics(predict_dataloader) def initialize_trainer(self, train: bool) -> None: """Initialize the lightning Trainer. @@ -303,6 +327,8 @@ def initialize_trainer(self, train: bool) -> None: accelerator=self.config.accelerator, devices=1, enable_checkpointing=False, + precision=self.config.precision, + logger=False, ) if train: @@ -311,42 +337,70 @@ def initialize_trainer(self, train: bool) -> None: else: devices = self.config.devices - additional_cfg = dict( - devices=devices, - callbacks=self.callbacks, - enable_checkpointing=True, - max_epochs=self.config.max_epochs, - num_sanity_val_steps=self.config.num_sanity_val_steps, - strategy=self._get_strategy(), - val_check_interval=self.config.val_check_interval, - check_val_every_n_epoch=None, - log_every_n_steps=self.config.log_every_n_steps, - ) - - if self.config.log_metrics: + # Configure loggers + logger = False + if self.config.log_metrics or self.config.tb_summarywriter: if not self.output_dir: logger.warning( "Output directory not set in model runner. " - "No loss file will be created." + "No loss file or tensorboard will be created." ) else: + logger = [] csv_log_dir = "csv_logs" - if self.overwrite_ckpt_check: - utils.check_dir_file_exists( - self.output_dir, - csv_log_dir, - ) + tb_log_dir = "tensorboard" + + if self.config.log_metrics: + if self.overwrite_ckpt_check: + utils.check_dir_file_exists( + self.output_dir, + csv_log_dir, + ) - additional_cfg.update( - { - "logger": lightning.pytorch.loggers.CSVLogger( + logger.append( + lightning.pytorch.loggers.CSVLogger( self.output_dir, version=csv_log_dir, name=None, + ) + ) + + if self.config.tb_summarywriter: + if self.overwrite_ckpt_check: + utils.check_dir_file_exists( + self.output_dir, + tb_log_dir, + ) + + logger.append( + lightning.pytorch.loggers.TensorBoardLogger( + self.output_dir, + version=tb_log_dir, + name=None, + ) + ) + + if len(logger) > 0: + self.callbacks.append( + LearningRateMonitor( + log_momentum=True, log_weight_decay=True ), - "log_every_n_steps": self.config.log_every_n_steps, - } - ) + ) + + additional_cfg = dict( + devices=devices, + callbacks=self.callbacks, + enable_checkpointing=True, + max_epochs=self.config.max_epochs, + num_sanity_val_steps=self.config.num_sanity_val_steps, + strategy=self._get_strategy(), + val_check_interval=self.config.val_check_interval, + check_val_every_n_epoch=None, + logger=logger, + accumulate_grad_batches=self.config.accumulate_grad_batches, + gradient_clip_val=self.config.gradient_clip_val, + gradient_clip_algorithm=self.config.gradient_clip_algorithm, + ) trainer_cfg.update(additional_cfg) @@ -363,15 +417,10 @@ def initialize_model(self, train: bool, db_search: bool = False) -> None: db_search : bool Determines whether to use the DB search model subclass. """ - tb_summarywriter = None - if self.config.tb_summarywriter: - if self.output_dir is None: - logger.warning( - "Can not create tensorboard because the output directory " - "is not set in the model runner." - ) - else: - tb_summarywriter = self.output_dir / "tensorboard" + try: + tokenizer = self.tokenizer + except AttributeError: + raise RuntimeError("Please use `initialize_tokenizer()` first.") model_params = dict( dim_model=self.config.dim_model, @@ -379,9 +428,6 @@ def initialize_model(self, train: bool, db_search: bool = False) -> None: dim_feedforward=self.config.dim_feedforward, n_layers=self.config.n_layers, dropout=self.config.dropout, - dim_intensity=self.config.dim_intensity, - max_peptide_len=self.config.max_peptide_len, - residues=self.config.residues, max_charge=self.config.max_charge, precursor_mass_tol=self.config.precursor_mass_tol, isotope_error_range=self.config.isotope_error_range, @@ -389,7 +435,6 @@ def initialize_model(self, train: bool, db_search: bool = False) -> None: n_beams=self.config.n_beams, top_match=self.config.top_match, n_log=self.config.n_log, - tb_summarywriter=tb_summarywriter, train_label_smoothing=self.config.train_label_smoothing, warmup_iters=self.config.warmup_iters, cosine_schedule_period_iters=self.config.cosine_schedule_period_iters, @@ -397,6 +442,7 @@ def initialize_model(self, train: bool, db_search: bool = False) -> None: weight_decay=self.config.weight_decay, out_writer=self.writer, calculate_precision=self.config.calculate_precision, + tokenizer=tokenizer, ) # Reconfigurable non-architecture related parameters for a @@ -409,7 +455,6 @@ def initialize_model(self, train: bool, db_search: bool = False) -> None: min_peptide_len=self.config.min_peptide_len, top_match=self.config.top_match, n_log=self.config.n_log, - tb_summarywriter=tb_summarywriter, train_label_smoothing=self.config.train_label_smoothing, warmup_iters=self.config.warmup_iters, cosine_schedule_period_iters=self.config.cosine_schedule_period_iters, @@ -449,7 +494,9 @@ def initialize_model(self, train: bool, db_search: bool = False) -> None: self.model = Model.load_from_checkpoint( self.model_filename, map_location=device, **loaded_model_params ) - + # Use tokenizer initialized from config file instead of loaded + # from checkpoint file + self.model.tokenizer = tokenizer architecture_params = set(model_params.keys()) - set( loaded_model_params.keys() ) @@ -470,30 +517,46 @@ def initialize_model(self, train: bool, db_search: bool = False) -> None: map_location=device, **model_params, ) + self.model.tokenizer = tokenizer except RuntimeError: raise RuntimeError( "Weights file incompatible with the current version of " "Casanovo." ) + def initialize_tokenizer( + self, + ) -> None: + """Initialize the peptide tokenizer""" + if self.config.mskb_tokenizer: + tokenizer_cs = MskbPeptideTokenizer + else: + tokenizer_cs = PeptideTokenizer + + self.tokenizer = tokenizer_cs( + residues=self.config.residues, + replace_isoleucine_with_leucine=self.config.replace_isoleucine_with_leucine, + reverse=self.config.reverse_peptides, + start_token=None, + stop_token="$", + ) + def initialize_data_module( self, - train_index: Optional[AnnotatedSpectrumIndex] = None, - valid_index: Optional[AnnotatedSpectrumIndex] = None, - test_index: Optional[ - Union[AnnotatedSpectrumIndex, SpectrumIndex] - ] = None, + train_paths: Optional[str] = None, + valid_paths: Optional[str] = None, + test_paths: Optional[str] = None, ) -> None: """Initialize the data module. Parameters ---------- - train_index : AnnotatedSpectrumIndex, optional - A spectrum index for model training. - valid_index : AnnotatedSpectrumIndex, optional - A spectrum index for validation. - test_index : AnnotatedSpectrumIndex or SpectrumIndex, optional - A spectrum index for evaluation or inference. + train_paths : str, optional + A spectrum path for model training. + valid_paths : str, optional + A spectrum path for validation. + test_paths : str, optional + A spectrum path for evaluation or inference. """ try: n_devices = self.trainer.num_devices @@ -502,10 +565,20 @@ def initialize_data_module( except AttributeError: raise RuntimeError("Please use `initialize_trainer()` first.") + try: + tokenizer = self.tokenizer + except AttributeError: + raise RuntimeError("Please use `initialize_tokenizer()` first.") + + lance_dir = ( + Path(self.tmp_dir.name) + if self.config.lance_dir is None + else self.config.lance_dir + ) self.loaders = DeNovoDataModule( - train_index=train_index, - valid_index=valid_index, - test_index=test_index, + train_paths=train_paths, + valid_paths=valid_paths, + test_paths=test_paths, min_mz=self.config.min_mz, max_mz=self.config.max_mz, min_intensity=self.config.min_intensity, @@ -513,18 +586,21 @@ def initialize_data_module( n_workers=self.config.n_workers, train_batch_size=train_bs, eval_batch_size=eval_bs, + n_peaks=self.config.n_peaks, + max_charge=self.config.max_charge, + tokenizer=tokenizer, + lance_dir=lance_dir, + shuffle=self.config.shuffle, + buffer_size=self.config.buffer_size, ) - def _get_index( + def _get_input_paths( self, peak_path: Iterable[str], annotated: bool, - msg: str = "", - ) -> Union[SpectrumIndex, AnnotatedSpectrumIndex]: - """Get the spectrum index. - - If the file is a SpectrumIndex, only one is allowed. Otherwise - multiple may be specified. + mode: str, + ) -> str: + """Get the spectrum input paths. Parameters ---------- @@ -532,54 +608,30 @@ def _get_index( The peak files/directories to check. annotated : bool Are the spectra expected to be annotated? - msg : str, optional - A string to insert into the error message. - + mode : str + Either train, valid or test to specify lance file name Returns ------- - SpectrumIndex or AnnotatedSpectrumIndex - The spectrum index for training, evaluation, or inference. + The spectrum paths for training, evaluation, or inference. """ - ext = (".mgf", ".h5", ".hdf5") + ext = (".mgf", ".lance") if not annotated: - ext += (".mzml", ".mzxml") + ext += (".mzML", ".mzml", ".mzxml") # FIXME: Check if these work - msg = msg.strip() filenames = _get_peak_filenames(peak_path, ext) if not filenames: - not_found_err = f"Cound not find {msg} peak files" + not_found_err = f"Cound not find {mode} peak files" logger.error(not_found_err + " from %s", peak_path) raise FileNotFoundError(not_found_err) - is_index = any([Path(f).suffix in (".h5", ".hdf5") for f in filenames]) - if is_index: + is_lance = any([Path(f).suffix in (".lance") for f in filenames]) + if is_lance: if len(filenames) > 1: - h5_err = f"Multiple {msg} HDF5 spectrum indexes specified" - logger.error(h5_err) - raise ValueError(h5_err) - - index_fname, filenames = filenames[0], None - else: - index_fname = Path(self.tmp_dir.name) / f"{uuid.uuid4().hex}.hdf5" - - Index = AnnotatedSpectrumIndex if annotated else SpectrumIndex - valid_charge = np.arange(1, self.config.max_charge + 1) - - try: - return Index(index_fname, filenames, valid_charge=valid_charge) - except TypeError as e: - if Index == AnnotatedSpectrumIndex: - error_msg = ( - "Error creating annotated spectrum index. " - "This may be the result of having an unannotated MGF file " - "present in the validation peak file path list.\n" - f"Original error message: {e}" - ) - - logger.error(error_msg) - raise TypeError(error_msg) + lance_err = f"Multiple {mode} spectrum lance files specified" + logger.error(lance_err) + raise ValueError(lance_err) - raise e + return filenames def _get_strategy(self) -> Union[str, DDPStrategy]: """Get the strategy for the Trainer. diff --git a/casanovo/denovo/transformers.py b/casanovo/denovo/transformers.py new file mode 100644 index 00000000..388882af --- /dev/null +++ b/casanovo/denovo/transformers.py @@ -0,0 +1,178 @@ +"""Transformer encoder and decoder for the de novo sequencing task.""" + +from collections.abc import Callable + +import torch +from depthcharge.encoders import FloatEncoder, PeakEncoder, PositionalEncoder +from depthcharge.tokenizers import Tokenizer +from depthcharge.transformers import ( + AnalyteTransformerDecoder, + SpectrumTransformerEncoder, +) + + +class PeptideDecoder(AnalyteTransformerDecoder): + """A transformer decoder for peptide sequences + + Parameters + ---------- + n_tokens : int + The number of tokens used to tokenize peptide sequences. + d_model : int, optional + The latent dimensionality to represent peaks in the mass spectrum. + nhead : int, optional + The number of attention heads in each layer. ``d_model`` must be + divisible by ``nhead``. + dim_feedforward : int, optional + The dimensionality of the fully connected layers in the Transformer + layers of the model. + n_layers : int, optional + The number of Transformer layers. + dropout : float, optional + The dropout probability for all layers. + pos_encoder : PositionalEncoder or bool, optional + The positional encodings to use for the amino acid sequence. If + ``True``, the default positional encoder is used. ``False`` disables + positional encodings, typically only for ablation tests. + max_charge : int, optional + The maximum charge state for peptide sequences. + """ + + def __init__( + self, + n_tokens: int | Tokenizer, + d_model: int = 128, + n_head: int = 8, + dim_feedforward: int = 1024, + n_layers: int = 1, + dropout: float = 0, + positional_encoder: PositionalEncoder | bool = True, + padding_int: int | None = None, + max_charge: int = 10, + ) -> None: + """Initialize a PeptideDecoder.""" + + super().__init__( + n_tokens=n_tokens, + d_model=d_model, + nhead=n_head, + dim_feedforward=dim_feedforward, + n_layers=n_layers, + dropout=dropout, + positional_encoder=positional_encoder, + padding_int=padding_int, + ) + + self.charge_encoder = torch.nn.Embedding(max_charge, d_model) + self.mass_encoder = FloatEncoder(d_model) + + # override final layer: + # +1 in comparison to version in depthcharge to second dimension + # This includes padding (=0) as a possible class + # and avoids problems during beam search decoding + self.final = torch.nn.Linear( + d_model, + self.token_encoder.num_embeddings, + ) + + def global_token_hook( + self, + tokens: torch.Tensor, + precursors: torch.Tensor, + **kwargs: dict, + ) -> torch.Tensor: + """ + Override global_token_hook to include precursor information. + + Parameters + ---------- + tokens : list of str, torch.Tensor, or None + The partial molecular sequences for which to predict the next + token. Optionally, these may be the token indices instead + of a string. + precursors : torch.Tensor + Precursor information. + **kwargs : dict + Additional data passed with the batch. + + Returns + ------- + torch.Tensor of shape (batch_size, d_model) + The global token representations. + + """ + masses = self.mass_encoder(precursors[:, None, 0]).squeeze(1) + charges = self.charge_encoder(precursors[:, 1].int() - 1) + precursors = masses + charges + return precursors + + +class SpectrumEncoder(SpectrumTransformerEncoder): + """A Transformer encoder for input mass spectra. + + Parameters + ---------- + d_model : int, optional + The latent dimensionality to represent peaks in the mass spectrum. + n_head : int, optional + The number of attention heads in each layer. ``d_model`` must be + divisible by ``n_head``. + dim_feedforward : int, optional + The dimensionality of the fully connected layers in the Transformer + layers of the model. + n_layers : int, optional + The number of Transformer layers. + dropout : float, optional + The dropout probability for all layers. + peak_encoder : bool, optional + Use positional encodings m/z values of each peak. + dim_intensity: int or None, optional + The number of features to use for encoding peak intensity. + The remaining (``d_model - dim_intensity``) are reserved for + encoding the m/z value. + """ + + def __init__( + self, + d_model: int = 128, + n_head: int = 8, + dim_feedforward: int = 1024, + n_layers: int = 1, + dropout: float = 0, + peak_encoder: PeakEncoder | Callable | bool = True, + ): + """Initialize a SpectrumEncoder""" + super().__init__( + d_model, n_head, dim_feedforward, n_layers, dropout, peak_encoder + ) + + self.latent_spectrum = torch.nn.Parameter(torch.randn(1, 1, d_model)) + + def global_token_hook( + self, + mz_array: torch.Tensor, + intensity_array: torch.Tensor, + *args: torch.Tensor, + **kwargs: dict, + ) -> torch.Tensor: + """Override global_token_hook to include + lantent_spectrum parameter + + Parameters + ---------- + mz_array : torch.Tensor of shape (n_spectra, n_peaks) + The zero-padded m/z dimension for a batch of mass spectra. + intensity_array : torch.Tensor of shape (n_spectra, n_peaks) + The zero-padded intensity dimension for a batch of mass spctra. + *args : torch.Tensor + Additional data passed with the batch. + **kwargs : dict + Additional data passed with the batch. + + Returns + ------- + torch.Tensor of shape (batch_size, d_model) + The precursor representations. + + """ + return self.latent_spectrum.squeeze(0).expand(mz_array.shape[0], -1) diff --git a/casanovo/utils.py b/casanovo/utils.py index 86e0748f..406e6874 100644 --- a/casanovo/utils.py +++ b/casanovo/utils.py @@ -17,7 +17,6 @@ from .data.psm import PepSpecMatch - SCORE_BINS = (0.0, 0.5, 0.9, 0.95, 0.99) logger = logging.getLogger("casanovo") @@ -39,6 +38,10 @@ def n_workers() -> int: int The number of workers. """ + # FIXME: remove multiprocessing Linux deadlock issue workaround when + # deadlock issue is resolved. + return 0 + # Windows or MacOS: no multiprocessing. if platform.system() in ["Windows", "Darwin"]: logger.warning( diff --git a/casanovo/version.py b/casanovo/version.py index 579db300..eb817aae 100644 --- a/casanovo/version.py +++ b/casanovo/version.py @@ -18,7 +18,7 @@ def _get_version() -> Optional[str]: """ try: # Fast, but only works in Python 3.8+. - from importlib.metadata import version, PackageNotFoundError + from importlib.metadata import PackageNotFoundError, version try: return version("casanovo") @@ -26,7 +26,7 @@ def _get_version() -> Optional[str]: return None except ImportError: # Slow, but works for all Python 3+. - from pkg_resources import get_distribution, DistributionNotFound + from pkg_resources import DistributionNotFound, get_distribution try: return get_distribution("casanovo").version diff --git a/docs/conf.py b/docs/conf.py index 56f7ecb0..a1955a8f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,3 +1,9 @@ +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys from importlib.metadata import version # Configuration file for the Sphinx documentation builder. @@ -8,13 +14,6 @@ # -- Path setup -------------------------------------------------------------- -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import sys - sys.path.insert(0, os.path.abspath(".")) diff --git a/pyproject.toml b/pyproject.toml index 3967bf05..6d80ff83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,12 +22,13 @@ dependencies = [ "appdirs", "lightning>=2.1", "click", - "depthcharge-ms>=0.2.3,<0.3.0", + "depthcharge-ms>=0.4.8,<0.5.0", "natsort", "numpy<2.0", "pandas", "psutil", "PyGithub", + "pylance==0.15.0", "PyYAML", "requests", "rich-click>=1.6.1", diff --git a/tests/conftest.py b/tests/conftest.py index a35c5834..4cc02aed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ """Fixtures used for testing.""" -import depthcharge import numpy as np import psims import pytest @@ -81,9 +80,13 @@ def _create_mgf( rng = np.random.default_rng(random_state) entries = [ _create_mgf_entry( - p, rng.choice([2, 3]), mod_aa_mass=mod_aa_mass, annotate=annotate + p, + i, + rng.choice([2, 3]), + mod_aa_mass=mod_aa_mass, + annotate=annotate, ) - for p in peptides + for i, p in enumerate(peptides) ] with mgf_file.open("w+") as mgf_ref: mgf_ref.write("\n".join(entries)) @@ -91,7 +94,9 @@ def _create_mgf( return mgf_file -def _create_mgf_entry(peptide, charge=2, mod_aa_mass=None, annotate=True): +def _create_mgf_entry( + peptide, title, charge=2, mod_aa_mass=None, annotate=True +): """ Create a MassIVE-KB style MGF entry for a single PSM. @@ -122,6 +127,7 @@ def _create_mgf_entry(peptide, charge=2, mod_aa_mass=None, annotate=True): mgf = [ "BEGIN IONS", + f"TITLE={title}", f"PEPMASS={precursor_mz}", f"CHARGE={charge}+", f"{frags}", @@ -247,9 +253,8 @@ def _create_mzml(peptides, mzml_file, random_state=42): return mzml_file -@pytest.fixture -def tiny_config(tmp_path): - """A config file for a tiny model.""" +def get_config_file(file_path, file_name, additional_cfg=None): + """Get Casanovo config yaml file""" cfg = { "n_head": 2, "dim_feedforward": 10, @@ -291,6 +296,16 @@ def tiny_config(tmp_path): "train_batch_size": 32, "num_sanity_val_steps": 0, "calculate_precision": False, + "lance_dir": None, + "shuffle": False, + "buffer_size": 64, + "accumulate_grad_batches": 1, + "gradient_clip_val": None, + "gradient_clip_algorithm": None, + "precision": "32-true", + "replace_isoleucine_with_leucine": True, + "reverse_peptides": False, + "mskb_tokenizer": True, "residues": { "G": 57.021464, "A": 71.037114, @@ -298,7 +313,7 @@ def tiny_config(tmp_path): "P": 97.052764, "V": 99.068414, "T": 101.047670, - "C+57.021": 160.030649, + "C[Carbamidomethyl]": 160.030649, # 103.009185 + 57.021464 "L": 113.084064, "I": 113.084064, "N": 114.042927, @@ -312,22 +327,27 @@ def tiny_config(tmp_path): "R": 156.101111, "Y": 163.063329, "W": 186.079313, - "M+15.995": 147.035400, - "N+0.984": 115.026943, - "Q+0.984": 129.042594, - "+42.011": 42.010565, - "+43.006": 43.005814, - "-17.027": -17.026549, - "+43.006-17.027": 25.980265, + # Amino acid modifications. + "M[Oxidation]": 147.035400, # Met oxidation: 131.040485 + 15.994915 + "N[Deamidated]": 115.026943, # Asn deamidation: 114.042927 + 0.984016 + "Q[Deamidated]": 129.042594, # Gln deamidation: 128.058578 + 0.984016 + # N-terminal modifications. + "[Acetyl]-": 42.010565, # Acetylation + "[Carbamyl]-": 43.005814, # Carbamylation "+43.006" + "[Ammonia-loss]-": -17.026549, # NH3 loss + "[+25.980265]-": 25.980265, # Carbamylation and NH3 loss }, - "allowed_fixed_mods": "C:C+57.021", + "allowed_fixed_mods": "C:C[Carbamidomethyl]", "allowed_var_mods": ( - "M:M+15.995,N:N+0.984,Q:Q+0.984," - "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" + "M:M[Oxidation],N:N[Deamidated],Q:Q[Deamidated]," + "nterm:[Acetyl]-,nterm:[Carbamyl]-,nterm:[Ammonia-loss]-,nterm:[+25.980265]-" ), } - cfg_file = tmp_path / "config.yml" + if additional_cfg is not None: + cfg.update(additional_cfg) + + cfg_file = file_path / file_name with cfg_file.open("w+") as out_file: yaml.dump(cfg, out_file) @@ -335,5 +355,16 @@ def tiny_config(tmp_path): @pytest.fixture -def residues_dict(): - return depthcharge.masses.PeptideMass("massivekb").masses +def tiny_config(tmp_path): + """A config file for a tiny model.""" + return get_config_file(tmp_path, "config.yml") + + +@pytest.fixture +def tiny_config_db(tmp_path): + """A config file for a db search.""" + return get_config_file( + tmp_path, + "config_db.yml", + additional_cfg={"replace_isoleucine_with_leucine": False}, + ) diff --git a/tests/test_integration.py b/tests/test_integration.py index 7dab1b5b..948cff63 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -7,79 +7,18 @@ from casanovo import casanovo - TEST_DIR = Path(__file__).resolve().parent -def test_db_search( - mgf_medium, tiny_fasta_file, tiny_config, tmp_path, monkeypatch -): - # Run a command: - monkeypatch.setattr(casanovo, "__version__", "4.1.0") - run = functools.partial( - CliRunner().invoke, casanovo.main, catch_exceptions=False - ) - - output_rootname = "db" - output_filename = (tmp_path / output_rootname).with_suffix(".mztab") - - search_args = [ - "db-search", - "--config", - tiny_config, - "--output_dir", - str(tmp_path), - "--output_root", - output_rootname, - str(mgf_medium), - str(tiny_fasta_file), - ] - - result = run(search_args) - - assert result.exit_code == 0 - assert output_filename.exists() - - mztab = pyteomics.mztab.MzTab(str(output_filename)) - - psms = mztab.spectrum_match_table - assert list(psms.sequence) == [ - "ATSIPAR", - "VTLSC+57.021R", - "LLIYGASTR", - "EIVMTQSPPTLSLSPGER", - "MEAPAQLLFLLLLWLPDTTR", - "ASQSVSSSYLTWYQQKPGQAPR", - "FSGSGSGTDFTLTISSLQPEDFAVYYC+57.021QQDYNLP", - ] - - # Validate mztab output - validate_args = [ - "java", - "-jar", - f"{TEST_DIR}/jmzTabValidator.jar", - "--check", - f"inFile={output_filename}", - ] - - validate_result = subprocess.run( - validate_args, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - - assert validate_result.returncode == 0 - assert not any( - [ - line.startswith("[Error-") - for line in validate_result.stdout.splitlines() - ] - ) - - def test_train_and_run( - mgf_small, mzml_small, tiny_config, tmp_path, monkeypatch + mgf_small, + mzml_small, + tiny_config, + tiny_config_db, + tmp_path, + monkeypatch, + mgf_medium, + tiny_fasta_file, ): # We can use this to explicitly test different versions. monkeypatch.setattr(casanovo, "__version__", "3.0.1") @@ -92,7 +31,6 @@ def test_train_and_run( # Train a tiny model: train_args = [ "train", - "--validation_peak_path", str(mgf_small), "--config", tiny_config, @@ -100,7 +38,6 @@ def test_train_and_run( str(tmp_path), "--output_root", "train", - str(mgf_small), # The training files. ] result = run(train_args) @@ -166,7 +103,6 @@ def test_train_and_run( "--output_root", output_rootname, str(mgf_small), - str(mzml_small), "--evaluate", ] @@ -214,6 +150,66 @@ def test_train_and_run( assert output_filename.is_file() + monkeypatch.setattr(casanovo, "__version__", "4.1.0") + output_rootname = "db" + output_filename = (tmp_path / output_rootname).with_suffix(".mztab") + + search_args = [ + "db-search", + "--model", + str(model_file), + "--config", + tiny_config_db, + "--output_dir", + str(tmp_path), + "--output_root", + output_rootname, + str(mgf_medium), + str(tiny_fasta_file), + ] + + result = run(search_args) + + assert result.exit_code == 0 + assert output_filename.exists() + + mztab = pyteomics.mztab.MzTab(str(output_filename)) + + psms = mztab.spectrum_match_table + assert list(psms.sequence) == [ + "ATSIPAR", + "VTLSC[Carbamidomethyl]R", + "LLIYGASTR", + "EIVMTQSPPTLSLSPGER", + "MEAPAQLLFLLLLWLPDTTR", + "ASQSVSSSYLTWYQQKPGQAPR", + "FSGSGSGTDFTLTISSLQPEDFAVYYC[Carbamidomethyl]QQDYNLP", + ] + + # Validate mztab output + validate_args = [ + "java", + "-jar", + f"{TEST_DIR}/jmzTabValidator.jar", + "--check", + f"inFile={output_filename}", + ] + + validate_result = subprocess.run( + validate_args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + assert validate_result.returncode == 0 + assert not any( + [ + line.startswith("[Error-") + for line in validate_result.stdout.splitlines() + ] + ) + def test_auxilliary_cli(tmp_path, monkeypatch): """Test the secondary CLI commands""" diff --git a/tests/unit_tests/test_run_stats.py b/tests/unit_tests/test_run_stats.py index 9a438673..a2149381 100644 --- a/tests/unit_tests/test_run_stats.py +++ b/tests/unit_tests/test_run_stats.py @@ -4,8 +4,7 @@ import numpy as np import pandas as pd -from casanovo.utils import get_score_bins, get_peptide_lengths - +from casanovo.utils import get_peptide_lengths, get_score_bins np.random.seed(4000) random.seed(4000) diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py index cf04cf83..10a8d4ef 100644 --- a/tests/unit_tests/test_runner.py +++ b/tests/unit_tests/test_runner.py @@ -4,6 +4,7 @@ import unittest.mock from pathlib import Path +import depthcharge.tokenizers.peptides import pytest import torch @@ -16,16 +17,25 @@ def test_initialize_model(tmp_path, mgf_small): """Test initializing a new or existing model.""" config = Config() config.model_save_folder_path = tmp_path + # Initializing model without initializing tokenizer raises an error + with pytest.raises(RuntimeError): + ModelRunner(config=config).initialize_model(train=True) + # No model filename given, so train from scratch. - ModelRunner(config=config).initialize_model(train=True) + runner = ModelRunner(config=config) + runner.initialize_tokenizer() + runner.initialize_model(train=True) # No model filename given during inference = error. with pytest.raises(ValueError): - ModelRunner(config=config).initialize_model(train=False) + runner = ModelRunner(config=config) + runner.initialize_tokenizer() + runner.initialize_model(train=False) # Non-existing model filename given during inference = error. with pytest.raises(FileNotFoundError): runner = ModelRunner(config=config, model_filename="blah") + runner.initialize_tokenizer() runner.initialize_model(train=False) # Train a quick model. @@ -38,10 +48,12 @@ def test_initialize_model(tmp_path, mgf_small): # Resume training from previous model. runner = ModelRunner(config=config, model_filename=str(ckpt)) + runner.initialize_tokenizer() runner.initialize_model(train=True) # Inference with previous model. runner = ModelRunner(config=config, model_filename=str(ckpt)) + runner.initialize_tokenizer() runner.initialize_model(train=False) # If the model initialization throws and EOFError, then the Spec2Pep model @@ -50,6 +62,7 @@ def test_initialize_model(tmp_path, mgf_small): weights.touch() with pytest.raises(EOFError): runner = ModelRunner(config=config, model_filename=str(weights)) + runner.initialize_tokenizer() runner.initialize_model(train=False) @@ -74,6 +87,7 @@ def test_save_and_load_weights(tmp_path, mgf_small, tiny_config): # Now load the weights into a new model # The device should be meta for all the weights. runner = ModelRunner(config=other_config, model_filename=str(ckpt)) + runner.initialize_tokenizer() runner.initialize_model(train=False) obs_layers = runner.model.encoder.transformer_encoder.num_layers @@ -127,6 +141,7 @@ def test_save_and_load_weights_deprecated(tmp_path, mgf_small, tiny_config): with ModelRunner( config=config, model_filename=str(ckpt), overwrite_ckpt_check=False ) as runner: + runner.initialize_tokenizer() runner.initialize_model(train=False) assert runner.model.cosine_schedule_period_iters == 5 # Fine-tuning. @@ -141,7 +156,7 @@ def test_save_and_load_weights_deprecated(tmp_path, mgf_small, tiny_config): assert "max_iters" not in runner.model.opt_kwargs -def test_calculate_precision(tmp_path, mgf_small, tiny_config): +def test_calculate_precision(tmp_path, mgf_small, tiny_config, monkeypatch): """Test that this parameter is working correctly.""" config = Config(tiny_config) config.n_layers = 1 @@ -149,22 +164,42 @@ def test_calculate_precision(tmp_path, mgf_small, tiny_config): config.calculate_precision = False config.tb_summarywriter = str(tmp_path) - runner = ModelRunner(config=config, output_dir=tmp_path) - with runner: - runner.train([mgf_small], [mgf_small]) + with monkeypatch.context() as ctx: + mock_logger = unittest.mock.MagicMock() + ctx.setattr("casanovo.denovo.model.logger", mock_logger) + runner = ModelRunner(config=config, output_dir=tmp_path) + with runner: + runner.train([mgf_small], [mgf_small]) - assert "valid_aa_precision" not in runner.model.history.columns - assert "valid_pep_precision" not in runner.model.history.columns + logged_items = [ + item + for call in mock_logger.info.call_args_list + for arg in call.args + for item in (arg.split("\t") if isinstance(arg, str) else [arg]) + ] + + assert "AA precision" not in logged_items + assert "Peptide precision" not in logged_items config.calculate_precision = True - runner = ModelRunner( - config=config, output_dir=tmp_path, overwrite_ckpt_check=False - ) - with runner: - runner.train([mgf_small], [mgf_small]) + with monkeypatch.context() as ctx: + mock_logger = unittest.mock.MagicMock() + ctx.setattr("casanovo.denovo.model.logger", mock_logger) + runner = ModelRunner( + config=config, output_dir=tmp_path, overwrite_ckpt_check=False + ) + with runner: + runner.train([mgf_small], [mgf_small]) - assert "valid_aa_precision" in runner.model.history.columns - assert "valid_pep_precision" in runner.model.history.columns + logged_items = [ + item + for call in mock_logger.info.call_args_list + for arg in call.args + for item in (arg.split("\t") if isinstance(arg, str) else [arg]) + ] + + assert "AA precision" in logged_items + assert "Peptide precision" in logged_items def test_save_final_model(tmp_path, mgf_small, tiny_config): @@ -223,12 +258,12 @@ def test_evaluate( result_file.unlink() exception_string = ( - "Error creating annotated spectrum index. " - "This may be the result of having an unannotated MGF file " + "Error creating annotated spectrum dataloaders. " + "This may be the result of having an unannotated peak file " "present in the validation peak file path list.\n" ) - with pytest.raises(FileNotFoundError): + with pytest.raises(TypeError): with ModelRunner( config, model_filename=str(model_file), overwrite_ckpt_check=False ) as runner: @@ -254,7 +289,7 @@ def test_evaluate( result_file.unlink() # Test mix of annotated an unannotated peak files - with pytest.warns(RuntimeWarning): + with pytest.raises(TypeError): with ModelRunner( config, model_filename=str(model_file), overwrite_ckpt_check=False ) as runner: @@ -326,19 +361,16 @@ def test_metrics_logging(tmp_path, mgf_small, tiny_config): def test_log_metrics(monkeypatch, tiny_config): - def get_mock_index(psm_list): - mock_test_index = unittest.mock.MagicMock() - mock_test_index.__enter__.return_value = mock_test_index - mock_test_index.__exit__.return_value = False - mock_test_index.n_spectra = len(psm_list) - mock_test_index.get_spectrum_id = lambda idx: psm_list[idx].spectrum_id - - mock_spectra = [ - (None, None, None, None, curr_psm.sequence) - for curr_psm in psm_list + def get_mock_loader(psm_list, tokenizer): + return [ + { + "peak_file": [psm.spectrum_id[0] for psm in psm_list], + "scan_id": [psm.spectrum_id[1] for psm in psm_list], + "seq": tokenizer.tokenize( + [psm.sequence for psm in psm_list] + ).unsqueeze(0), + } ] - mock_test_index.__getitem__.side_effect = lambda idx: mock_spectra[idx] - return mock_test_index def get_mock_psm(sequence, spectrum_id): return PepSpecMatch( @@ -357,6 +389,10 @@ def get_mock_psm(sequence, spectrum_id): with ModelRunner(Config(tiny_config)) as runner: runner.writer = unittest.mock.MagicMock() + runner.model = unittest.mock.MagicMock() + runner.model.tokenizer = ( + depthcharge.tokenizers.peptides.MskbPeptideTokenizer() + ) # Test 100% peptide precision infer_psms = [ @@ -370,7 +406,7 @@ def get_mock_psm(sequence, spectrum_id): ] runner.writer.psms = infer_psms - mock_index = get_mock_index(act_psms) + mock_index = get_mock_loader(act_psms, runner.model.tokenizer) runner.log_metrics(mock_index) pep_precision = mock_logger.info.call_args_list[-3][0][1] @@ -392,7 +428,7 @@ def get_mock_psm(sequence, spectrum_id): ] runner.writer.psms = infer_psms - mock_index = get_mock_index(act_psms) + mock_index = get_mock_loader(act_psms, runner.model.tokenizer) runner.log_metrics(mock_index) pep_precision = mock_logger.info.call_args_list[-3][0][1] @@ -419,7 +455,7 @@ def get_mock_psm(sequence, spectrum_id): ] runner.writer.psms = infer_psms - mock_index = get_mock_index(act_psms) + mock_index = get_mock_loader(act_psms, runner.model.tokenizer) runner.log_metrics(mock_index) pep_precision = mock_logger.info.call_args_list[-3][0][1] @@ -437,7 +473,7 @@ def get_mock_psm(sequence, spectrum_id): ] runner.writer.psms = infer_psms - mock_index = get_mock_index(act_psms) + mock_index = get_mock_loader(act_psms, runner.model.tokenizer) runner.log_metrics(mock_index) pep_precision = mock_logger.info.call_args_list[-3][0][1] @@ -453,7 +489,7 @@ def get_mock_psm(sequence, spectrum_id): ] runner.writer.psms = infer_psms - mock_index = get_mock_index(act_psms) + mock_index = get_mock_loader(act_psms, runner.model.tokenizer) runner.log_metrics(mock_index) pep_precision = mock_logger.info.call_args_list[-3][0][1] @@ -469,7 +505,7 @@ def get_mock_psm(sequence, spectrum_id): ] runner.writer.psms = infer_psms - mock_index = get_mock_index(act_psms) + mock_index = get_mock_loader(act_psms, runner.model.tokenizer) runner.log_metrics(mock_index) pep_precision = mock_logger.info.call_args_list[-3][0][1] @@ -496,7 +532,7 @@ def get_mock_psm(sequence, spectrum_id): ] runner.writer.psms = infer_psms - mock_index = get_mock_index(act_psms) + mock_index = get_mock_loader(act_psms, runner.model.tokenizer) runner.log_metrics(mock_index) pep_precision = mock_logger.info.call_args_list[-3][0][1] diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index 00617457..2a701703 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -1,34 +1,42 @@ import collections +import copy import datetime import functools import hashlib import heapq import io +import math import os import pathlib import platform import re -import requests import shutil import tempfile import unittest import unittest.mock -import depthcharge.masses +import depthcharge +import depthcharge.data +import depthcharge.tokenizers.peptides import einops import github import numpy as np import pandas as pd import pytest +import requests import torch -from casanovo import casanovo -from casanovo import utils +from casanovo import casanovo, utils +from casanovo.config import Config from casanovo.data import db_utils, ms_io -from casanovo.data.datasets import SpectrumDataset, AnnotatedSpectrumDataset +from casanovo.denovo.dataloaders import DeNovoDataModule from casanovo.denovo.evaluate import aa_match, aa_match_batch, aa_match_metrics -from casanovo.denovo.model import Spec2Pep, _aa_pep_score, _calc_match_score -from depthcharge.data import SpectrumIndex, AnnotatedSpectrumIndex +from casanovo.denovo.model import ( + DbSpec2Pep, + Spec2Pep, + _aa_pep_score, + _calc_match_score, +) def test_version(): @@ -36,6 +44,7 @@ def test_version(): assert casanovo.__version__ is not None +@pytest.mark.skip(reason="Skipping due to Linux deadlock issue") def test_n_workers(monkeypatch): """Check that n_workers is correct without a GPU.""" monkeypatch.setattr("torch.cuda.is_available", lambda: False) @@ -420,18 +429,6 @@ def test_is_valid_url(): assert not casanovo._is_valid_url("foobar") -def test_tensorboard(): - """ - Test that the tensorboard.SummaryWriter object is only created when a folder - path is passed. - """ - model = Spec2Pep(tb_summarywriter="test_path") - assert model.tb_summarywriter is not None - - model = Spec2Pep() - assert model.tb_summarywriter is None - - def test_aa_pep_score(): """ Test the calculation of amino acid and peptide scores from the raw amino @@ -453,7 +450,10 @@ def test_aa_pep_score(): assert peptide_score == pytest.approx(0.5) -def test_peptide_generator_errors(residues_dict, tiny_fasta_file): +def test_peptide_generator_errors(tiny_fasta_file): + residues_dict = ( + depthcharge.tokenizers.PeptideTokenizer.from_massivekb().residues + ) with pytest.raises(FileNotFoundError): [ (a, b) @@ -572,8 +572,7 @@ def test_calc_match_score(): ) -def test_digest_fasta_cleave(tiny_fasta_file, residues_dict): - +def test_digest_fasta_cleave(tiny_fasta_file): # No missed cleavages expected_normal = [ "ATSIPAR", @@ -643,12 +642,12 @@ def test_digest_fasta_cleave(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) assert pdb.db_peptides.index.to_list() == expected -def test_digest_fasta_mods(tiny_fasta_file, residues_dict): +def test_digest_fasta_mods(tiny_fasta_file): # 1 modification allowed # fixed: C+57.02146 # variable: 1M+15.994915,1N+0.984016,1Q+0.984016 @@ -677,21 +676,21 @@ def test_digest_fasta_mods(tiny_fasta_file, residues_dict): "+42.011EIVMTQSPPTLSLSPGER", "+43.006EIVMTQSPPTLSLSPGER", "-17.027MEAPAQLLFLLLLWLPDTTR", - "-17.027M+15.995EAPAQLLFLLLLWLPDTTR", # + "-17.027M+15.995EAPAQLLFLLLLWLPDTTR", "MEAPAQLLFLLLLWLPDTTR", "MEAPAQ+0.984LLFLLLLWLPDTTR", "M+15.995EAPAQLLFLLLLWLPDTTR", "+43.006-17.027MEAPAQLLFLLLLWLPDTTR", - "+43.006-17.027M+15.995EAPAQLLFLLLLWLPDTTR", # + "+43.006-17.027M+15.995EAPAQLLFLLLLWLPDTTR", "+42.011MEAPAQLLFLLLLWLPDTTR", "+43.006MEAPAQLLFLLLLWLPDTTR", - "+42.011M+15.995EAPAQLLFLLLLWLPDTTR", # - "+43.006M+15.995EAPAQLLFLLLLWLPDTTR", # + "+42.011M+15.995EAPAQLLFLLLLWLPDTTR", + "+43.006M+15.995EAPAQLLFLLLLWLPDTTR", "-17.027ASQSVSSSYLTWYQQKPGQAPR", "ASQSVSSSYLTWYQQKPGQAPR", - "ASQ+0.984SVSSSYLTWYQQKPGQAPR", "ASQSVSSSYLTWYQ+0.984QKPGQAPR", "ASQSVSSSYLTWYQQ+0.984KPGQAPR", + "ASQ+0.984SVSSSYLTWYQQKPGQAPR", "ASQSVSSSYLTWYQQKPGQ+0.984APR", "+43.006-17.027ASQSVSSSYLTWYQQKPGQAPR", "+42.011ASQSVSSSYLTWYQQKPGQAPR", @@ -699,9 +698,9 @@ def test_digest_fasta_mods(tiny_fasta_file, residues_dict): "-17.027FSGSGSGTDFTLTISSLQPEDFAVYYC+57.021QQDYNLP", "FSGSGSGTDFTLTISSLQPEDFAVYYC+57.021QQDYNLP", "FSGSGSGTDFTLTISSLQ+0.984PEDFAVYYC+57.021QQDYNLP", + "FSGSGSGTDFTLTISSLQPEDFAVYYC+57.021QQDYN+0.984LP", "FSGSGSGTDFTLTISSLQPEDFAVYYC+57.021Q+0.984QDYNLP", "FSGSGSGTDFTLTISSLQPEDFAVYYC+57.021QQ+0.984DYNLP", - "FSGSGSGTDFTLTISSLQPEDFAVYYC+57.021QQDYN+0.984LP", "+43.006-17.027FSGSGSGTDFTLTISSLQPEDFAVYYC+57.021QQDYNLP", "+42.011FSGSGSGTDFTLTISSLQPEDFAVYYC+57.021QQDYNLP", "+43.006FSGSGSGTDFTLTISSLQPEDFAVYYC+57.021QQDYNLP", @@ -721,12 +720,12 @@ def test_digest_fasta_mods(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) assert pdb.db_peptides.index.to_list() == expected_1mod -def test_length_restrictions(tiny_fasta_file, residues_dict): +def test_length_restrictions(tiny_fasta_file): # length between 20 and 50 expected_long = [ "MEAPAQLLFLLLLWLPDTTR", @@ -752,7 +751,7 @@ def test_length_restrictions(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) assert pdb.db_peptides.index.to_list() == expected_long @@ -771,12 +770,12 @@ def test_length_restrictions(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) assert pdb.db_peptides.index.to_list() == expected_short -def test_digest_fasta_enzyme(tiny_fasta_file, residues_dict): +def test_digest_fasta_enzyme(tiny_fasta_file): # arg-c enzyme expected_argc = [ "ATSIPAR", @@ -845,8 +844,8 @@ def test_digest_fasta_enzyme(tiny_fasta_file, residues_dict): "QSPPTL", "SPGERV", "ISSLQP", - "RATSIP", "TSIPAR", + "RATSIP", "MEAPAQ", "RASQSV", "TISSLQ", @@ -879,8 +878,8 @@ def test_digest_fasta_enzyme(tiny_fasta_file, residues_dict): "AQLLFL", "QPEDFA", "TLSC+57.021RA", - "C+57.021RASQS", "SC+57.021RASQ", + "C+57.021RASQS", "DFTLTI", "PDTTRE", "TTREIV", @@ -897,8 +896,8 @@ def test_digest_fasta_enzyme(tiny_fasta_file, residues_dict): "LWLPDT", "QLLFLL", "LQPEDF", - "REIVMT", "TREIVM", + "REIVMT", "QDYNLP", "LLLWLP", "SSYLTW", @@ -917,8 +916,8 @@ def test_digest_fasta_enzyme(tiny_fasta_file, residues_dict): "TWYQQK", "VYYC+57.021QQ", "YLTWYQ", - "YC+57.021QQDY", "YYC+57.021QQD", + "YC+57.021QQDY", ] pdb = db_utils.ProteinDatabase( @@ -936,7 +935,7 @@ def test_digest_fasta_enzyme(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) assert pdb.db_peptides.index.to_list() == expected_argc @@ -955,7 +954,7 @@ def test_digest_fasta_enzyme(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) assert pdb.db_peptides.index.to_list() == expected_aspn @@ -975,7 +974,7 @@ def test_digest_fasta_enzyme(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) assert pdb.db_peptides.index.to_list() == expected_argc @@ -995,7 +994,7 @@ def test_digest_fasta_enzyme(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) assert pdb.db_peptides.index.to_list() == expected_semispecific @@ -1015,12 +1014,109 @@ def test_digest_fasta_enzyme(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) assert pdb.db_peptides.index.to_list() == expected_nonspecific -def test_get_candidates(tiny_fasta_file, residues_dict): +def test_psm_batches(tiny_config): + peptides_one = [ + "SGSGSG", + "GSGSGT", + "SGSGTD", + "FSGSGS", + "ATSIPA", + "GASTRA", + "LSLSPG", + "ASQSVS", + "GSGTDF", + "SLSPGE", + "AQLLFL", + "QPEDFA", + ] + + peptides_two = [ + "SQSVSS", + "KPGQAP", + "SPPTLS", + "ASTRAT", + "RFSGSG", + "IYGAST", + "APAQLL", + "PTLSLS", + "TLSLSP", + "TLTISS", + "WYQQKP", + "TWYQQK", + ] + + def mock_get_candidates(precursor_mz, precorsor_charge): + if precorsor_charge == 1: + return pd.Series(peptides_one) + elif precorsor_charge == 2: + return pd.Series(peptides_two) + else: + return pd.Series() + + tokenizer = depthcharge.tokenizers.peptides.PeptideTokenizer( + residues=Config(tiny_config).residues + ) + db_model = DbSpec2Pep(tokenizer=tokenizer) + db_model.protein_database = unittest.mock.MagicMock() + db_model.protein_database.get_candidates = mock_get_candidates + + mock_batch = { + "precursor_mz": torch.Tensor([42.0, 84.0, 126.0]), + "precursor_charge": torch.Tensor([1, 2, 3]), + "peak_file": ["one.mgf", "two.mgf", "three.mgf"], + "scan_id": [1, 2, 3], + } + + expected_batch_all = { + "precursor_mz": torch.Tensor([42.0] * 12 + [84.0] * 12), + "precursor_charge": torch.Tensor([1] * 12 + [2] * 12), + "seq": tokenizer.tokenize(peptides_one + peptides_two), + "peak_file": ["one.mgf"] * 12 + ["two.mgf"] * 12, + "scan_id": [1] * 12 + [2] * 12, + } + + for psm_batch_size in [24, 12, 8, 10]: + db_model.psm_batch_size = psm_batch_size + psm_batches = list(db_model._psm_batches(mock_batch)) + assert len(psm_batches) == math.ceil(24 / psm_batch_size) + num_spectra = 0 + + for psm_batch in psm_batches: + end_idx = min( + num_spectra + psm_batch_size, + len(expected_batch_all["peak_file"]), + ) + assert torch.allclose( + psm_batch["precursor_mz"], + expected_batch_all["precursor_mz"][num_spectra:end_idx], + ) + assert torch.equal( + psm_batch["precursor_charge"], + expected_batch_all["precursor_charge"][num_spectra:end_idx], + ) + assert torch.equal( + psm_batch["seq"], + expected_batch_all["seq"][num_spectra:end_idx], + ) + assert ( + psm_batch["peak_file"] + == expected_batch_all["peak_file"][num_spectra:end_idx] + ) + assert ( + psm_batch["scan_id"] + == expected_batch_all["scan_id"][num_spectra:end_idx] + ) + num_spectra += len(psm_batch["peak_file"]) + + assert num_spectra == 24 + + +def test_get_candidates(tiny_fasta_file): # precursor_window is 10000 expected_smallwindow = ["LLIYGASTR"] @@ -1045,7 +1141,7 @@ def test_get_candidates(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) candidates = pdb.get_candidates(precursor_mz=496.2, charge=2) assert expected_smallwindow == list(candidates) @@ -1065,7 +1161,7 @@ def test_get_candidates(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) candidates = pdb.get_candidates(precursor_mz=496.2, charge=2) assert expected_midwindow == list(candidates) @@ -1085,14 +1181,13 @@ def test_get_candidates(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) candidates = pdb.get_candidates(precursor_mz=496.2, charge=2) assert expected_widewindow == list(candidates) -def test_get_candidates_isotope_error(tiny_fasta_file, residues_dict): - +def test_get_candidates_isotope_error(tiny_fasta_file): # Tide isotope error windows for 496.2, 2+: # 0: [980.481617, 1000.289326] # 1: [979.491114, 999.278813] @@ -1153,7 +1248,7 @@ def test_get_candidates_isotope_error(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) pdb.db_peptides = peptide_list candidates = pdb.get_candidates(precursor_mz=496.2, charge=2) @@ -1174,7 +1269,7 @@ def test_get_candidates_isotope_error(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) pdb.db_peptides = peptide_list candidates = pdb.get_candidates(precursor_mz=496.2, charge=2) @@ -1195,7 +1290,7 @@ def test_get_candidates_isotope_error(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) pdb.db_peptides = peptide_list candidates = pdb.get_candidates(precursor_mz=496.2, charge=2) @@ -1216,25 +1311,32 @@ def test_get_candidates_isotope_error(tiny_fasta_file, residues_dict): "M:M+15.995,N:N+0.984,Q:Q+0.984," "nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" ), - residues=residues_dict, + tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) pdb.db_peptides = peptide_list candidates = pdb.get_candidates(precursor_mz=496.2, charge=2) assert expected_isotope0123 == list(candidates) -def test_beam_search_decode(): +def test_beam_search_decode(tiny_config): """ Test beam search decoding and its sub-functions. """ - model = Spec2Pep(n_beams=4, residues="massivekb", min_peptide_len=4) + config = casanovo.Config(tiny_config) + model = Spec2Pep( + n_beams=4, + residues="massivekb", + min_peptide_len=4, + tokenizer=depthcharge.tokenizers.peptides.PeptideTokenizer( + residues=config.residues + ), + ) model.decoder.reverse = False # For simplicity. - aa2idx = model.decoder._aa2idx # Sizes. batch = 1 # B length = model.max_peptide_len + 1 # L - vocab = model.decoder.vocab_size + 1 # V + vocab = len(model.tokenizer) + 1 # V beam = model.n_beams # S step = 3 @@ -1254,8 +1356,12 @@ def test_beam_search_decode(): ) # Fill scores and tokens with relevant predictions. scores[:, : step + 1, :] = 0 - for i, peptide in enumerate(["PEPK", "PEPR", "PEPG", "PEP$"]): - tokens[i, : step + 1] = torch.tensor([aa2idx[aa] for aa in peptide]) + for i, (peptide, add_stop) in enumerate( + [("PEPK", False), ("PEPR", False), ("PEPG", False), ("PEP", True)] + ): + tokens[i, : step + 1] = model.tokenizer.tokenize( + peptide, add_stop=add_stop + )[0] for j in range(step + 1): scores[i, j, tokens[1, j]] = 1 @@ -1285,14 +1391,15 @@ def test_beam_search_decode(): beam_fits_precursor, pred_cache, ) + # Verify that the correct peptides have been cached. correct_cached = 0 for _, _, _, pep in pred_cache[0]: - if torch.equal(pep, torch.tensor([4, 14, 4, 13])): + if torch.equal(pep, model.tokenizer.tokenize("PEPK")[0]): correct_cached += 1 - elif torch.equal(pep, torch.tensor([4, 14, 4, 18])): + elif torch.equal(pep, model.tokenizer.tokenize("PEPR")[0]): correct_cached += 1 - elif torch.equal(pep, torch.tensor([4, 14, 4])): + elif torch.equal(pep, model.tokenizer.tokenize("PEP")[0]): correct_cached += 1 else: pytest.fail( @@ -1304,16 +1411,22 @@ def test_beam_search_decode(): # Return the candidate peptide with the highest score test_cache = collections.OrderedDict((i, []) for i in range(batch)) heapq.heappush( - test_cache[0], (0.93, 0.1, 4 * [0.93], torch.tensor([4, 14, 4, 19])) + test_cache[0], + (0.93, 0.1, 4 * [0.93], model.tokenizer.tokenize("PEPY")[0]), ) heapq.heappush( - test_cache[0], (0.95, 0.2, 4 * [0.95], torch.tensor([4, 14, 4, 13])) + test_cache[0], + (0.95, 0.2, 4 * [0.95], model.tokenizer.tokenize("PEPK")[0]), ) heapq.heappush( - test_cache[0], (0.94, 0.3, 4 * [0.94], torch.tensor([4, 14, 4, 4])) + test_cache[0], + (0.94, 0.3, 4 * [0.94], model.tokenizer.tokenize("PEPP")[0]), ) - assert list(model._get_top_peptide(test_cache))[0][0][-1] == "PEPK" + assert torch.equal( + next(model._get_top_peptide(test_cache))[0][-1], + model.tokenizer.tokenize(["PEPK"])[0], + ) # Test that an empty predictions is returned when no beams have been # finished. empty_cache = collections.OrderedDict((i, []) for i in range(batch)) @@ -1321,30 +1434,30 @@ def test_beam_search_decode(): # Test multiple PSM per spectrum and if it's highest scoring peptides model.top_match = 2 assert set( - [pep[-1] for pep in list(model._get_top_peptide(test_cache))[0]] + [ + model.tokenizer.detokenize(pep[-1].unsqueeze(0))[0] + for pep in list(model._get_top_peptide(test_cache))[0] + ] ) == {"PEPK", "PEPP"} # Test _get_topk_beams(). # Set scores to proceed generating the unfinished beam. step = 4 scores[2, step, :] = 0 - scores[2, step, range(1, 5)] = torch.tensor([1.0, 2.0, 3.0, 4.0]) + next_tokens = model.tokenizer.tokenize(["P", "S", "A", "G"]).flatten() + scores[2, step, next_tokens] = torch.tensor([4.0, 3.0, 2.0, 1.0]) # Modify finished beams array to allow decoding from only one beam test_finished_beams = torch.tensor([True, True, False, True]) new_tokens, new_scores = model._get_topk_beams( tokens, scores, test_finished_beams, batch, step ) - expected_tokens = torch.tensor( - [ - [4, 14, 4, 1, 4], - [4, 14, 4, 1, 3], - [4, 14, 4, 1, 2], - [4, 14, 4, 1, 1], - ] + expected_tokens = model.tokenizer.tokenize( + ["PEPGP", "PEPGS", "PEPGA", "PEPGG"] ) + # Only the expected scores of the final step. expected_scores = torch.zeros(beam, vocab) - expected_scores[:, range(1, 5)] = torch.tensor([1.0, 2.0, 3.0, 4.0]) + expected_scores[:, next_tokens] = torch.tensor([4.0, 3.0, 2.0, 1.0]) assert torch.equal(new_tokens[:, : step + 1], expected_tokens) assert torch.equal(new_scores[:, step, :], expected_scores) @@ -1352,9 +1465,9 @@ def test_beam_search_decode(): # Test output if decoding loop isn't stopped with termination of all beams. model.max_peptide_len = 0 # 1 spectrum with 5 peaks (2 values: m/z and intensity). - spectra = torch.zeros(1, 5, 2) + mzs = ints = torch.zeros(1, 5) precursors = torch.tensor([[469.25364, 2.0, 235.63410]]) - assert len(list(model.beam_search_decode(spectra, precursors))[0]) == 0 + assert len(list(model.beam_search_decode(mzs, ints, precursors))[0]) == 0 model.max_peptide_len = 100 # Re-initialize scores and tokens to further test caching functionality. @@ -1365,8 +1478,9 @@ def test_beam_search_decode(): tokens = torch.zeros(batch * beam, length, dtype=torch.int64) scores[:, : step + 1, :] = 0 - for i, peptide in enumerate(["PKKP$", "EPPK$", "PEPK$", "PMKP$"]): - tokens[i, : step + 1] = torch.tensor([aa2idx[aa] for aa in peptide]) + tokens[:, : step + 1] = model.tokenizer.tokenize( + ["PKKP", "EPPK", "PEPK", "PMKP"], add_stop=True + ) i, j, s = np.arange(step), np.arange(4), torch.Tensor([4, 0.5, 3, 0.4]) scores[:, i, :] = 1 scores[j, i, tokens[j, i]] = s @@ -1387,10 +1501,16 @@ def test_beam_search_decode(): assert negative_score == 2 # Test using a single beam only. - model = Spec2Pep(n_beams=1, residues="massivekb", min_peptide_len=2) + model = Spec2Pep( + n_beams=1, + min_peptide_len=2, + tokenizer=depthcharge.tokenizers.peptides.MskbPeptideTokenizer( + residues=config.residues + ), + ) + vocab = len(model.tokenizer) + 1 beam = model.n_beams # S model.decoder.reverse = False # For simplicity. - aa2idx = model.decoder._aa2idx step = 4 # Initialize scores and tokens. @@ -1403,12 +1523,14 @@ def test_beam_search_decode(): pred_cache = collections.OrderedDict((i, []) for i in range(batch)) # Ground truth peptide is "PEPK". - true_peptide = "PEPK$" + true_peptide = "PEPK" precursors = torch.tensor([469.25364, 2.0, 235.63410]).repeat( beam * batch, 1 ) scores[:, range(step), :] = 1 - tokens[0, : step + 1] = torch.tensor([aa2idx[aa] for aa in true_peptide]) + tokens[0, : step + 1] = model.tokenizer.tokenize( + true_peptide, add_stop=True + )[0] # Test _finish_beams(). finished_beams, beam_fits_precursor, discarded_beams = model._finish_beams( @@ -1424,7 +1546,9 @@ def test_beam_search_decode(): tokens, scores, step, finished_beams, beam_fits_precursor, pred_cache ) - assert torch.equal(pred_cache[0][0][-1], torch.tensor([4, 14, 4, 13])) + assert torch.equal( + pred_cache[0][0][-1], model.tokenizer.tokenize(true_peptide)[0] + ) # Test _get_topk_beams(). step = 1 @@ -1455,9 +1579,13 @@ def test_beam_search_decode(): assert torch.equal(new_tokens[:, : step + 1], expected_tokens) # Test _finish_beams() for tokens with a negative mass. - model = Spec2Pep(n_beams=2, residues="massivekb") + model = Spec2Pep( + n_beams=2, + tokenizer=depthcharge.tokenizers.peptides.MskbPeptideTokenizer( + residues=config.residues + ), + ) beam = model.n_beams # S - aa2idx = model.decoder._aa2idx step = 1 # Ground truth peptide is "-17.027GK". @@ -1465,8 +1593,7 @@ def test_beam_search_decode(): beam * batch, 1 ) tokens = torch.zeros(batch * beam, length, dtype=torch.int64) - for i, peptide in enumerate(["GK", "AK"]): - tokens[i, : step + 1] = torch.tensor([aa2idx[aa] for aa in peptide]) + tokens[:, : step + 1] = model.tokenizer.tokenize(["GK", "AK"]) # Test _finish_beams(). finished_beams, beam_fits_precursor, discarded_beams = model._finish_beams( @@ -1477,26 +1604,34 @@ def test_beam_search_decode(): assert torch.equal(discarded_beams, torch.tensor([False, False])) # Test _finish_beams() for multiple/internal N-mods and dummy predictions. - model = Spec2Pep(n_beams=3, residues="massivekb", min_peptide_len=3) + model = Spec2Pep( + n_beams=3, + min_peptide_len=3, + tokenizer=depthcharge.tokenizers.peptides.PeptideTokenizer( + residues=config.residues + ), + ) beam = model.n_beams # S - model.decoder.reverse = True - aa2idx = model.decoder._aa2idx step = 4 # Ground truth peptide is irrelevant for this test. precursors = torch.tensor([1861.0044, 2.0, 940.5750]).repeat( beam * batch, 1 ) + + # sequences with invalid mass modifications will raise an exception if + # tokenized using tokenizer.tokenize tokens = torch.zeros(batch * beam, length, dtype=torch.int64) - # Reverse decoding - for i, peptide in enumerate( - [ - ["K", "A", "A", "A", "+43.006-17.027"], - ["K", "A", "A", "+42.011", "A"], - ["K", "A", "A", "+43.006", "+42.011"], - ] - ): - tokens[i, : step + 1] = torch.tensor([aa2idx[aa] for aa in peptide]) + sequences = [ + ["K", "A", "A", "A", "[+25.980265]-"], + ["K", "A", "A", "[Acetyl]-", "A"], + ["K", "A", "A", "[Carbamyl]-", "[Ammonia-loss]-"], + ] + + for i, seq in enumerate(sequences): + tokens[i, : step + 1] = torch.tensor( + [model.tokenizer.index[aa] for aa in seq] + ) # Test _finish_beams(). All should be discarded finished_beams, beam_fits_precursor, discarded_beams = model._finish_beams( @@ -1509,14 +1644,19 @@ def test_beam_search_decode(): assert torch.equal(discarded_beams, torch.tensor([False, True, True])) # Test _get_topk_beams() with finished beams in the batch. - model = Spec2Pep(n_beams=1, residues="massivekb", min_peptide_len=3) + model = Spec2Pep( + n_beams=1, + min_peptide_len=3, + tokenizer=depthcharge.tokenizers.peptides.PeptideTokenizer( + residues=config.residues + ), + ) # Sizes and other variables. batch = 2 # B beam = model.n_beams # S - model.decoder.reverse = True length = model.max_peptide_len + 1 # L - vocab = model.decoder.vocab_size + 1 # V + vocab = len(model.tokenizer) + 1 # V step = 4 # Initialize dummy scores and tokens. @@ -1531,8 +1671,8 @@ def test_beam_search_decode(): scores[:, step, range(1, 4)] = torch.tensor([1.0, 2.0, 3.0]) # Simulate one finished and one unfinished beam in the same batch. - tokens[0, :step] = torch.tensor([4, 14, 4, 28]) - tokens[1, :step] = torch.tensor([4, 14, 4, 1]) + tokens[0, :step] = model.tokenizer.tokenize("PEP", add_stop=True)[0] + tokens[1, :step] = model.tokenizer.tokenize("PEPG")[0] # Set finished beams array to allow decoding from only one beam. test_finished_beams = torch.tensor([True, False]) @@ -1542,22 +1682,23 @@ def test_beam_search_decode(): ) # Only the second peptide should have a new token predicted. - expected_tokens = torch.tensor( - [ - [4, 14, 4, 28, 0], - [4, 14, 4, 1, 3], - ] - ) + expected_tokens = tokens.clone() + expected_tokens[1, len("PEPG")] = 3 - assert torch.equal(new_tokens[:, : step + 1], expected_tokens) + assert torch.equal(new_tokens, expected_tokens) # Test that duplicate peptide scores don't lead to a conflict in the cache. - model = Spec2Pep(n_beams=5, residues="massivekb", min_peptide_len=3) + model = Spec2Pep( + n_beams=1, + min_peptide_len=3, + tokenizer=depthcharge.tokenizers.peptides.PeptideTokenizer( + residues=config.residues + ), + ) batch = 2 # B beam = model.n_beams # S - model.decoder.reverse = True length = model.max_peptide_len + 1 # L - vocab = model.decoder.vocab_size + 1 # V + vocab = len(model.tokenizer) + 1 # V step = 4 # Simulate beams with identical amino acid scores but different tokens. @@ -1591,7 +1732,7 @@ def test_eval_metrics(): the ground truth. A peptide prediction is correct if all its AA are correct matches. """ - model = Spec2Pep() + tokenizer = depthcharge.tokenizers.peptides.MskbPeptideTokenizer() preds = [ "SPEIK", @@ -1608,7 +1749,7 @@ def test_eval_metrics(): aa_matches, n_pred_aa, n_gt_aa = aa_match_batch( peptides1=preds, peptides2=gt, - aa_dict=model.decoder._peptide_mass.masses, + aa_dict=tokenizer.residues, mode="best", ) @@ -1623,16 +1764,12 @@ def test_eval_metrics(): assert 26 / 40 == pytest.approx(aa_recall) assert 26 / 41 == pytest.approx(aa_precision) - aa_matches, pep_match = aa_match( - None, None, depthcharge.masses.PeptideMass().masses - ) + aa_matches, pep_match = aa_match(None, None, tokenizer.residues) assert aa_matches.shape == (0,) assert not pep_match - aa_matches, pep_match = aa_match( - "PEPTIDE", None, depthcharge.masses.PeptideMass().masses - ) + aa_matches, pep_match = aa_match("PEPTIDE", None, tokenizer.residues) assert np.array_equal(aa_matches, np.zeros(len("PEPTIDE"), dtype=bool)) assert not pep_match @@ -1642,36 +1779,45 @@ def test_spectrum_id_mgf(mgf_small, tmp_path): """Test that spectra from MGF files are specified by their index.""" mgf_small2 = tmp_path / "mgf_small2.mgf" shutil.copy(mgf_small, mgf_small2) + data_module = DeNovoDataModule( + train_paths=[mgf_small, mgf_small2], + valid_paths=[mgf_small, mgf_small2], + test_paths=[mgf_small, mgf_small2], + shuffle=False, + ) + data_module.setup() - for index_func, dataset_func in [ - (SpectrumIndex, SpectrumDataset), - (AnnotatedSpectrumIndex, AnnotatedSpectrumDataset), + for dataset in [ + data_module.train_dataset, + data_module.valid_dataset, + data_module.test_dataset, ]: - index = index_func( - tmp_path / "index.hdf5", [mgf_small, mgf_small2], overwrite=True - ) - dataset = dataset_func(index) - for i, (filename, mgf_i) in enumerate( + for batch in dataset: + print(batch) + + for i, (filename, scan_id) in enumerate( [ - (mgf_small, 0), - (mgf_small, 1), - (mgf_small2, 0), - (mgf_small2, 1), + (mgf_small, "0"), + (mgf_small, "1"), + (mgf_small2, "0"), + (mgf_small2, "1"), ] ): - spectrum_id = str(filename), f"index={mgf_i}" - assert dataset.get_spectrum_id(i) == spectrum_id + assert dataset[i]["peak_file"][0] == filename.name + assert dataset[i]["scan_id"][0] == scan_id def test_spectrum_id_mzml(mzml_small, tmp_path): """Test that spectra from mzML files are specified by their scan number.""" mzml_small2 = tmp_path / "mzml_small2.mzml" shutil.copy(mzml_small, mzml_small2) - - index = SpectrumIndex( - tmp_path / "index.hdf5", [mzml_small, mzml_small2], overwrite=True + data_module = DeNovoDataModule( + test_paths=[mzml_small, mzml_small2], + shuffle=False, ) - dataset = SpectrumDataset(index) + data_module.setup(stage="test", annotated=False) + + dataset = data_module.test_dataset for i, (filename, scan_nr) in enumerate( [ (mzml_small, 17), @@ -1680,25 +1826,33 @@ def test_spectrum_id_mzml(mzml_small, tmp_path): (mzml_small2, 111), ] ): - spectrum_id = str(filename), f"scan={scan_nr}" - assert dataset.get_spectrum_id(i) == spectrum_id + assert dataset[i]["peak_file"][0] == filename.name + assert dataset[i]["scan_id"][0] == f"scan={scan_nr}" def test_train_val_step_functions(): """Test train and validation step functions operating on batches.""" + tokenizer = depthcharge.tokenizers.peptides.MskbPeptideTokenizer() model = Spec2Pep( n_beams=1, residues="massivekb", min_peptide_len=4, train_label_smoothing=0.1, + tokenizer=tokenizer, ) - spectra = torch.zeros(1, 5, 2) - precursors = torch.tensor([[469.25364, 2.0, 235.63410]]) - peptides = ["PEPK"] - batch = (spectra, precursors, peptides) - train_step_loss = model.training_step(batch) - val_step_loss = model.validation_step(batch) + batch = { + "mz_array": torch.zeros(1, 5), + "intensity_array": torch.zeros(1, 5), + "precursor_mz": torch.tensor(235.63410).unsqueeze(0), + "precursor_charge": torch.tensor(2.0).unsqueeze(0), + "seq": tokenizer.tokenize(["PEPK"]), + } + train_batch = {key: val.unsqueeze(0) for key, val in batch.items()} + val_batch = copy.deepcopy(train_batch) + + train_step_loss = model.training_step(train_batch) + val_step_loss = model.validation_step(val_batch) # Check if valid loss value returned assert train_step_loss > 0 @@ -1714,12 +1868,8 @@ def test_run_map(mgf_small): out_writer = ms_io.MztabWriter("dummy.mztab") # Set peak file by base file name only. out_writer.set_ms_run([os.path.basename(mgf_small.name)]) - assert os.path.basename(mgf_small.name) not in out_writer._run_map - assert os.path.abspath(mgf_small.name) in out_writer._run_map - # Set peak file by full path. - out_writer.set_ms_run([os.path.abspath(mgf_small.name)]) - assert os.path.basename(mgf_small.name) not in out_writer._run_map - assert os.path.abspath(mgf_small.name) in out_writer._run_map + assert mgf_small.name in out_writer._run_map + assert os.path.abspath(mgf_small.name) not in out_writer._run_map def test_check_dir(tmp_path):