Skip to content

Commit

Permalink
use existing constant, clarify comment
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Nov 21, 2024
1 parent 449cfdf commit 8a9e89d
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions src/llmcompressor/utils/fsdp/context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
try:
from accelerate import Accelerator
except ImportError:
Accelerator = None
try:
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp._common_utils import TrainingState
from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE, TrainingState
except ImportError:
FullyShardedDataParallel = None
Accelerator = None

from contextlib import nullcontext

Expand All @@ -14,22 +16,21 @@
"fix_fsdp_module_name",
]

FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"


def summon_full_params_context(model, offload_to_cpu: bool = False):
if FullyShardedDataParallel is not None:
# avoid nested summon_full_param context
if (
hasattr(model, "training_state")
and model.training_state is TrainingState.SUMMON_FULL_PARAMS
):
return nullcontext()
return FullyShardedDataParallel.summon_full_params(
model, offload_to_cpu=offload_to_cpu
)
if FullyShardedDataParallel is None:
return nullcontext()

return nullcontext()
# do not call from within summon_full_param context
if (
hasattr(model, "training_state")
and model.training_state is TrainingState.SUMMON_FULL_PARAMS
):
return nullcontext()

return FullyShardedDataParallel.summon_full_params(
model, offload_to_cpu=offload_to_cpu
)


def main_process_first_context():
Expand All @@ -46,12 +47,15 @@ def main_process_first_context():
def fix_fsdp_module_name(name: str) -> str:
"""
Remove FSDP wrapper prefixes from a module name.
Accounts for scenario where FSDP_WRAPPER_NAME is
Accounts for scenario where FSDP_WRAPPED_MODULE is
at the end of the name, as well as in the middle.
:param name: name to strip
:return: stripped name
"""
return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
"." + FSDP_WRAPPER_NAME, ""
if FullyShardedDataParallel is None:
return name

return name.replace(FSDP_WRAPPED_MODULE + ".", "").replace(
"." + FSDP_WRAPPED_MODULE, ""
)

0 comments on commit 8a9e89d

Please sign in to comment.