-
Notifications
You must be signed in to change notification settings - Fork 1
/
common.py
251 lines (194 loc) · 9.23 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import random
import os
import numpy as np
from typing import Callable, Any
from copy import deepcopy
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero import GatheredParameters
from deepspeed.git_version_info import torch_info
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
class EnableDeterminism:
def __init__(self, seed: int):
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.seed = seed + local_rank
self.saved_random_state = None
self.saved_np_random_state = None
self.saved_cuda_launch_blocking = None
self.saved_cublas_workspace_config = None
self.saved_deterministic_algorithms = None
def __enter__(self):
self.saved_random_state = random.getstate()
self.saved_np_random_state = np.random.get_state()
self.saved_acc_rng_state = get_accelerator().get_rng_state()
self.saved_cuda_launch_blocking = os.environ.get("CUDA_LAUNCH_BLOCKING", "")
self.saved_cublas_workspace_config = os.environ.get("CUBLAS_WORKSPACE_CONFIG", "")
self.saved_deterministic_algorithms = torch.are_deterministic_algorithms_enabled()
random.seed(self.seed)
np.random.seed(self.seed)
get_accelerator().manual_seed(self.seed)
get_accelerator().manual_seed_all(self.seed)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)
# Enable CUDNN deterministic mode
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def __exit__(self, type, value, traceback):
random.setstate(self.saved_random_state)
np.random.set_state(self.saved_np_random_state)
get_accelerator().set_rng_state(self.saved_acc_rng_state)
os.environ["CUDA_LAUNCH_BLOCKING"] = self.saved_cuda_launch_blocking
os.environ["CUBLAS_WORKSPACE_CONFIG"] = self.saved_cublas_workspace_config
torch.use_deterministic_algorithms(self.saved_deterministic_algorithms)
def enable_determinism(seed: int):
def decorator(func: Callable) -> Callable:
def wrapper(*args: Any, **kwargs: Any):
with EnableDeterminism(seed):
return func(*args, **kwargs)
return wrapper
return decorator
def bf16_required_version_check(accelerator_check=True):
split_version = lambda x: map(int, x.split('.')[:2])
TORCH_MAJOR, TORCH_MINOR = split_version(torch_info['version'])
NCCL_MAJOR, NCCL_MINOR = split_version(torch_info['nccl_version'])
CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version'])
# Sometimes bf16 tests are runnable even if not natively supported by accelerator
if accelerator_check:
accelerator_pass = get_accelerator().is_bf16_supported()
else:
accelerator_pass = True
torch_version_available = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
cuda_version_available = CUDA_MAJOR >= 11
nccl_version_available = NCCL_MAJOR > 2 or (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)
npu_available = get_accelerator().device_name() == 'npu'
if torch_version_available and cuda_version_available and nccl_version_available and accelerator_pass:
return True
elif npu_available:
return True
else:
return False
def train_amp(baseline_model,
baseline_optimizer,
target_engine,
dtype,
scaler,
x, y,
rtol, atol):
# Runs the forward pass with autocasting.
with torch.autocast(device_type="cuda", dtype=dtype):
baseline_optimizer.zero_grad()
baseline_loss = baseline_model(x, y)
scaler.scale(baseline_loss).backward()
scaler.step(baseline_optimizer)
scaler.update()
target_loss = target_engine(x.to(dtype), y.to(dtype))
assert torch.allclose(baseline_loss.half(), target_loss, rtol=rtol, atol=atol)
target_engine.backward(target_loss)
target_engine.step()
with GatheredParameters(target_engine.parameters()):
for p1, p2 in zip(baseline_model.parameters(), target_engine.parameters()):
assert torch.allclose(p1.half(), p2, rtol=rtol, atol=atol)
def train_no_amp(baseline_model,
baseline_optimizer,
target_engine,
x, y,
rtol, atol):
baseline_loss = baseline_model(x, y)
target_loss = target_engine(x, y)
assert torch.allclose(baseline_loss, target_loss, rtol=rtol, atol=atol)
with GatheredParameters(target_engine.parameters()):
for p1, p2 in zip(baseline_model.parameters(), target_engine.parameters()):
assert torch.allclose(p1, p2, rtol=rtol, atol=atol)
baseline_loss.backward()
target_engine.backward(target_loss)
for p1, p2 in zip(baseline_model.parameters(), target_engine.parameters()):
g2 = deepspeed.utils.safe_get_full_grad(p2)
assert torch.allclose(p1, p2, rtol=rtol, atol=atol)
assert torch.allclose(p1.grad, g2, rtol=rtol, atol=atol)
baseline_optimizer.step()
target_engine.step()
baseline_model.zero_grad()
with GatheredParameters(target_engine.parameters()):
for p1, p2 in zip(baseline_model.parameters(), target_engine.parameters()):
assert torch.allclose(p1, p2, rtol=rtol, atol=atol)
@enable_determinism(123)
def compare_loss(args, model_cls, rtol=1e-2, atol=1e-2):
iteration = 5
hidden_dim = 10
dtype = eval(args.dtype)
zero_stage = args.zero_stage
offload_device = eval(f"OffloadDeviceEnum.{args.offload_device}")
get_accelerator().set_device(args.local_rank)
if dtype == torch.bfloat16 and not bf16_required_version_check():
raise ValueError("DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly")
if offload_device == OffloadDeviceEnum.nvme:
if zero_stage != 3:
raise ValueError(f"Nvme offload not supported for zero stage {zero_stage}")
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.015
},
},
"zero_optimization": {
"stage": zero_stage,
},
}
if offload_device == OffloadDeviceEnum.cpu:
config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device}
elif offload_device == OffloadDeviceEnum.nvme:
tmpdir = os.getcwd()
config_dict["zero_optimization"]["offload_optimizer"] = {
"device": offload_device,
"nvme_path": str(tmpdir)
}
if dtype == torch.float16:
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
elif dtype == torch.bfloat16:
config_dict["bf16"] = {"enabled": True}
device = torch.device(get_accelerator().current_device_name())
model = model_cls(hidden_dim)
deepspeed.init_distributed(dist_backend='nccl')
i = get_accelerator().current_device()
lr = config_dict["optimizer"]["params"]["lr"]
baseline_model = DDP(deepcopy(model).to(device=device, dtype=torch.float32), device_ids=[i], output_device=i)
baseline_optimizer = torch.optim.AdamW(baseline_model.parameters(), lr=lr, weight_decay=0.0)
use_amp = dtype != torch.float32
scaler = GradScaler() if use_amp else None
stage_3_enabled = config_dict["zero_optimization"]["stage"] == 3
if stage_3_enabled:
with deepspeed.zero.Init(config_dict_or_path=config_dict):
target_model = model_cls(hidden_dim)
with GatheredParameters(target_model.parameters(), modifier_rank=0):
for p1, p2 in zip(target_model.parameters(), model.parameters()):
p1.data.copy_(p2.data)
else:
target_model = deepcopy(model)
if args.use_torch_adam:
ds_optimizer = torch.optim.Adam(target_model.parameters(), lr=lr)
del config_dict["optimizer"]
target_engine, _, _, _ = deepspeed.initialize(config=config_dict,
model=target_model,
optimizer=ds_optimizer)
else:
target_engine, _, _, _ = deepspeed.initialize(config=config_dict,
model=target_model,
model_parameters=target_model.parameters())
train_batch_size = config_dict["train_micro_batch_size_per_gpu"]
xs = [torch.randn(train_batch_size, hidden_dim, device=device, dtype=torch.float32) for _ in range(iteration)]
ys = [torch.randn_like(x) for x in xs]
for i, (x, y) in enumerate(zip(xs, ys)):
if use_amp:
train_amp(baseline_model, baseline_optimizer, target_engine, dtype, scaler, x, y, rtol, atol)
else:
train_no_amp(baseline_model, baseline_optimizer, target_engine, x, y, rtol, atol)