Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: huvunvidia <[email protected]>
  • Loading branch information
huvunvidia committed Jul 10, 2024
1 parent f0f6c86 commit 3e601e1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def fwd_output_only_func(dataloader_iter, model):
# processing forward args for mcore T5
if self.mcore_t5:
# when run encoding
if output_name=="hiddens":
if output_name == "hiddens":
(
encoder_input_ids,
encoder_attn_mask,
Expand All @@ -814,18 +814,18 @@ def fwd_output_only_func(dataloader_iter, model):
)

output = model(
encoder_input_ids=encoder_input_ids,
decoder_input_ids=None,
encoder_attn_mask=encoder_attn_mask_3d,
decoder_attn_mask=None,
encoder_decoder_attn_mask=None,
lm_labels=None,
encoder_hidden_states=None,
output_encoder_hidden_only=True,
encoder_input_ids=encoder_input_ids,
decoder_input_ids=None,
encoder_attn_mask=encoder_attn_mask_3d,
decoder_attn_mask=None,
encoder_decoder_attn_mask=None,
lm_labels=None,
encoder_hidden_states=None,
output_encoder_hidden_only=True,
).contiguous()

# when run decoding
elif output_name=="logits":
elif output_name == "logits":
(
encoder_hidden_states,
encoder_attn_mask,
Expand All @@ -843,23 +843,26 @@ def fwd_output_only_func(dataloader_iter, model):
enc_dec_attn_mask_3d = build_attention_mask_3d(
decoder_attn_mask, encoder_attn_mask, AttnMaskType.padding
)

# re-transpose encoder_hidden_states from [batch, seq_len, hidden] to [seq_len, batch, hidden]
encoder_hidden_states = encoder_hidden_states.transpose(1, 0)

output = model(
encoder_input_ids=None,
decoder_input_ids=decoder_input_ids,
encoder_attn_mask=encoder_attn_mask_3d,
decoder_attn_mask=decoder_attn_mask_3d,
encoder_decoder_attn_mask=enc_dec_attn_mask_3d,
lm_labels=None,
encoder_hidden_states=encoder_hidden_states,
output_encoder_hidden_only=False,
encoder_input_ids=None,
decoder_input_ids=decoder_input_ids,
encoder_attn_mask=encoder_attn_mask_3d,
decoder_attn_mask=decoder_attn_mask_3d,
encoder_decoder_attn_mask=enc_dec_attn_mask_3d,
lm_labels=None,
encoder_hidden_states=encoder_hidden_states,
output_encoder_hidden_only=False,
).contiguous()

else:
assert output_name in ["hiddens", "logits"], "output_name argument must be either 'hiddens' or 'logits'"
assert output_name in [
"hiddens",
"logits",
], "output_name argument must be either 'hiddens' or 'logits'"

else:
# map batch and shared args into forward args
Expand Down Expand Up @@ -1255,7 +1258,7 @@ def dummy():

# build input arguments description
if tokens_enc is not None:
if self.mcore_t5 is True:
if self.mcore_t5 is True:
batch_for_pipeline = [tokens_enc, enc_mask]
arg_names = []
else:
Expand All @@ -1273,9 +1276,7 @@ def dummy():
arg_names.append('enc_input')

if self.mcore_t5:
forward_step_func = self._get_forward_output_only_func(
arg_names=arg_names, output_name="hiddens"
)
forward_step_func = self._get_forward_output_only_func(arg_names=arg_names, output_name="hiddens")
else:
forward_step_func = self._get_forward_output_only_func(
arg_names=arg_names, output_name="hiddens", output_enc_hidden_only=True
Expand Down
12 changes: 8 additions & 4 deletions nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _check_and_add_adapter(self, name, module, peft_name, peft_cfg, name_key_to_
f'model.{mcore_target}',
f'model.module.{mcore_target}',
f'enc_dec_model.{mcore_target}',
f'enc_dec_model.module.{mcore_target}',
f'enc_dec_model.module.{mcore_target}',
]: # simple string match for now
if not isinstance(module, IdentityOp):
swap_mcore_mixin(module, mcore_mixin)
Expand Down Expand Up @@ -164,7 +164,7 @@ def _get_layers_from_model(self, model):
if self.cfg.megatron_amp_O2:
layers = model.module.encoder.layers + model.module.decoder.layers
else:
layers = model.encoder.layers + model.decoder.layers
layers = model.encoder.layers + model.decoder.layers
else:
if self.cfg.megatron_amp_O2:
layers = model.module.language_model.encoder.layers
Expand All @@ -178,7 +178,9 @@ def _check_and_add_peft_cfg(self, peft_cfg):
assert not self.use_mcore_gpt or hasattr(
peft_cfg, 'name_key_to_mcore_mixins'
), f"{peft_cfg.__class__.__name__} is not supported in megatron core mode yet."
name_key_to_mcore_mixins = peft_cfg.name_key_to_mcore_mixins if (self.use_mcore_gpt or self.use_mcore_t5) else None
name_key_to_mcore_mixins = (
peft_cfg.name_key_to_mcore_mixins if (self.use_mcore_gpt or self.use_mcore_t5) else None
)

for adapter_name, adapter_cfg in peft_cfg.get_config_dict().items():
# mixin for mcore models
Expand Down Expand Up @@ -469,7 +471,9 @@ def on_load_checkpoint(self, checkpoint) -> None:
if not self.ptuning_only_and_non_first_stage:
# same as super().on_load_checkpoint() but strict=False and only check unexpected keys
# mcore uses distributed checkpointing
use_mcore = (hasattr(self, 'mcore_gpt') and self.mcore_gpt) or (hasattr(self, 'mcore_t5') and self.mcore_t5)
use_mcore = (hasattr(self, 'mcore_gpt') and self.mcore_gpt) or (
hasattr(self, 'mcore_t5') and self.mcore_t5
)
if use_mcore:
for index, module in enumerate(self.get_model_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
Expand Down

0 comments on commit 3e601e1

Please sign in to comment.