diff --git a/casanovo/config.py b/casanovo/config.py index a611fc96..68d83fc6 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -54,6 +54,7 @@ class Config: n_log=int, tb_summarywriter=str, train_label_smoothing=float, + lr_schedule=str, warmup_iters=int, max_iters=int, learning_rate=float, diff --git a/casanovo/config.yaml b/casanovo/config.yaml index 896f67bc..3615bffb 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -81,6 +81,8 @@ dropout: 0.0 dim_intensity: # Max decoded peptide length max_length: 100 +# Type of learning rate schedule to use. One of {constant, linear, cosine}. +lr_schedule: "cosine" # Number of warmup iterations for learning rate scheduler warmup_iters: 100_000 # Max number of iterations for learning rate scheduler diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 39d2027a..211dae67 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -109,6 +109,7 @@ def __init__( torch.utils.tensorboard.SummaryWriter ] = None, train_label_smoothing: float = 0.01, + lr_schedule=None, warmup_iters: int = 100_000, max_iters: int = 600_000, out_writer: Optional[ms_io.MztabWriter] = None, @@ -142,6 +143,7 @@ def __init__( ) self.val_celoss = torch.nn.CrossEntropyLoss(ignore_index=0) # Optimizer settings. + self.lr_schedule = lr_schedule self.warmup_iters = warmup_iters self.max_iters = max_iters self.opt_kwargs = kwargs @@ -951,7 +953,7 @@ def configure_optimizers( self, ) -> Tuple[torch.optim.Optimizer, Dict[str, Any]]: """ - Initialize the optimizer. + Initialize the optimizer and lr_scheduler. This is used by pytorch-lightning when preparing the model for training. @@ -961,16 +963,34 @@ def configure_optimizers( The initialized Adam optimizer and its learning rate scheduler. """ optimizer = torch.optim.Adam(self.parameters(), **self.opt_kwargs) + # Add linear learning rate scheduler for warmup + lr_schedulers = [ + torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=1e-10, total_iters=self.warmup_iters + ) + ] + if self.lr_schedule == "cosine": + lr_schedulers.append( + CosineScheduler(optimizer, max_iters=self.max_iters) + ) + elif self.lr_schedule == "linear": + lr_schedulers.append( + torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=0, + total_iters=self.max_iters, + ) + ) + # Combine learning rate schedulers + lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler(lr_schedulers) # Apply learning rate scheduler per step. - lr_scheduler = CosineWarmupScheduler( - optimizer, warmup=self.warmup_iters, max_iters=self.max_iters - ) return [optimizer], {"scheduler": lr_scheduler, "interval": "step"} -class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): +class CosineScheduler(torch.optim.lr_scheduler._LRScheduler): """ - Learning rate scheduler with linear warm up followed by cosine shaped decay. + Learning rate scheduler with cosine shaped decay. Parameters ---------- @@ -982,10 +1002,8 @@ class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): The total number of iterations. """ - def __init__( - self, optimizer: torch.optim.Optimizer, warmup: int, max_iters: int - ): - self.warmup, self.max_iters = warmup, max_iters + def __init__(self, optimizer: torch.optim.Optimizer, max_iters: int): + self.max_iters = max_iters super().__init__(optimizer) def get_lr(self): @@ -994,8 +1012,6 @@ def get_lr(self): def get_lr_factor(self, epoch): lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_iters)) - if epoch <= self.warmup: - lr_factor *= epoch / self.warmup return lr_factor diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 296ce90a..9b05a909 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -223,6 +223,7 @@ def initialize_model(self, train: bool) -> None: n_log=self.config.n_log, tb_summarywriter=self.config.tb_summarywriter, train_label_smoothing=self.config.train_label_smoothing, + lr_schedule=self.config.lr_schedule, warmup_iters=self.config.warmup_iters, max_iters=self.config.max_iters, lr=self.config.learning_rate, diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index efa89a05..54c96782 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -535,3 +535,45 @@ def test_run_map(mgf_small): out_writer.set_ms_run([os.path.abspath(mgf_small.name)]) assert os.path.basename(mgf_small.name) not in out_writer._run_map assert os.path.abspath(mgf_small.name) in out_writer._run_map + + +def test_lr_schedule(): + """Test that the learning rate schedule is setup correctly""" + + # Constant lr schedule is setup correctly + model = Spec2Pep(lr_schedule="constant") + assert model.lr_schedule is not None + assert model.lr_schedule == "constant" + + # Learning rate schedule is applied every step rather than epoch + _, schedule = model.configure_optimizers() + assert schedule["interval"] == "step" + + # Constant lr schedule includes only a warmup period + schedulers = schedule["scheduler"].state_dict()["_schedulers"] + assert len(schedulers) == 1 + + # Default warmup period lasts correct number of iters + assert schedulers[0]["start_factor"] == 1e-10 + assert schedulers[0]["end_factor"] == 1 + assert schedulers[0]["total_iters"] == 100000 + + # Linear lr schedule with custom warmup and max iters + model = Spec2Pep(lr_schedule="linear", warmup_iters=10, max_iters=100) + _, schedule = model.configure_optimizers() + schedulers = schedule["scheduler"].state_dict()["_schedulers"] + + assert len(schedulers) == 2 + assert schedulers[0]["total_iters"] == 10 + + assert schedulers[1]["start_factor"] == 1 + assert schedulers[1]["end_factor"] == 0 + assert schedulers[1]["total_iters"] == 100 + + # Cosine lr schedule + model = Spec2Pep(lr_schedule="cosine") + _, schedule = model.configure_optimizers() + schedulers = schedule["scheduler"].state_dict()["_schedulers"] + + assert len(schedulers) == 2 + assert schedulers[1]["_last_lr"][0] == [0.001]