-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU] attn supports f16 #26487
[CPU] attn supports f16 #26487
Conversation
f65a8e1
to
cc3d9cf
Compare
63700c7
to
667901f
Compare
@luo-cheng2021 Have updated the code according comments. Could you help review the codes, especially page attention part? Thanks. |
src/plugins/intel_cpu/src/graph.cpp
Outdated
@@ -193,9 +193,12 @@ void Graph::Replicate(const std::shared_ptr<const ov::Model> &model, | |||
auto parentNode = op2node[unusedOutput.get_node_shared_ptr()]; | |||
const auto port = unusedOutput.get_index(); | |||
const auto nodeName = std::string("stub_") + std::to_string(unusedOutput.get_index()) + "_" + parentNode->getName(); | |||
// WA: avoid PagedAttention's second output reorder. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd better not to hardcode here and should find the place where the output precision is changed to f16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any updates here. It has to be resolved before the merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any updates here. It has to be resolved before the merge.
still address it. The root cause is ConvertPrecision
transformation. Need to deal with it carefully and avoid affecting gpu.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dmitry-gorokhov @luo-cheng2021
remove such wa.
The root cause is that the PagedAttentionExtension op is a bit hardcode. doesn't provide a mechanism to change the 2nd output dtype.
I make some changes.
- add a
set_out_type
member func to PagedAttentionExtension op. - When execute
validate_and_infer_types
in PagedAttentionExtension, it will determine output type. It won't break the GPU path. - add a
fuse_type_to_pa
in CPU plugin, which is a extend toConvertPrecision
. It is used to specify the correct type forPagedAttentionExtension
's 2nd output type. The scope is in CPU plugin and won't break the common pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xczhai Just for my better understanding: could you please descrive the pattern there Reorder is inserted? Like next op after PA expected fp16 on its input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xczhai Just for my better understanding: could you please descrive the pattern there Reorder is inserted? Like next op after PA expected fp16 on its input?
Okay.
- At the very beginning, the PA op spec describes the two outputs type is aligned with input0 type. As a result, PA's outputs type is
f32
when entering CPU plugin. - During CPU plugin transformation,
ConvertPrecision
will convert or fuse the op's type. As a result, PA's two outputs type isf16
. But remember PA's 2nd output is dangle without any child andResult
node. - In construct graph, all the
dangle
output will be wrapped byResult
node and the type is aligned with output type. In this case, the specific pattern isPA's 2nd output
-->Result(f16)
- But in CPU node design, PA's 2nd output is always
f32
. So the pattern isPA's 2nd output(f32)
-->Result(f16)
. - The following ResolveConflict logic scan this pattern and then insert a
Reorder
. So the pattern becomesPA's 2nd output(f32)
-->Reorder
-->Result(f16)
src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp
Outdated
Show resolved
Hide resolved
ed88364
to
37479b2
Compare
Hi @dmitry-gorokhov Could you please take a review? Thanks! |
...ns/intel_cpu/tests/functional/custom/subgraph_tests/src/common/concat_multiple_query_sdp.cpp
Outdated
Show resolved
Hide resolved
...ns/intel_cpu/tests/functional/custom/subgraph_tests/src/common/concat_multiple_query_sdp.cpp
Outdated
Show resolved
Hide resolved
- rebase f16 impl from arm - refactor the testcase for x64
|
||
if ((inType == ElementType::bf16 && !ov::with_cpu_x86_bfloat16()) || | ||
(inType == ElementType::f16 && !ov::with_cpu_x86_avx512_core_fp16())) { | ||
GTEST_SKIP(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test skip is still there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test skip is still there
remove
93e1a24
to
b9c9f4c
Compare
Details:
Tickets: