Skip to content

Commit

Permalink
rebase f16 from arm
Browse files Browse the repository at this point in the history
- rebase f16 impl from arm
- refactor the testcase for x64
  • Loading branch information
xczhai committed Oct 12, 2024
1 parent e6cc130 commit eed3d2a
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1040,8 +1040,6 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
auto p = past_v_scale_zp.ptr<float>(pv, b_kv, h_group);
attn_acc_value(buf_attn_score.ptr<T3>(ithr, b, 0, h_group),
buf_attn_w.ptr<T3>(b, h_group, 0, pv)[0],
attn_acc_value(buf_attn_score.ptr<float>(ithr, b, 0, h_group),
buf_attn_w.ptr<float>(b, h_group, 0, pv)[0],
v,
S,
p + 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTest,
ConcatSDPTest,
::testing::Combine(::testing::Values(ElementType::f16),
::testing::ValuesIn(inputShapes),
::testing::Values(false),
::testing::Values(true, false)),
ConcatSDPTest::getTestCaseName);
} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ namespace test {
std::string ConcatSDPTest::getTestCaseName(const testing::TestParamInfo<ConcatSDPTestParams>& obj) {
ElementType inType;
std::vector<InputShape> inputShapes;
bool hasShapeof;
std::tie(inType, inputShapes, hasShapeof) = obj.param;
bool forceKVU8;
bool hasShapeOf;
std::tie(inType, inputShapes, forceKVU8, hasShapeOf) = obj.param;
std::ostringstream result;
result << "IS=";
for (const auto& shape : inputShapes) {
Expand All @@ -46,21 +47,24 @@ std::string ConcatSDPTest::getTestCaseName(const testing::TestParamInfo<ConcatSD
result << ")_";
}
result << "Prc=" << inType << "_";
result << "HasShapeOf=" << hasShapeof;
result << "ForceKVU8=" << forceKVU8 << "_";
result << "HasShapeOf=" << hasShapeOf;
return result.str();
}

void ConcatSDPTest::SetUp() {
ElementType inType;
std::vector<InputShape> inputShapes;
std::tie(inType, inputShapes, hasShapeOf) = this->GetParam();
std::tie(inType, inputShapes, m_forceKVU8, m_hasShapeOf) = this->GetParam();
targetDevice = ov::test::utils::DEVICE_CPU;
rel_threshold = 1e-2f;
if (inType == ElementType::bf16) {
configuration.insert({"ENFORCE_BF16", "YES"});
rel_threshold = 0.01f;
} else if (inType == ElementType::f16) {
if (inType == ElementType::bf16 || inType == ElementType::f16) {
configuration.insert({"INFERENCE_PRECISION_HINT", ov::element::Type(inType).get_type_name()});
rel_threshold = 0.01f;
}

if (m_forceKVU8) {
configuration["KV_CACHE_PRECISION"] = "u8";
}
init_input_shapes(inputShapes);
ov::ParameterVector inputParams;
Expand Down Expand Up @@ -92,7 +96,7 @@ void ConcatSDPTest::SetUp() {
// |
// ShapeOf...
// The transformation 'SimplifyGatherShapeOf' will move ShapeOf to be the child of ReadValue
if (hasShapeOf) {
if (m_hasShapeOf) {
shapeof_k = std::make_shared<ov::op::v0::ShapeOf>(gatherK);
shapeof_v = std::make_shared<ov::op::v0::ShapeOf>(gatherV);
}
Expand All @@ -107,20 +111,20 @@ void ConcatSDPTest::SetUp() {
pastv_assign->set_friendly_name("pastv_w");

ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
if (hasShapeOf) {
if (m_hasShapeOf) {
results.push_back(std::make_shared<ov::op::v0::Result>(shapeof_k));
results.push_back(std::make_shared<ov::op::v0::Result>(shapeof_v));
}
SinkVector sinks{pastk_assign, pastv_assign};
function = std::make_shared<ov::Model>(results, sinks, inputParams, "ConcatSDP");
targetDevice = ov::test::utils::DEVICE_CPU;

functionRefs = function->clone();
pass::Manager manager;
// decompose ScaledDotProductAttention
manager.register_pass<ov::pass::ScaledDotProductAttentionDecomposition>();
manager.run_passes(functionRefs);
}

void ConcatSDPTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
std::vector<ov::Shape> shapes(4);
shapes[0] = targetInputStaticShapes[0];
Expand All @@ -129,13 +133,15 @@ void ConcatSDPTest::generate_inputs(const std::vector<ov::Shape>& targetInputSta
shapes[3] = targetInputStaticShapes[1];
SubgraphBaseTest::generate_inputs(shapes);
}

template<typename IT, typename T>
void strided_iota(IT first, size_t n, T value, T stride) {
for (size_t i = 0; i < n; i++) {
*first++ = value;
value += stride;
}
}

void ConcatSDPTest::generate(int idx, const std::vector<ov::Shape>& targetInputStaticShapes) {
inputs.clear();
auto create_input = [this] (std::shared_ptr<op::v0::Parameter> param, ov::Shape shape, float val) {
Expand Down Expand Up @@ -169,16 +175,19 @@ void ConcatSDPTest::generate(int idx, const std::vector<ov::Shape>& targetInputS
create_input(function->get_parameters()[3], targetInputStaticShapes[1], idx + 4.0f);
create_input(function->get_parameters()[4], ov::Shape{targetInputStaticShapes[0][0]}, idx + 0.0f);
}

void ConcatSDPTest::prepare() {
compile_model();
inferRequest = compiledModel.create_infer_request();
ASSERT_TRUE(inferRequest);
}

void ConcatSDPTest::reset() {
for (auto&& state : inferRequest.query_state()) {
state.reset();
}
}

std::vector<ov::Tensor> ConcatSDPTest::run_test(std::shared_ptr<ov::Model> model) {
function = model;
prepare();
Expand All @@ -201,6 +210,17 @@ std::vector<ov::Tensor> ConcatSDPTest::run_test(std::shared_ptr<ov::Model> model
}
TEST_P(ConcatSDPTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
ElementType inType;
std::vector<InputShape> inputShapes;
bool forceKVU8;
bool hasShapeOf;
std::tie(inType, inputShapes, forceKVU8, hasShapeOf) = this->GetParam();

if ((inType == ElementType::bf16 && !ov::with_cpu_x86_bfloat16()) ||
(inType == ElementType::f16 && !ov::with_cpu_x86_avx512_core_fp16())) {
GTEST_SKIP();
}

auto actualOutputs = run_test(function);
if (!hasShapeOf) {
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1);
Expand All @@ -216,9 +236,14 @@ TEST_P(ConcatSDPTest, CompareWithRefs) {
}
}
}

// the range of our result will exceed f16 max value and there may be 'inf'. In softmax, there is a step:
// v - max(v), if v is inf, the result of 'v-max(v)' will be nan
// use f32 as reference
if (inType == ElementType::f16) {
configuration["INFERENCE_PRECISION_HINT"] = "f32";
}

auto expectedOutputs = run_test(functionRefs);
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0);
for (size_t i = 0; i < actualOutputs.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace test {
template<typename IT, typename T>
void strided_iota(IT first, size_t n, T value, T stride);

typedef std::tuple<ElementType, std::vector<InputShape>, bool> ConcatSDPTestParams;
typedef std::tuple<ElementType, std::vector<InputShape>, bool, bool> ConcatSDPTestParams;

class ConcatSDPTest :
public testing::WithParamInterface<ConcatSDPTestParams>,
Expand All @@ -46,7 +46,8 @@ class ConcatSDPTest :
void prepare();
void reset();
std::vector<ov::Tensor> run_test(std::shared_ptr<ov::Model> model);
bool hasShapeOf;
bool m_forceKVU8;
bool m_hasShapeOf;
protected:
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
void SetUp() override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTest,
ConcatSDPTest,
::testing::Combine(::testing::Values(ElementType::f32),
::testing::ValuesIn(inputShapes),
::testing::Values(true, false),
::testing::Values(true, false)),
ConcatSDPTest::getTestCaseName);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/opsets/opset13.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp"

#include "custom/subgraph_tests/src/classes/concat_sdp.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "utils/cpu_test_utils.hpp"
#include "common_test_utils/ov_tensor_utils.hpp"

using namespace CPUTestUtils;

namespace ov {
namespace test {

namespace {
const std::vector<std::vector<InputShape>> inputShapes = {
// greedy search
{
// B, H, L1, S
{{1, 8, -1, 64}, {{1, 8, 10, 64}, {1, 8, 1, 64}, {1, 8, 1, 64}, {1, 8, 20, 64}, {1, 8, 1, 64}}},
// B, H, L0, S
{{1, 8, -1, 64}, {{1, 8, 0, 64}, {1, 8, 10, 64}, {1, 8, 11, 64}, {1, 8, 12, 64}, {1, 8, 32, 64}}},
},
// beam search
{
// B, H, L1, S
{{-1, 8, -1, 64}, {{4, 8, 10, 64}, {4, 8, 1, 64}, {4, 8, 1, 64}, {4, 8, 1, 64}, {4, 8, 1, 64}}},
// B, H, L0, S
{{-1, 8, -1, 64}, {{4, 8, 0, 64}, {4, 8, 10, 64}, {4, 8, 11, 64}, {4, 8, 12, 64}, {4, 8, 13, 64}}},
},
};

INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTest,
ConcatSDPTest,
::testing::Combine(::testing::Values(ElementType::bf16, ElementType::f16),
::testing::ValuesIn(inputShapes),
::testing::Values(true, false),
::testing::Values(true, false)),
ConcatSDPTest::getTestCaseName);

} // namespace

} // namespace test
} // namespace ov

0 comments on commit eed3d2a

Please sign in to comment.