Skip to content

Commit

Permalink
Merge branch 'add_lr_schedule_options' into label-smoothing
Browse files Browse the repository at this point in the history
  • Loading branch information
melihyilmaz authored Nov 2, 2023
2 parents b044bc6 + 5557d97 commit 5716c7a
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 12 deletions.
1 change: 1 addition & 0 deletions casanovo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Config:
n_log=int,
tb_summarywriter=str,
train_label_smoothing=float,
lr_schedule=str,
warmup_iters=int,
max_iters=int,
learning_rate=float,
Expand Down
2 changes: 2 additions & 0 deletions casanovo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ dropout: 0.0
dim_intensity:
# Max decoded peptide length
max_length: 100
# Type of learning rate schedule to use. One of {constant, linear, cosine}.
lr_schedule: "cosine"
# Number of warmup iterations for learning rate scheduler
warmup_iters: 100_000
# Max number of iterations for learning rate scheduler
Expand Down
40 changes: 28 additions & 12 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
torch.utils.tensorboard.SummaryWriter
] = None,
train_label_smoothing: float = 0.01,
lr_schedule=None,
warmup_iters: int = 100_000,
max_iters: int = 600_000,
out_writer: Optional[ms_io.MztabWriter] = None,
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
)
self.val_celoss = torch.nn.CrossEntropyLoss(ignore_index=0)
# Optimizer settings.
self.lr_schedule = lr_schedule
self.warmup_iters = warmup_iters
self.max_iters = max_iters
self.opt_kwargs = kwargs
Expand Down Expand Up @@ -951,7 +953,7 @@ def configure_optimizers(
self,
) -> Tuple[torch.optim.Optimizer, Dict[str, Any]]:
"""
Initialize the optimizer.
Initialize the optimizer and lr_scheduler.
This is used by pytorch-lightning when preparing the model for training.
Expand All @@ -961,16 +963,34 @@ def configure_optimizers(
The initialized Adam optimizer and its learning rate scheduler.
"""
optimizer = torch.optim.Adam(self.parameters(), **self.opt_kwargs)
# Add linear learning rate scheduler for warmup
lr_schedulers = [
torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=1e-10, total_iters=self.warmup_iters
)
]
if self.lr_schedule == "cosine":
lr_schedulers.append(
CosineScheduler(optimizer, max_iters=self.max_iters)
)
elif self.lr_schedule == "linear":
lr_schedulers.append(
torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1,
end_factor=0,
total_iters=self.max_iters,
)
)
# Combine learning rate schedulers
lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler(lr_schedulers)
# Apply learning rate scheduler per step.
lr_scheduler = CosineWarmupScheduler(
optimizer, warmup=self.warmup_iters, max_iters=self.max_iters
)
return [optimizer], {"scheduler": lr_scheduler, "interval": "step"}


class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
class CosineScheduler(torch.optim.lr_scheduler._LRScheduler):
"""
Learning rate scheduler with linear warm up followed by cosine shaped decay.
Learning rate scheduler with cosine shaped decay.
Parameters
----------
Expand All @@ -982,10 +1002,8 @@ class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
The total number of iterations.
"""

def __init__(
self, optimizer: torch.optim.Optimizer, warmup: int, max_iters: int
):
self.warmup, self.max_iters = warmup, max_iters
def __init__(self, optimizer: torch.optim.Optimizer, max_iters: int):
self.max_iters = max_iters
super().__init__(optimizer)

def get_lr(self):
Expand All @@ -994,8 +1012,6 @@ def get_lr(self):

def get_lr_factor(self, epoch):
lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_iters))
if epoch <= self.warmup:
lr_factor *= epoch / self.warmup
return lr_factor


Expand Down
1 change: 1 addition & 0 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def initialize_model(self, train: bool) -> None:
n_log=self.config.n_log,
tb_summarywriter=self.config.tb_summarywriter,
train_label_smoothing=self.config.train_label_smoothing,
lr_schedule=self.config.lr_schedule,
warmup_iters=self.config.warmup_iters,
max_iters=self.config.max_iters,
lr=self.config.learning_rate,
Expand Down
42 changes: 42 additions & 0 deletions tests/unit_tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,45 @@ def test_run_map(mgf_small):
out_writer.set_ms_run([os.path.abspath(mgf_small.name)])
assert os.path.basename(mgf_small.name) not in out_writer._run_map
assert os.path.abspath(mgf_small.name) in out_writer._run_map


def test_lr_schedule():
"""Test that the learning rate schedule is setup correctly"""

# Constant lr schedule is setup correctly
model = Spec2Pep(lr_schedule="constant")
assert model.lr_schedule is not None
assert model.lr_schedule == "constant"

# Learning rate schedule is applied every step rather than epoch
_, schedule = model.configure_optimizers()
assert schedule["interval"] == "step"

# Constant lr schedule includes only a warmup period
schedulers = schedule["scheduler"].state_dict()["_schedulers"]
assert len(schedulers) == 1

# Default warmup period lasts correct number of iters
assert schedulers[0]["start_factor"] == 1e-10
assert schedulers[0]["end_factor"] == 1
assert schedulers[0]["total_iters"] == 100000

# Linear lr schedule with custom warmup and max iters
model = Spec2Pep(lr_schedule="linear", warmup_iters=10, max_iters=100)
_, schedule = model.configure_optimizers()
schedulers = schedule["scheduler"].state_dict()["_schedulers"]

assert len(schedulers) == 2
assert schedulers[0]["total_iters"] == 10

assert schedulers[1]["start_factor"] == 1
assert schedulers[1]["end_factor"] == 0
assert schedulers[1]["total_iters"] == 100

# Cosine lr schedule
model = Spec2Pep(lr_schedule="cosine")
_, schedule = model.configure_optimizers()
schedulers = schedule["scheduler"].state_dict()["_schedulers"]

assert len(schedulers) == 2
assert schedulers[1]["_last_lr"][0] == [0.001]

0 comments on commit 5716c7a

Please sign in to comment.