diff --git a/src/llmcompressor/utils/fsdp/context.py b/src/llmcompressor/utils/fsdp/context.py index 177b2c02f..8a9252278 100644 --- a/src/llmcompressor/utils/fsdp/context.py +++ b/src/llmcompressor/utils/fsdp/context.py @@ -1,10 +1,12 @@ try: from accelerate import Accelerator +except ImportError: + Accelerator = None +try: from torch.distributed.fsdp import FullyShardedDataParallel - from torch.distributed.fsdp._common_utils import TrainingState + from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE, TrainingState except ImportError: FullyShardedDataParallel = None - Accelerator = None from contextlib import nullcontext @@ -14,22 +16,21 @@ "fix_fsdp_module_name", ] -FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" - def summon_full_params_context(model, offload_to_cpu: bool = False): - if FullyShardedDataParallel is not None: - # avoid nested summon_full_param context - if ( - hasattr(model, "training_state") - and model.training_state is TrainingState.SUMMON_FULL_PARAMS - ): - return nullcontext() - return FullyShardedDataParallel.summon_full_params( - model, offload_to_cpu=offload_to_cpu - ) + if FullyShardedDataParallel is None: + return nullcontext() - return nullcontext() + # do not call from within summon_full_param context + if ( + hasattr(model, "training_state") + and model.training_state is TrainingState.SUMMON_FULL_PARAMS + ): + return nullcontext() + + return FullyShardedDataParallel.summon_full_params( + model, offload_to_cpu=offload_to_cpu + ) def main_process_first_context(): @@ -46,12 +47,15 @@ def main_process_first_context(): def fix_fsdp_module_name(name: str) -> str: """ Remove FSDP wrapper prefixes from a module name. - Accounts for scenario where FSDP_WRAPPER_NAME is + Accounts for scenario where FSDP_WRAPPED_MODULE is at the end of the name, as well as in the middle. :param name: name to strip :return: stripped name """ - return name.replace(FSDP_WRAPPER_NAME + ".", "").replace( - "." + FSDP_WRAPPER_NAME, "" + if FullyShardedDataParallel is None: + return name + + return name.replace(FSDP_WRAPPED_MODULE + ".", "").replace( + "." + FSDP_WRAPPED_MODULE, "" )