-
Notifications
You must be signed in to change notification settings - Fork 6
/
slurm.py
171 lines (138 loc) · 6.2 KB
/
slurm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Copyright (c) 2020-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from logging import getLogger
import os
import sys
import torch
import socket
import signal
import subprocess
logger = getLogger()
def sig_handler(signum, frame):
logger.warning("Signal handler called with signal " + str(signum))
prod_id = int(os.environ['SLURM_PROCID'])
logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id))
if prod_id == 0:
logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID'])
os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID'])
else:
logger.warning("Not the master process, no need to requeue.")
sys.exit(-1)
def term_handler(signum, frame):
logger.warning("Signal handler called with signal " + str(signum))
logger.warning("Bypassing SIGTERM.")
def init_signal_handler():
"""
Handle signals sent by SLURM for time limit / pre-emption.
"""
signal.signal(signal.SIGUSR1, sig_handler)
signal.signal(signal.SIGTERM, term_handler)
logger.warning("Signal handler installed.")
def init_distributed_mode(params):
"""
Handle single and multi-GPU / multi-node / SLURM jobs.
Initialize the following variables:
- n_nodes
- node_id
- local_rank
- global_rank
- world_size
"""
params.is_slurm_job = 'SLURM_JOB_ID' in os.environ and not params.debug_slurm
print("SLURM job: %s" % str(params.is_slurm_job))
# SLURM job
if params.is_slurm_job:
assert params.local_rank == -1 # on the cluster, this is handled by SLURM
SLURM_VARIABLES = [
'SLURM_JOB_ID',
'SLURM_JOB_NODELIST', 'SLURM_JOB_NUM_NODES', 'SLURM_NTASKS', 'SLURM_TASKS_PER_NODE',
'SLURM_MEM_PER_NODE', 'SLURM_MEM_PER_CPU',
'SLURM_NODEID', 'SLURM_PROCID', 'SLURM_LOCALID', 'SLURM_TASK_PID'
]
PREFIX = "%i - " % int(os.environ['SLURM_PROCID'])
for name in SLURM_VARIABLES:
value = os.environ.get(name, None)
print(PREFIX + "%s: %s" % (name, str(value)))
# # job ID
# params.job_id = os.environ['SLURM_JOB_ID']
# number of nodes / node ID
params.n_nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
params.node_id = int(os.environ['SLURM_NODEID'])
# local rank on the current node / global rank
params.local_rank = int(os.environ['SLURM_LOCALID'])
params.global_rank = int(os.environ['SLURM_PROCID'])
# number of processes / GPUs per node
params.world_size = int(os.environ['SLURM_NTASKS'])
params.n_gpu_per_node = params.world_size // params.n_nodes
# define master address and master port
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']])
params.master_addr = hostnames.split()[0].decode('utf-8')
assert 10001 <= params.master_port <= 20000 or params.world_size == 1
print(PREFIX + "Master address: %s" % params.master_addr)
print(PREFIX + "Master port : %i" % params.master_port)
# set environment variables for 'env://'
os.environ['MASTER_ADDR'] = params.master_addr
os.environ['MASTER_PORT'] = str(params.master_port)
os.environ['WORLD_SIZE'] = str(params.world_size)
os.environ['RANK'] = str(params.global_rank)
# multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch
elif params.local_rank != -1:
assert params.master_port == -1
# read environment variables
params.global_rank = int(os.environ['RANK'])
params.world_size = int(os.environ['WORLD_SIZE'])
params.n_gpu_per_node = int(os.environ['NGPU'])
# number of nodes / node ID
params.n_nodes = params.world_size // params.n_gpu_per_node
params.node_id = params.global_rank // params.n_gpu_per_node
# local job (single GPU)
else:
assert params.local_rank == -1
assert params.master_port == -1
params.n_nodes = 1
params.node_id = 0
params.local_rank = 0
params.global_rank = 0
params.world_size = 1
params.n_gpu_per_node = 1
# sanity checks
assert params.n_nodes >= 1
assert 0 <= params.node_id < params.n_nodes
assert 0 <= params.local_rank <= params.global_rank < params.world_size
assert params.world_size == params.n_nodes * params.n_gpu_per_node
# define whether this is the master process / if we are in distributed mode
params.is_master = params.node_id == 0 and params.local_rank == 0
params.multi_node = params.n_nodes > 1
params.multi_gpu = params.world_size > 1
# summary
PREFIX = "%i - " % params.global_rank
print(PREFIX + "Number of nodes: %i" % params.n_nodes)
print(PREFIX + "Node ID : %i" % params.node_id)
print(PREFIX + "Local rank : %i" % params.local_rank)
print(PREFIX + "Global rank : %i" % params.global_rank)
print(PREFIX + "World size : %i" % params.world_size)
print(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node)
print(PREFIX + "Master : %s" % str(params.is_master))
print(PREFIX + "Multi-node : %s" % str(params.multi_node))
print(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu))
print(PREFIX + "Hostname : %s" % socket.gethostname())
# set GPU device
if not params.cpu:
torch.cuda.set_device(params.local_rank)
# initialize multi-GPU
if params.multi_gpu:
# http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization
# 'env://' will read these environment variables:
# MASTER_PORT - required; has to be a free port on machine with rank 0
# MASTER_ADDR - required (except for rank 0); address of rank 0 node
# WORLD_SIZE - required; can be set either here, or in a call to init function
# RANK - required; can be set either here, or in a call to init function
print("Initializing PyTorch distributed ...")
torch.distributed.init_process_group(
init_method='env://',
backend='nccl',
)