-
Notifications
You must be signed in to change notification settings - Fork 12
/
train_net.py
108 lines (92 loc) · 3.66 KB
/
train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import torch
import torch.distributed as dist
import torch.multiprocessing
from lib.config import cfg, logger
from lib.networks import make_network
from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler
from lib.datasets import make_data_loader
from lib.utils.net_utils import load_model, save_model, save_trained_config, load_pretrain
from lib.evaluators import make_evaluator
from lib.utils.msg_utils import send_msg
torch.backends.cudnn.benchmark = False
# torch.autograd.set_detect_anomaly(True)
if cfg.fix_random:
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train(cfg, network):
train_loader = make_data_loader(cfg,
is_train=True,
is_distributed=cfg.distributed,
max_iter=cfg.ep_iter)
if cfg.skip_eval:
val_loader = None
else:
val_loader = make_data_loader(cfg, is_train=False)
trainer = make_trainer(cfg, network, train_loader)
optimizer = make_optimizer(cfg, network)
scheduler = make_lr_scheduler(cfg, optimizer)
recorder = make_recorder(cfg)
evaluator = make_evaluator(cfg)
begin_epoch = load_model(network,
optimizer,
scheduler,
recorder,
cfg.trained_model_dir,
resume=cfg.resume)
if begin_epoch == 0 and cfg.pretrain != '':
load_pretrain(network, cfg.pretrain)
set_lr_scheduler(cfg, scheduler)
save_trained_config(cfg)
for epoch in range(begin_epoch, cfg.train.epoch):
recorder.epoch = epoch
if cfg.distributed:
train_loader.batch_sampler.sampler.set_epoch(epoch)
train_loader.dataset.epoch = epoch
trainer.train(epoch, train_loader, optimizer, recorder, scheduler)
# scheduler.step()
if (epoch + 1) % cfg.save_ep == 0 and cfg.local_rank == 0:
save_model(network, optimizer, scheduler, recorder,
cfg.trained_model_dir, epoch)
if (epoch + 1) % cfg.save_latest_ep == 0 and cfg.local_rank == 0:
save_model(network,
optimizer,
scheduler,
recorder,
cfg.trained_model_dir,
epoch,
last=True)
if not cfg.skip_eval and (epoch + 1) % cfg.eval_ep == 0 and cfg.local_rank == 0:
trainer.val(epoch, val_loader, evaluator, recorder)
return network
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def main():
if cfg.distributed:
cfg.local_rank = int(os.environ['RANK']) % torch.cuda.device_count()
torch.cuda.set_device(cfg.local_rank)
torch.distributed.init_process_group(backend="nccl",
init_method="env://")
synchronize()
network = make_network(cfg)
send_msg('Training start: ' + cfg.exp_name + ' Machine: ' + cfg.machine, cfg)
train(cfg, network)
send_msg('Training success: ' + cfg.exp_name + ' Machine: ' + cfg.machine, cfg)
if cfg.local_rank == 0:
logger.info('Success!')
logger.info('='*80)
os.system('kill -9 {}'.format(os.getpid()))
if __name__ == "__main__":
main()