Skip to content

Commit

Permalink
update inject_model_helper
Browse files Browse the repository at this point in the history
  • Loading branch information
zigzagcai committed Sep 26, 2024
1 parent 57e99a0 commit a572ca1
Showing 1 changed file with 30 additions and 17 deletions.
47 changes: 30 additions & 17 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,16 @@ def inject_config(model: nn.Module) -> None:


def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Optional[Dict] = None) -> None:
# get inject_info
"""
Inject model helper functions.
Args:
model (Union[nn.Module, nn.ModuleList]):
For built-in models, it is nn.Module for no pp and nn.ModuleList for pp.
For injected models, it is nn.Module.
inject_info (Optional[Dict]): configurations for injected_models.
"""
# parse inject_info
if inject_info is not None:
inject = inject_info.get("inject", False)
interactive = inject_info.get("interactive", False)
Expand All @@ -928,33 +937,37 @@ def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Opt
"norm": inject_norm,
}

# inject config
if inject:
inject_config(model)

if not isinstance(model, nn.ModuleList):
model = [model]

# inject modules
for _chunk in model:
# Special case for pure dp mode: skip
if (
isinstance(gpc.config.parallel["tensor"], dict)
and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name
and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL)
):
continue
# In-place replacement or check for modules: "embed", "linear", "norm"
# (1) If inject=True, in-place replacement
# (2) If inject=False, check
for mod in modules:
inject_funcs[mod](_chunk, inject, interactive)

# reset parameters and move model to device
# reset parameters if needed, model should have reset_parameters() method
if reset_params:
_chunk.reset_parameters()
for _chunk in model:
if inject:
if reset_params:
_chunk.reset_parameters()
# If model is initialized on cpu, model should be moved to cuda device after injection
if not next(_chunk.parameters()).is_cuda:
_chunk.to(get_current_device())

# inject configs
if inject:
inject_config(model[0])
if gpc.is_rank_for_log():
logger.info(
f"inject is enabled, please check the model carefully, "
f"if there are any problems, please report issue to us. "
f"The injected model is \n {model}"
)
# print injected model
if inject and gpc.is_rank_for_log():
logger.info(
f"inject is enabled, please check the model carefully, "
f"if there are any problems, please report issue to us. "
f"The injected model is \n {model}"
)

0 comments on commit a572ca1

Please sign in to comment.