-
Notifications
You must be signed in to change notification settings - Fork 31
/
fsdp_cifar10_example.py
152 lines (125 loc) · 5.34 KB
/
fsdp_cifar10_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""
import logging
import os
import torch.distributed as dist
from distributed_shampoo import (
compile_fsdp_parameter_metadata,
FSDPShampooConfig,
PrecisionConfig,
)
from distributed_shampoo.examples.trainer_utils import (
get_data_loader_and_sampler,
get_model_and_loss_fn,
instantiate_optimizer,
Parser,
set_seed,
setup_distribution,
train_model,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
logging.basicConfig(
format="[%(filename)s:%(lineno)d] %(levelname)s: %(message)s",
level=logging.DEBUG,
)
logger = logging.getLogger(__name__)
# for reproducibility, set environmental variable for CUBLAS
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# get local and world rank and world size
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_RANK = int(os.environ["RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
if __name__ == "__main__":
"""Multi-GPU CIFAR-10 Distributed Data Parallel Training Example Script
Uses torch.distributed to launch distributed training run.
Requirements:
- Python 3.10 or above
- PyTorch / TorchVision
To run this training script with a single node, one can run from the optimizers directory:
SGD (with learning rate = 1e-2, momentum = 0.9):
torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_TRAINERS -m distributed_shampoo.examples.fsdp_cifar10_example --optimizer-type SGD --lr 1e-2 --momentum 0.9
Adam (with default parameters):
torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_TRAINERS -m distributed_shampoo.examples.fsdp_cifar10_example --optimizer-type ADAM
Distributed Shampoo (with default Adam grafting, precondition frequency = 100):
torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_TRAINERS -m distributed_shampoo.examples.fsdp_cifar10_example --optimizer-type DISTRIBUTED_SHAMPOO --precondition-frequency 100 --grafting-type ADAM --num-trainers-per-group -1 --use-bias-correction --use-decoupled-weight-decay --use-merge-dims
To use distributed checkpointing, append the flag --use-distributed-checkpoint with optional --checkpoint-dir argument.
The script will produce lifetime and window loss values retrieved from the forward pass over the data.
Guaranteed reproducibility on a single GPU.
"""
args = Parser.get_args()
# set seed for reproducibility
set_seed(args.seed)
# initialize distributed process group
device = setup_distribution(
backend=args.backend,
world_rank=WORLD_RANK,
world_size=WORLD_SIZE,
local_rank=LOCAL_RANK,
)
# instantiate model and loss function
model, loss_function = get_model_and_loss_fn(device)
model = FSDP(model, use_orig_params=True)
# instantiate data loader
data_loader, sampler = get_data_loader_and_sampler(
args.data_path, WORLD_SIZE, WORLD_RANK, args.local_batch_size
)
# instantiate optimizer (SGD, Adam, DistributedShampoo)
optimizer = instantiate_optimizer(
args.optimizer_type,
model,
lr=args.lr,
betas=(args.beta1, args.beta2),
beta3=args.beta3,
epsilon=args.epsilon,
momentum=args.momentum,
dampening=args.dampening,
weight_decay=args.weight_decay,
max_preconditioner_dim=args.max_preconditioner_dim,
precondition_frequency=args.precondition_frequency,
start_preconditioning_step=args.start_preconditioning_step,
inv_root_override=args.inv_root_override,
exponent_multiplier=args.exponent_multiplier,
use_nesterov=args.use_nesterov,
use_bias_correction=args.use_bias_correction,
use_decoupled_weight_decay=args.use_decoupled_weight_decay,
grafting_type=args.grafting_type,
grafting_epsilon=args.grafting_epsilon,
grafting_beta2=args.grafting_beta2,
use_merge_dims=args.use_merge_dims,
use_pytorch_compile=args.use_pytorch_compile,
distributed_config=FSDPShampooConfig(
param_to_metadata=compile_fsdp_parameter_metadata(model),
),
precision_config=PrecisionConfig(
computation_dtype=args.computation_dtype.value,
factor_matrix_dtype=args.factor_matrix_dtype.value,
inv_factor_matrix_dtype=args.inv_factor_matrix_dtype.value,
corrected_eigenvalues_dtype=args.corrected_eigenvalues_dtype.value,
factor_matrix_eigenvectors_dtype=args.factor_matrix_eigenvectors_dtype.value,
filtered_grad_dtype=args.filtered_grad_dtype.value,
momentum_dtype=args.momentum_dtype.value,
grafting_state_dtype=args.grafting_state_dtype.value,
),
use_protected_eigh=args.use_protected_eigh,
track_root_inv_residuals=args.track_root_inv_residuals,
preconditioner_computation_type=args.preconditioner_computation_type,
)
# train model
train_model(
model,
WORLD_SIZE,
loss_function,
sampler,
data_loader,
optimizer,
device=device,
epochs=args.epochs,
window_size=args.window_size,
local_rank=LOCAL_RANK,
)
# clean up process group
dist.destroy_process_group()