-
Notifications
You must be signed in to change notification settings - Fork 1
/
set_opt.py
48 lines (38 loc) · 1.74 KB
/
set_opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.optim as optim
def setup_optimizer(model, lr, weight_decay, epochs):
"""
S4 requires a specific optimizer setup.
The S4 layer (A, B, C, dt) parameters typically
require a smaller learning rate (typically 0.001), with no weight decay.
The rest of the model can be trained with a higher learning rate (e.g. 0.004, 0.01)
and weight decay (if desired).
"""
# All parameters in the model
all_parameters = list(model.parameters())
# General parameters don't contain the special _optim key
params = [p for p in all_parameters if not hasattr(p, "_optim")]
# Create an optimizer with the general parameters
optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay)
# Add parameters with special hyperparameters
hps = [getattr(p, "_optim") for p in all_parameters if hasattr(p, "_optim")]
hps = [
dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
] # Unique dicts
for hp in hps:
params = [p for p in all_parameters if getattr(p, "_optim", None) == hp]
optimizer.add_param_group(
{"params": params, **hp}
)
# Create a lr scheduler
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=0.2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
# Print optimizer info
keys = sorted(set([k for hp in hps for k in hp.keys()]))
for i, g in enumerate(optimizer.param_groups):
group_hps = {k: g.get(k, None) for k in keys}
print(' | '.join([
f"Optimizer group {i}",
f"{len(g['params'])} tensors",
] + [f"{k} {v}" for k, v in group_hps.items()]))
return optimizer, scheduler