-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rfc: graph: propose to support Grouped-query Attention #2018
base: rfcs
Are you sure you want to change the base?
Changes from 2 commits
7c3caf5
48aedf7
9099287
509d1b0
5b57d70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,215 @@ | ||||||
# Support Grouped-query Attention in Graph API | ||||||
|
||||||
## Introduction & Motivation | ||||||
|
||||||
In typical Multi-Head Attention (MHA), Query, Key, and Value tensors usually | ||||||
have the same head number (eg., all Query, Key, and Value have shape (N, H, S, | ||||||
D) where N is mini-batch size, S is sequence length, H is head number, and D is | ||||||
head size). It becomes the performance bottleneck to load the Key and Value | ||||||
tensors in each generation step especially when the sentence length becomes | ||||||
longer. | ||||||
|
||||||
To reduce the memory bandwidth of loading the Key and Value tensors, Multi-Query | ||||||
Attention (MQA) is invented by reducing the head number of Key and Value to 1 | ||||||
which means multiple Queries will map the same single Key and Value (the shape | ||||||
of Key and Value become (N, 1, S, D)). | ||||||
|
||||||
However, MQA can lead to model quality degradation and training instability. | ||||||
Grouped-Query Attention (GQA), an interpolation between typical MHA and MQA, is | ||||||
proposed with single Key and Value head per subgroup of Query heads (the shape | ||||||
of Key and Value become (N, G, S, D) where H > G and H % G = 0). GQA is widely | ||||||
adopted in llama2-70B, llama3 family, llama3.1 family [[#1]][1], Mistral, and | ||||||
StarCoder2 [[#2]][2]. | ||||||
|
||||||
The following figure is from the paper [[#3]][3] and shows the comparison of | ||||||
MHA, MQA, and GQA. | ||||||
|
||||||
![attention](attention.png) | ||||||
|
||||||
oneDNN already supports Scaled Dot-Product Attention (SDPA) through Graph API | ||||||
([document](https://oneapi-src.github.io/oneDNN/dev_guide_graph_sdpa.html#doxid-dev-guide-graph-sdpa) | ||||||
and | ||||||
[examples](https://github.com/oneapi-src/oneDNN/blob/main/examples/graph/sdpa.cpp)). | ||||||
This proposal aims to extend the support from SDPA to cover GQA. | ||||||
|
||||||
## GQA in PyTorch | ||||||
|
||||||
Unlike SDPA, PyTorch does not support GQA as a fused operations. In Huggingface | ||||||
Transformers [[#4]][4], GQA is implemented in the following way. | ||||||
|
||||||
1. Firstly, as the Query (in shape (N, H, S, D)) and Key/Value (in shape (N, G, | ||||||
S, D)) have different head number dimension and cannot perform dot-product | ||||||
directly, Key and Value tensors are repeated along the head number dimension. | ||||||
|
||||||
```python | ||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | ||||||
""" | ||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | ||||||
""" | ||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | ||||||
if n_rep == 1: | ||||||
return hidden_states | ||||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | ||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | ||||||
``` | ||||||
|
||||||
2. Then the repeated Key and Values can be passed to the typical SDPA block. | ||||||
|
||||||
```python | ||||||
key_states = repeat_kv(key_states, self.num_key_value_groups) | ||||||
value_states = repeat_kv(value_states, self.num_key_value_groups) | ||||||
|
||||||
# typical SDPA block | ||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) | ||||||
|
||||||
if attention_mask is not None: # no matter the length, we just slice it | ||||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] | ||||||
attn_weights = attn_weights + causal_mask | ||||||
|
||||||
# upcast attention to fp32 | ||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) | ||||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) | ||||||
attn_output = torch.matmul(attn_weights, value_states) | ||||||
``` | ||||||
|
||||||
The overall workflow can be visualized as the following diagram: | ||||||
|
||||||
![GQA](GQA.png) | ||||||
|
||||||
## Proposals | ||||||
|
||||||
oneDNN's MatMul operation is used to construct the dot-product operations in | ||||||
SDPA. But the MatMul operation requires the two inputs should have the same | ||||||
batch dimension or the batch dimensions can be simply broadcasted. MQA can be | ||||||
implemented directly using typical SDPA and the broadcasting rule. But in GQA, | ||||||
the head number of Key and Value is different from the head number of Query. We | ||||||
propose two options to pre-process the Query, Key and Value tensors before | ||||||
passing them into the MatMul operations. | ||||||
|
||||||
### Option 1 | ||||||
|
||||||
Add new ops and patterns. We can pre-process the Key and Value tensors following | ||||||
how they are pre-processed in the Huggingface Transformer code above. To achieve | ||||||
that, below two new operations need to be supported by oneDNN graph operation set: | ||||||
|
||||||
- Unsqueeze (see [sub-rfc document](./unsqueeze_and_expand.md)) | ||||||
- Expand (see [sub-rfc document](./unsqueeze_and_expand.md)) | ||||||
|
||||||
Then we would to support patterns like the above diagram. | ||||||
|
||||||
Pros: | ||||||
|
||||||
1. It makes the pattern graph to be consistent with the popular implementation | ||||||
in the community. Frameworks can still map the framework graph directly to | ||||||
oneDNN graph and get the fused partition. | ||||||
|
||||||
Cons: | ||||||
|
||||||
1. Need to add more operations to oneDNN graph operation set. | ||||||
2. The implementation in the community may change. Even in the Huggingface code, | ||||||
another alternative implementation is explicitly mentioned (via | ||||||
`torch.repeat_interleave`). Once the implementation changes, the above pros | ||||||
will become invalid. | ||||||
|
||||||
### Option 2 (recommended) | ||||||
|
||||||
Add new patterns. We can reshape Query, Key and Values tensors from 4D to 5D and | ||||||
leverage the broadcasting semantics of MatMul operation to perform the dot-products. | ||||||
|
||||||
1. Reshape Query from 4D shape (N, H, S, D) to 5D shape (N, G, H / G, S, D). | ||||||
2. Reshape Key from 4D shape (H, G, S, D) to 5D shape (N, G, 1, S, D). | ||||||
3. Performance 5D matmul between (N, G, H / G, S, D) and (N, G, 1, S, D). The | ||||||
third dimension will be broadcasted from 1 to `H / G` automatically per the | ||||||
broadcasting rule of MatMul operation. | ||||||
4. Similar reshape and broadcasting can also be applied to the dot-product of | ||||||
Value. | ||||||
|
||||||
Here is the diagram: | ||||||
![option2](option2.png) | ||||||
|
||||||
Pros: | ||||||
|
||||||
1. No change to API. | ||||||
2. The pattern looks simpler than the pattern in option 1. | ||||||
|
||||||
Cons: | ||||||
|
||||||
1. The pattern is less intuitive from GQA definition. | ||||||
2. The pattern cannot be used to optimize a framework graph directly. Frameworks | ||||||
will have to implement GQA fusion by themselves and leverage this option to | ||||||
optimized the fused GQA. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this turns out to be a serious con, it would be reasonable to add a pass to match the Option 1 subgraph and convert it to the Option 2 subgraph, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will have to support and match the subgraph in Option 1 once the request pops up. With that, oneDNN will support and maintain several different patterns for the same GQA functionality. Maybe it's not an issue as even for now we choose to Option 1 as the initial step, the pattern may still change in the future as mentioned in the cons of Option 1. |
||||||
|
||||||
### Option 3 | ||||||
|
||||||
We can extend the MatMul broadcasting rules to support group broadcast. | ||||||
Currently, MatMul operation supports below broadcasting rules. | ||||||
|
||||||
For single 2D matrix multiplication: | ||||||
|
||||||
| Matrix A | Matrix B | Matrix C = A x B | | ||||||
| -- | -- | -- | | ||||||
| M x K | K x N | M x N | | ||||||
|
||||||
For batched matrix multiplications: | ||||||
|
||||||
| Matrix A | Matrix B | Matrix C = A x B | | ||||||
| -- | -- | -- | | ||||||
| B x M x K | B x K x N | B x M x N | | ||||||
| B x M x K | 1 x K x N | B x M x N | | ||||||
| 1 x M x K | B x K x N | B x M x N | | ||||||
|
||||||
This can be extended to multiple batch dimensions matrix multiplications: | ||||||
|
||||||
| Matrix A | Matrix B | Matrix C = A x B | | ||||||
| -- | -- | -- | | ||||||
| B1 x 1 x B3 x M x K | B1 x B2 x 1 x M x K | B1 x B2 x B3 x M x N | | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed, thanks. |
||||||
|
||||||
This RFC proposes to further extend it to support group broadcast: | ||||||
|
||||||
| Matrix A | Matrix B | Matrix C = A x B | | ||||||
| -- | -- | -- | | ||||||
| B1 x 1 x B3 x M x K | 1 x B2 x (B3/c) x K x N | B1 x B2 x B3 x M x N | | ||||||
|
||||||
where, c is a factor of B3. | ||||||
|
||||||
This rule looks uncommon and is not supported by the typical broadcasting rules | ||||||
(see broadcasting in | ||||||
[ONNX](https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md) and | ||||||
[NumPy](https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules).), | ||||||
but actually it's added to the MatMul operation of cuDNN in order to support | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to add the link here for reference: https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-graph-library.html#cudnn-backend-operation-matmul-descriptor . There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added, thanks. |
||||||
GQA. | ||||||
|
||||||
Pros. | ||||||
|
||||||
1. We will have the most unified and simple pattern for SDPA, GQA, and MQA. | ||||||
2. The same support methodology with cuDNN API. | ||||||
|
||||||
Cons. | ||||||
|
||||||
1. It complicates the semantics of MatMul operation. Previously incompatible | ||||||
batch dimension will be treated as an error. But now we need to further check | ||||||
if they can be properly group broadcasted. Even cuDNN explicitly document | ||||||
that the new broadcasting rule is only supported by the fused attention. | ||||||
2. Same as option 2, still the pattern cannot be used to optimize a framework | ||||||
graph directly. Frameworks will have to implement GQA fusion by themselves | ||||||
and leverage this option to optimized the fused GQA. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another con here maybe that we rely on oneDNN matmul primitive kernels for reference implementation and testing in benchdnn which do not support the new broadcasting rule. Extending the broadcast semantics on graph side will also request additional effort for reference implementation and testing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added, thanks. |
||||||
|
||||||
## Summary | ||||||
We would recommend to go with option 2, as it don't need to change the API and | ||||||
minimal changes to the library. It is simple enough for understanding. For the | ||||||
cons, it is releted to the integration in the framework. Actually framework have | ||||||
to implement GQA fusion by themselves for both option2 and option3, also option | ||||||
1 when the implementation in the community changes. | ||||||
|
||||||
## References | ||||||
|
||||||
1. [https://github.com/meta-llama/llama-models][1] | ||||||
2. [https://huggingface.co/models][2] | ||||||
3. [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head | ||||||
Checkpoints][3] | ||||||
4. [https://github.com/huggingface/transformers/blob/2782aadae2b0b0c313eac3ee70f84f0335577635/src/transformers/models/llama/modeling_llama.py#L203C1-L212C85][4] | ||||||
|
||||||
[1]: https://github.com/meta-llama/llama-models | ||||||
[2]: https://huggingface.co/models | ||||||
[3]: https://arxiv.org/pdf/2305.13245 | ||||||
[4]: https://github.com/huggingface/transformers/blob/2782aadae2b0b0c313eac3ee70f84f0335577635/src/transformers/models/llama/modeling_llama.py#L203C1-L212C85 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Support Unsqueeze and Expand in Graph API | ||
|
||
## Introduction & Motivation | ||
|
||
As mentioned in the GQA RFC, if we choose to support option1, we need support | ||
StaticUnsqueeze and StaticExpand operation in Graph API. | ||
|
||
## Proposal | ||
|
||
### StaticUnsqueeze | ||
In the frameworks, there is `unsqueeze` used to add dimensions of size 1 to the | ||
input tensor. We propose to define `StaticUnsqueeze` to map the similar | ||
`unsqueeze` in the frameworks.[[#1]][1] [[#2]][2] [[#3]][3] [[#4]][4] | ||
|
||
| Framework | TensorFlow | PyTorch | ONNX | OpenVINO | | ||
|----------------|------------|-----------|----------|------------| | ||
| op | expand_dims| unsqueeze |Unsqueeze |Unsqueeze | | ||
| input | src | src | data | src | | ||
| input | axis | dim | axes(a tensor of int) | dim| | ||
| output | dst | dst | dst | dst | | ||
|
||
These ops in the framework are the same with only a slight difference. Only for | ||
the ONNX `Unsuqeeze`, the second input is a tensor of int. It supports a list of | ||
dimensions to be inserted. Both these `axis` and `dim` are in the range | ||
[-input.dim() - 1, input.dim()]. | ||
|
||
Based on the definitions of these frameworks, we define the following operation | ||
`StaticUnsqueeze` to map these ops. | ||
|
||
| StaticUnsqueeze | Argument Name | Required or Optional | Data Type | | ||
|-----------|------------------|-----------------------------|---------------| | ||
| input | `src` | Required | f32,f16,bf16* | | ||
| attribute | `dim` | Required | s64 | | ||
| output | `dst` | Required | f32,f16,bf16* | | ||
|
||
**Detailed description**: | ||
|
||
It return the output with a dimension of size one inserted at the specified | ||
`dim` position. Unsqueeze operation can return a view or copy of `src`. | ||
|
||
`dim`: the index at which to insert the singleton dimension, which should also be in | ||
the range `[-src.dim() - 1, src.dim()]` | ||
|
||
For example: | ||
when `src`'s shape is \[4\], | ||
1. `dim` = 0, the `dst`'s shape is [1,4] | ||
2. `dim` = 1, the `dst`'s shape is [4,1] | ||
|
||
### StaticExpand | ||
|
||
In the frameworks, there are some operations similar to expand semantics. For | ||
example, `expand, expand_as, repeat, repeat_interleave` in PyTorch, `broadcast1, | ||
broadcast3` in openvino, `Expand-13, Expand-8` in ONNX and | ||
etc.[[#5]][5] [[#6]][6] [[#7]][7] [[#8]][8] These OP definitions are not quite the | ||
same, but they can all implement similar expand functions. | ||
|
||
However, there is no operation in Graph API corresponding to the semantics in | ||
the framework. So we can add an operation `StaticExpand` to map the similar op | ||
from framework. It replicates data on the input to fit a given shape. | ||
|
||
#### option1 | ||
|
||
| StaticExpand | Argument Name | Required or Optional | Data Type | | ||
|-----------|------------------|-----------------------------|---------------| | ||
| input | `src` | Required | f32,f16,bf16* | | ||
| attribute | `target_shape` | Required | s64 | | ||
| output | `dst` | Required | f32,f16,bf16* | | ||
|
||
**Detailed description**: | ||
|
||
`Expand` takes the first tensor `src` builds a new tensor with shape matching the | ||
attribute `target_shape`. `target_shape` is a 1D integer tensor that represents | ||
required shape of the output. It requires thar the rank of input and output are | ||
equal. | ||
|
||
Pros: | ||
|
||
1. This definition is simple and easy to understand. It is convenient to map | ||
pytorch `unsqueeze`. | ||
|
||
Cons: | ||
|
||
1. It require the input's rank is equal output's rank. Don't support some op | ||
function. | ||
|
||
#### option2 | ||
Add an attribute `axes_mapping` based on the option1. | ||
|
||
| StaticExpand | Argument Name | Required or Optional | Data Type | | ||
|-----------|------------------|-----------------------------|---------------| | ||
| input | `src` | Required | f32,f16,bf16* | | ||
| attribute | `target_shape` | Required | s64 | | ||
| attribute | `axes_mapping` | Optional* | s64 | | ||
| output | `dst` | Required | f32,f16,bf16* | | ||
|
||
**Detailed description**: | ||
|
||
The attribute `axes_mapping` is a tensor of int. If this attribute is not set, | ||
it is the same as option 1. If the attribute is set, the size of `axis_mapping` | ||
should match the rank of input data tensor, so all axes from data tensor should | ||
be mapped to axes of the output. For example, `axes_mapping = [1]`enables | ||
broadcasting of a tensor with shape `[C]` to shape `[N,C,H,W]` by replication of | ||
initial tensor along dimensions 0, 2 and 3. Another example is broadcasting of | ||
tensor with shape `[H,W]` to shape `[N,H,W,C]` with `axes_mapping = [1, 2]`. | ||
|
||
Pros: | ||
|
||
1. It solve the option1's cons. | ||
|
||
Cons: | ||
|
||
1. Need add an attribute, which increases the difficulty of understanding. Not | ||
very useful in practice, the option1's cons can be solved by `unsqueeze` op. | ||
|
||
## References | ||
|
||
1. https://www.tensorflow.org/api_docs/python/tf/expand_dims | ||
2. https://pytorch.org/docs/stable/generated/torch.unsqueeze.html | ||
3. https://onnx.ai/onnx/operators/onnx__Unsqueeze.html#l-onnx-doc-unsqueeze | ||
4. https://docs.openvino.ai/2022.3/openvino_docs_ops_shape_Unsqueeze_1.html | ||
5. https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html | ||
6. https://docs.openvino.ai/2022.3/openvino_docs_ops_movement_Broadcast_3.html | ||
7. https://onnx.ai/onnx/operators/onnx__Expand.html | ||
8. https://www.tensorflow.org/api_docs/python/tf/broadcast_to | ||
|
||
[1]: https://www.tensorflow.org/api_docs/python/tf/expand_dims | ||
[2]: https://pytorch.org/docs/stable/generated/torch.unsqueeze.html | ||
[3]: https://onnx.ai/onnx/operators/onnx__Unsqueeze.html#l-onnx-doc-unsqueeze | ||
[4]: https://docs.openvino.ai/2022.3/openvino_docs_ops_shape_Unsqueeze_1.html | ||
[5]: https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html | ||
[6]: https://docs.openvino.ai/2022.3/openvino_docs_ops_movement_Broadcast_3.html | ||
[7]: https://onnx.ai/onnx/operators/onnx__Expand.html | ||
[8]: https://www.tensorflow.org/api_docs/python/tf/broadcast_to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI - the PyTorch PR just got merged this week: pytorch/pytorch#132689