Skip to content

Commit

Permalink
disable overlap for qkv (#9079)
Browse files Browse the repository at this point in the history
* disable overlap for qkv (#9072)

* disable overlap for qkv

Signed-off-by: Rachit Garg <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Rachit Garg <[email protected]>
Co-authored-by: Rachit Garg <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Apply isort and black reformatting

Signed-off-by: michal2409 <[email protected]>

---------

Signed-off-by: Rachit Garg <[email protected]>
Signed-off-by: michal2409 <[email protected]>
Signed-off-by: Michal Futrega <[email protected]>
Co-authored-by: Rachit Garg <[email protected]>
Co-authored-by: Rachit Garg <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Michal Futrega <[email protected]>
Co-authored-by: michal2409 <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: Marc Romeyn <[email protected]>
  • Loading branch information
7 people authored and marcromeyn committed Jun 7, 2024
1 parent 8ced4b1 commit 0df7e95
Showing 1 changed file with 26 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def mcore_register_adapters(self):
if (
self.config.sequence_parallel
and hasattr(self.linear_qkv, "return_layernorm_output_gathered")
and not self.config.tp_comm_overlap
and not self.linear_qkv.ub_overlap_ag
):
# for LoRA SP, TE v1.5 can return layernorm output gathered so there is no need
# to perform the redundant gather in the adapter module, unless TP communication
Expand Down Expand Up @@ -142,11 +142,19 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
if SplitAlongDim is not None:

# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list,)
(query, key, value) = SplitAlongDim(
mixed_qkv,
3,
split_arg_list,
)
else:

# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3,)
(query, key, value) = torch.split(
mixed_qkv,
split_arg_list,
dim=3,
)

# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
Expand Down Expand Up @@ -231,11 +239,21 @@ def forward(

if self.checkpoint_core_attention:
core_attn_out = self._checkpointed_attention_forward(
query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params,
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
else:
core_attn_out = self.core_attention(
query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params,
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)

if packed_seq_params is not None:
Expand Down Expand Up @@ -316,7 +334,9 @@ def forward(self, hidden_states):
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(
intermediate_parallel, bias_parallel, self.config.activation_func_fp8_input_store,
intermediate_parallel,
bias_parallel,
self.config.activation_func_fp8_input_store,
)

else:
Expand Down

0 comments on commit 0df7e95

Please sign in to comment.