Skip to content

Commit

Permalink
Use device mesh to hybrid shard across nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras committed Apr 9, 2024
1 parent c9ceb5c commit aea7251
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import torch.multiprocessing as mp
import wandb
from packaging import version
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

from olmo.config import CheckpointType, TrainConfig
from olmo.data import build_train_dataloader
Expand All @@ -24,6 +26,7 @@
get_default_device,
get_global_rank,
get_local_rank,
get_local_world_size,
get_world_size,
peak_gpu_memory,
seed_all,
Expand Down Expand Up @@ -133,8 +136,32 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
param_init_fn = dummy_init_fn
else:
param_init_fn = None

# Set up device mesh for hybrid sharding in order to specify which nodes are assoicated to a given model replica
device_mesh: Optional[DeviceMesh] = None
if cfg.fsdp.sharding_strategy in (ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2):
if version.parse(torch.__version__) < version.parse("2.2.0"):
# Device mesh was not added to PyTorch until v2.2.0
raise OLMoConfigurationError(
"OLMo training does not correctly support hybrid sharding before torch 2.2.0"
)

num_model_replicas = cfg.fsdp.hybrid_sharding_num_model_replicas or (
get_world_size() // get_local_world_size()
)

if num_model_replicas <= 0:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must be a positive integer")

num_nodes = get_world_size() // get_local_world_size()
if num_nodes % num_model_replicas != 0:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide number of nodes")

device_mesh = init_device_mesh("cuda", (num_model_replicas, get_world_size() // num_model_replicas))

fsdp_model = FSDP(
olmo_model,
device_mesh=device_mesh,
sharding_strategy=cfg.fsdp.sharding_strategy,
mixed_precision=cfg.fsdp_precision,
auto_wrap_policy=wrap_policy,
Expand Down

0 comments on commit aea7251

Please sign in to comment.