-
Notifications
You must be signed in to change notification settings - Fork 20
/
launcher.py
43 lines (36 loc) · 1.3 KB
/
launcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from model.model import captcha_model, model_conv, model_resnet
from data.datamodule import captcha_dm
import pytorch_lightning as pl
import torch.optim as optim
import torch
import os
from utils.config_util import configGetter
from utils.arg_parsers import train_arg_parser
cfg = configGetter('SOLVER')
lr = cfg['LR']
batch_size = cfg['BATCH_SIZE']
epoch = cfg['EPOCH']
def main(arg):
pl.seed_everything(42)
m = model_resnet()
model = captcha_model(
model=m, lr=lr)
dm = captcha_dm(batch_size=batch_size)
tb_logger = pl.loggers.TensorBoardLogger(
args.log_dir, name=args.exp_name, version=2, default_hp_metric=False)
trainer = pl.Trainer(deterministic=True,
gpus=args.gpus,
auto_select_gpus=True,
precision=32,
logger=tb_logger,
fast_dev_run=False,
max_epochs=epoch,
log_every_n_steps=50,
stochastic_weight_avg=True
)
trainer.fit(model, datamodule=dm)
os.makedirs(args.save_path, exist_ok=True)
trainer.save_checkpoint(os.path.join(args.save_path, 'model.pth'))
if __name__ == "__main__":
args = train_arg_parser()
main(args)