-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
146 lines (117 loc) · 4.53 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
PyTorch Lightning example code, designed for use in TU Delft CV lab.
Copyright (c) 2022 Robert-Jan Bruintjes, TU Delft.
"""
# Package imports, from conda or pip
import os
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from omegaconf import OmegaConf
import torchmetrics
# Imports of own files
import model_factory
import dataset_factory
class Runner(pl.LightningModule):
def __init__(self, cfg, model):
super().__init__()
self.cfg = cfg
self.model = model
self.loss_fn = nn.CrossEntropyLoss()
self.train_accuracy = torchmetrics.Accuracy()
self.val_accuracy = torchmetrics.Accuracy()
self.test_accuracy = torchmetrics.Accuracy()
def forward(self, x):
# Runner needs to redirect any model.forward() calls to the actual
# network
return self.model(x)
def configure_optimizers(self):
if self.cfg.optimize.optimizer == 'Adam':
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.cfg.optimize.lr)
else:
raise NotImplementedError(f"Optimizer {self.cfg.optimizer}")
return optimizer
def _step(self, batch):
x, y = batch
y_hat = self.model(x)
loss = self.loss_fn(y_hat, y)
return loss, y_hat
def training_step(self, batch, batch_idx):
loss, y_hat = self._step(batch)
preds = torch.argmax(y_hat, dim=1)
self.train_accuracy(preds, batch[1])
# Log step-level loss & accuracy
self.log("train/loss_step", loss)
self.log("train/acc_step", self.train_accuracy)
return loss
def validation_step(self, batch, batch_idx):
loss, y_hat = self._step(batch)
preds = torch.argmax(y_hat, dim=1)
self.val_accuracy(preds, batch[1])
# Log step-level loss & accuracy
self.log("val/loss_step", loss)
self.log("val/acc_step", self.val_accuracy)
return loss
def test_step(self, batch, batch_idx):
loss, y_hat = self._step(batch)
preds = torch.argmax(y_hat, dim=1)
self.test_accuracy(preds, batch[1])
# Log test loss
self.log("test/loss", loss)
self.log('test/acc', self.test_accuracy)
return loss
def on_train_epoch_end(self):
# Log the epoch-level training accuracy
self.log('train/acc', self.train_accuracy.compute())
self.train_accuracy.reset()
def on_validation_epoch_end(self):
# Log the epoch-level validation accuracy
self.log('val/acc', self.val_accuracy.compute())
self.val_accuracy.reset()
def main():
# Load defaults and overwrite by command-line arguments
cfg = OmegaConf.load("config.yaml")
cmd_cfg = OmegaConf.from_cli()
cfg = OmegaConf.merge(cfg, cmd_cfg)
print(OmegaConf.to_yaml(cfg))
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Seed everything. Note that this does not make training entirely
# deterministic.
pl.seed_everything(cfg.seed, workers=True)
# Set cache dir to W&B logging directory
os.environ["WANDB_CACHE_DIR"] = os.path.join(cfg.wandb.dir, 'cache')
wandb_logger = WandbLogger(
save_dir=cfg.wandb.dir,
project=cfg.wandb.project,
name=cfg.wandb.experiment_name,
log_model='all' if cfg.wandb.log else None,
offline=not cfg.wandb.log,
# Keyword args passed to wandb.init()
entity=cfg.wandb.entity,
config=OmegaConf.to_object(cfg),
)
# Create model using factory pattern
model = model_factory.factory(cfg)
# Create datasets using factory pattern
loaders = dataset_factory.factory(cfg)
train_dataset_loader, val_dataset_loader, test_dataset_loader = loaders
# Tie it all together with PyTorch Lightning: Runner contains the model,
# optimizer, loss function and metrics; Trainer executes the
# training/validation loops and model checkpointing.
runner = Runner(cfg, model)
trainer = pl.Trainer(
max_epochs=cfg.train.epochs,
logger=wandb_logger,
# Use DDP training by default, even for CPU training
strategy="ddp_find_unused_parameters_false",
gpus=torch.cuda.device_count(),
)
# Train + validate (if validation dataset is implemented)
trainer.fit(runner, train_dataset_loader, val_dataset_loader)
# Test (if test dataset is implemented)
if test_dataset_loader is not None:
trainer.test(runner, test_dataset_loader)
if __name__ == '__main__':
main()