Skip to content

Commit

Permalink
Add label smoothing to learning rate scheduler branch (#262)
Browse files Browse the repository at this point in the history
* Remove unused custom_encoder option (#254)

* resolves issue #238: remove custom_encoder option

* fixed lint issue

* fixed lint issue

* Revert "fixed lint issue"

This reverts commit bd1366c.

* lint

* lint issue

* Consistently format changelog.

---------

Co-authored-by: Isha Gokhale <[email protected]>
Co-authored-by: Wout Bittremieux <[email protected]>

* Correctly report AA precision and recall during validation (#253)

Fixes #252.

Co-authored-by: Melih Yilmaz <[email protected]>

* Remove gradient calculation during inference  (#258)

* Remove force_grad in inference

* Upgrade required PyTorch version

* Update CHANGELOG.md

* Update CHANGELOG.md

* Fix typo in torch version

* Specify correct Pytorch version change

---------

Co-authored-by: Wout Bittremieux <[email protected]>

* Add label smoothing

* Modify config file

* Minor fix config.yaml

* Run black

* Lint casanovo.py

---------

Co-authored-by: ishagokhale <[email protected]>
Co-authored-by: Isha Gokhale <[email protected]>
Co-authored-by: Wout Bittremieux <[email protected]>
Co-authored-by: Wout Bittremieux <[email protected]>
  • Loading branch information
5 people authored Nov 2, 2023
1 parent 5557d97 commit 7358564
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 58 deletions.
15 changes: 6 additions & 9 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,28 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Changed

- The CLI has been overhauled to use subcommands.
- Upgraded to Lightning >=2.0
- Upgraded to Lightning >=2.0.
- Checkpointing is configured to save the top-k models instead of all.
- Log steps rather than epochs as units of progress during training.
- Validation performance metrics are logged (and added to tensorboard) at the validation epoch, and training loss is logged at the end of training epoch, i.e. training and validation metrics are logged asynchronously.
- Irrelevant warning messages on the console output and in the log file are no longer shown.
- Nicely format logged warnings.
- `every_n_train_steps` has been renamed to `val_check_interval` in accordance to the corresponding Pytorch Lightning parameter.
- Training batches are randomly shuffled.

### Fixed

- Casanovo runs on CPU and can passes all tests.
- Enable gradients during prediction and validation to avoid NaNs from occuring as a temporary workaround until a new Pytorch version is available.
- Upgrade to depthcharge v0.2.3 for `PeptideTransformerDecoder` hotfix.
- Upgraded to Torch >=2.1.

### Removed

- Remove config option for a custom Pytorch Lightning logger.
- Remove superfluous `custom_encoder` config option.

### Fixed

- Casanovo now runs on CPU and can passes all tests.
- Upgrade to Depthcharge v0.2.0 to fix sinusoidal encoding.
- Casanovo runs on CPU and can pass all tests.
- Correctly refer to input peak files by their full file path.
- Specifying custom residues to retrain Casanovo is now possible.
- Upgrade to depthcharge v0.2.3 to fix sinusoidal encoding and for the `PeptideTransformerDecoder` hotfix.
- Correctly report amino acid precision and recall during validation.

## [3.3.0] - 2023-04-04

Expand Down
1 change: 1 addition & 0 deletions casanovo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Config:
residues=dict,
n_log=int,
tb_summarywriter=str,
train_label_smoothing=float,
lr_schedule=str,
warmup_iters=int,
max_iters=int,
Expand Down
5 changes: 2 additions & 3 deletions casanovo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ dropout: 0.0
# Number of dimensions to use for encoding peak intensity
# Projected up to ``dim_model`` by default and summed with the peak m/z encoding
dim_intensity:
# Option to provide a pre-trained spectrum encoder when training
# Trained from scratch by default
custom_encoder:
# Max decoded peptide length
max_length: 100
# Type of learning rate schedule to use. One of {constant, linear, cosine}.
Expand All @@ -94,6 +91,8 @@ max_iters: 600_000
learning_rate: 5e-4
# Regularization term for weight updates
weight_decay: 1e-5
# Amount of label smoothing when computing the training loss
train_label_smoothing: 0.01

# TRAINING/INFERENCE OPTIONS
# Number of spectra in one training batch
Expand Down
86 changes: 42 additions & 44 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ class Spec2Pep(pl.LightningModule, ModelMixin):
(``dim_model - dim_intensity``) are reserved for encoding the m/z value.
If ``None``, the intensity will be projected up to ``dim_model`` using a
linear layer, then summed with the m/z encoding for each peak.
custom_encoder : Optional[Union[SpectrumEncoder, PairedSpectrumEncoder]]
A pretrained encoder to use. The ``dim_model`` of the encoder must be
the same as that specified by the ``dim_model`` parameter here.
max_length : int
The maximum peptide length to decode.
residues: Union[Dict[str, float], str]
Expand Down Expand Up @@ -76,6 +73,8 @@ class Spec2Pep(pl.LightningModule, ModelMixin):
tb_summarywriter: Optional[str]
Folder path to record performance metrics during training. If ``None``,
don't use a ``SummaryWriter``.
train_label_smoothing: float
Smoothing factor when calculating the training loss.
warmup_iters: int
The number of warm up iterations for the learning rate scheduler.
max_iters: int
Expand All @@ -97,7 +96,6 @@ def __init__(
n_layers: int = 9,
dropout: float = 0.0,
dim_intensity: Optional[int] = None,
custom_encoder: Optional[SpectrumEncoder] = None,
max_length: int = 100,
residues: Union[Dict[str, float], str] = "canonical",
max_charge: int = 5,
Expand All @@ -110,6 +108,7 @@ def __init__(
tb_summarywriter: Optional[
torch.utils.tensorboard.SummaryWriter
] = None,
train_label_smoothing: float = 0.01,
lr_schedule=None,
warmup_iters: int = 100_000,
max_iters: int = 600_000,
Expand All @@ -121,17 +120,14 @@ def __init__(
self.save_hyperparameters()

# Build the model.
if custom_encoder is not None:
self.encoder = custom_encoder
else:
self.encoder = SpectrumEncoder(
dim_model=dim_model,
n_head=n_head,
dim_feedforward=dim_feedforward,
n_layers=n_layers,
dropout=dropout,
dim_intensity=dim_intensity,
)
self.encoder = SpectrumEncoder(
dim_model=dim_model,
n_head=n_head,
dim_feedforward=dim_feedforward,
n_layers=n_layers,
dropout=dropout,
dim_intensity=dim_intensity,
)
self.decoder = PeptideDecoder(
dim_model=dim_model,
n_head=n_head,
Expand All @@ -142,7 +138,10 @@ def __init__(
max_charge=max_charge,
)
self.softmax = torch.nn.Softmax(2)
self.celoss = torch.nn.CrossEntropyLoss(ignore_index=0)
self.celoss = torch.nn.CrossEntropyLoss(
ignore_index=0, label_smoothing=train_label_smoothing
)
self.val_celoss = torch.nn.CrossEntropyLoss(ignore_index=0)
# Optimizer settings.
self.lr_schedule = lr_schedule
self.warmup_iters = warmup_iters
Expand Down Expand Up @@ -732,7 +731,10 @@ def training_step(
"""
pred, truth = self._forward_step(*batch)
pred = pred[:, :-1, :].reshape(-1, self.decoder.vocab_size + 1)
loss = self.celoss(pred, truth.flatten())
if mode == "train":
loss = self.celoss(pred, truth.flatten())
else:
loss = self.val_celoss(pred, truth.flatten())
self.log(
f"{mode}_CELoss",
loss.detach(),
Expand Down Expand Up @@ -760,9 +762,7 @@ def validation_step(
The loss of the validation step.
"""
# Record the loss.
# FIXME: Temporary workaround to avoid the NaN bug.
with torch.set_grad_enabled(True):
loss = self.training_step(batch, mode="valid")
loss = self.training_step(batch, mode="valid")
if not self.calculate_precision:
return loss

Expand All @@ -775,8 +775,8 @@ def validation_step(

aa_precision, _, pep_precision = evaluate.aa_match_metrics(
*evaluate.aa_match_batch(
peptides_pred,
peptides_true,
peptides_pred,
self.decoder._peptide_mass.masses,
)
)
Expand Down Expand Up @@ -813,30 +813,28 @@ def predict_step(
and amino acid-level confidence scores.
"""
predictions = []
# FIXME: Temporary workaround to avoid the NaN bug.
with torch.set_grad_enabled(True):
for (
precursor_charge,
precursor_mz,
spectrum_i,
spectrum_preds,
) in zip(
batch[1][:, 1].cpu().detach().numpy(),
batch[1][:, 2].cpu().detach().numpy(),
batch[2],
self.forward(batch[0], batch[1]),
):
for peptide_score, aa_scores, peptide in spectrum_preds:
predictions.append(
(
spectrum_i,
precursor_charge,
precursor_mz,
peptide,
peptide_score,
aa_scores,
)
for (
precursor_charge,
precursor_mz,
spectrum_i,
spectrum_preds,
) in zip(
batch[1][:, 1].cpu().detach().numpy(),
batch[1][:, 2].cpu().detach().numpy(),
batch[2],
self.forward(batch[0], batch[1]),
):
for peptide_score, aa_scores, peptide in spectrum_preds:
predictions.append(
(
spectrum_i,
precursor_charge,
precursor_mz,
peptide,
peptide_score,
aa_scores,
)
)

return predictions

Expand Down
3 changes: 2 additions & 1 deletion casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
self.trainer = None
self.model = None
self.loaders = None

self.writer = None

# Configure checkpoints.
Expand Down Expand Up @@ -212,7 +213,6 @@ def initialize_model(self, train: bool) -> None:
n_layers=self.config.n_layers,
dropout=self.config.dropout,
dim_intensity=self.config.dim_intensity,
custom_encoder=self.config.custom_encoder,
max_length=self.config.max_length,
residues=self.config.residues,
max_charge=self.config.max_charge,
Expand All @@ -222,6 +222,7 @@ def initialize_model(self, train: bool) -> None:
top_match=self.config.top_match,
n_log=self.config.n_log,
tb_summarywriter=self.config.tb_summarywriter,
train_label_smoothing=self.config.train_label_smoothing,
lr_schedule=self.config.lr_schedule,
warmup_iters=self.config.warmup_iters,
max_iters=self.config.max_iters,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"scikit-learn",
"spectrum_utils",
"tensorboard",
"torch>=2.0",
"torch>=2.1",
"tqdm",
]
dynamic = ["version"]
Expand Down

0 comments on commit 7358564

Please sign in to comment.