diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 7fa7810878e..0cb6c59dd9a 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -1500,7 +1500,7 @@ def __post_init__(self): def set_state_dict_type(self): """ - Set the state dict config based on the `StateDictType. + Set the state dict config based on the `StateDictType`. """ from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullOptimStateDictConfig, @@ -1538,9 +1538,7 @@ def set_auto_wrap_policy(self, model): # First base off of `_no_split_modules` no_split_modules = getattr(model, "_no_split_modules", None) - default_transformer_cls_names_to_wrap = ( - ",".join(model._no_split_modules) if no_split_modules is not None else "" - ) + default_transformer_cls_names_to_wrap = list(no_split_modules) if no_split_modules is not None else [] if self.auto_wrap_policy == transformer_auto_wrap_policy: if self.transformer_cls_names_to_wrap is None: self.transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index aa53efe0dd2..d1234f23538 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -49,6 +49,7 @@ set_seed(42) BERT_BASE_CASED = "bert-base-cased" +LLAMA_TESTING = "hf-internal-testing/tiny-random-LlamaForCausalLM" FP16 = "fp16" BF16 = "bf16" dtypes = [FP16, BF16] @@ -136,38 +137,42 @@ def test_state_dict_type(self): assert fsdp_plugin.state_dict_config.rank0_only def test_auto_wrap_policy(self): - model = AutoModel.from_pretrained(BERT_BASE_CASED) - for policy in FSDP_AUTO_WRAP_POLICY: - env = self.fsdp_env.copy() - env["FSDP_AUTO_WRAP_POLICY"] = policy - transformer_cls_to_wrap = None - min_num_params = None - if policy == "TRANSFORMER_BASED_WRAP": - env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "BertLayer" - transformer_cls_to_wrap = "BertLayer" - elif policy == "SIZE_BASED_WRAP": - env["FSDP_MIN_NUM_PARAMS"] = "2000" - min_num_params = 2000 - # First test via env - with mockenv_context(**env): - fsdp_plugin = FullyShardedDataParallelPlugin() + for model_name in [LLAMA_TESTING, BERT_BASE_CASED]: + model = AutoModel.from_pretrained(model_name) + layer_to_wrap = "LlamaDecoderLayer" if model_name == LLAMA_TESTING else "BertLayer" + for policy in FSDP_AUTO_WRAP_POLICY: + env = self.fsdp_env.copy() + env["FSDP_AUTO_WRAP_POLICY"] = policy + transformer_cls_to_wrap = None + min_num_params = None + env.pop("FSDP_TRANSFORMER_CLS_TO_WRAP", None) + env.pop("FSDP_MIN_NUM_PARAMS", None) + if policy == "TRANSFORMER_BASED_WRAP": + env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = layer_to_wrap + transformer_cls_to_wrap = layer_to_wrap + elif policy == "SIZE_BASED_WRAP": + env["FSDP_MIN_NUM_PARAMS"] = "2000" + min_num_params = 2000 + # First test via env + with mockenv_context(**env): + fsdp_plugin = FullyShardedDataParallelPlugin() + fsdp_plugin.set_auto_wrap_policy(model) + if policy == "NO_WRAP": + assert fsdp_plugin.auto_wrap_policy is None + else: + assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial) + + # Then manually set the policy + fsdp_plugin = FullyShardedDataParallelPlugin( + auto_wrap_policy=policy, + transformer_cls_names_to_wrap=transformer_cls_to_wrap, + min_num_params=min_num_params, + ) fsdp_plugin.set_auto_wrap_policy(model) - if policy == "NO_WRAP": - assert fsdp_plugin.auto_wrap_policy is None - else: - assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial) - - # Then manually set the policy - fsdp_plugin = FullyShardedDataParallelPlugin( - auto_wrap_policy=policy, - transformer_cls_names_to_wrap=transformer_cls_to_wrap, - min_num_params=min_num_params, - ) - fsdp_plugin.set_auto_wrap_policy(model) - if policy == "NO_WRAP": - assert fsdp_plugin.auto_wrap_policy is None - else: - assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial) + if policy == "NO_WRAP": + assert fsdp_plugin.auto_wrap_policy is None + else: + assert isinstance(fsdp_plugin.auto_wrap_policy, functools.partial) env = self.fsdp_env.copy() env["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP"