Skip to content

Commit

Permalink
Fix FSDP auto_wrap using characters instead of full str for layers (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr authored Sep 4, 2024
1 parent b5235f2 commit ab89fc7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 35 deletions.
6 changes: 2 additions & 4 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,7 @@ def __post_init__(self):

def set_state_dict_type(self):
"""
Set the state dict config based on the `StateDictType.
Set the state dict config based on the `StateDictType`.
"""
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
Expand Down Expand Up @@ -1538,9 +1538,7 @@ def set_auto_wrap_policy(self, model):

# First base off of `_no_split_modules`
no_split_modules = getattr(model, "_no_split_modules", None)
default_transformer_cls_names_to_wrap = (
",".join(model._no_split_modules) if no_split_modules is not None else ""
)
default_transformer_cls_names_to_wrap = list(no_split_modules) if no_split_modules is not None else []
if self.auto_wrap_policy == transformer_auto_wrap_policy:
if self.transformer_cls_names_to_wrap is None:
self.transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap
Expand Down
67 changes: 36 additions & 31 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
set_seed(42)

BERT_BASE_CASED = "bert-base-cased"
LLAMA_TESTING = "hf-internal-testing/tiny-random-LlamaForCausalLM"
FP16 = "fp16"
BF16 = "bf16"
dtypes = [FP16, BF16]
Expand Down Expand Up @@ -136,38 +137,42 @@ def test_state_dict_type(self):
assert fsdp_plugin.state_dict_config.rank0_only

def test_auto_wrap_policy(self):
model = AutoModel.from_pretrained(BERT_BASE_CASED)
for policy in FSDP_AUTO_WRAP_POLICY:
env = self.fsdp_env.copy()
env["FSDP_AUTO_WRAP_POLICY"] = policy
transformer_cls_to_wrap = None
min_num_params = None
if policy == "TRANSFORMER_BASED_WRAP":
env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "BertLayer"
transformer_cls_to_wrap = "BertLayer"
elif policy == "SIZE_BASED_WRAP":
env["FSDP_MIN_NUM_PARAMS"] = "2000"
min_num_params = 2000
# First test via env
with mockenv_context(**env):
fsdp_plugin = FullyShardedDataParallelPlugin()
for model_name in [LLAMA_TESTING, BERT_BASE_CASED]:
model = AutoModel.from_pretrained(model_name)
layer_to_wrap = "LlamaDecoderLayer" if model_name == LLAMA_TESTING else "BertLayer"
for policy in FSDP_AUTO_WRAP_POLICY:
env = self.fsdp_env.copy()
env["FSDP_AUTO_WRAP_POLICY"] = policy
transformer_cls_to_wrap = None
min_num_params = None
env.pop("FSDP_TRANSFORMER_CLS_TO_WRAP", None)
env.pop("FSDP_MIN_NUM_PARAMS", None)
if policy == "TRANSFORMER_BASED_WRAP":
env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = layer_to_wrap
transformer_cls_to_wrap = layer_to_wrap
elif policy == "SIZE_BASED_WRAP":
env["FSDP_MIN_NUM_PARAMS"] = "2000"
min_num_params = 2000
# First test via env
with mockenv_context(**env):
fsdp_plugin = FullyShardedDataParallelPlugin()
fsdp_plugin.set_auto_wrap_policy(model)
if policy == "NO_WRAP":
assert fsdp_plugin.auto_wrap_policy is None
else:
assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)

# Then manually set the policy
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=policy,
transformer_cls_names_to_wrap=transformer_cls_to_wrap,
min_num_params=min_num_params,
)
fsdp_plugin.set_auto_wrap_policy(model)
if policy == "NO_WRAP":
assert fsdp_plugin.auto_wrap_policy is None
else:
assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)

# Then manually set the policy
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=policy,
transformer_cls_names_to_wrap=transformer_cls_to_wrap,
min_num_params=min_num_params,
)
fsdp_plugin.set_auto_wrap_policy(model)
if policy == "NO_WRAP":
assert fsdp_plugin.auto_wrap_policy is None
else:
assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)
if policy == "NO_WRAP":
assert fsdp_plugin.auto_wrap_policy is None
else:
assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial)

env = self.fsdp_env.copy()
env["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP"
Expand Down

0 comments on commit ab89fc7

Please sign in to comment.