Skip to content

Commit

Permalink
add test case for prepare_buffer_fusing
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-k-park committed Nov 29, 2024
1 parent 4166ab5 commit 87195ad
Showing 1 changed file with 61 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "convolution_inst.h"
#include "gather_inst.h"
#include "gemm_inst.h"
#include "kv_cache_inst.h"
#include "read_value_inst.h"
#include "reshape_inst.h"
#include "fully_connected_inst.h"
#include "permute_inst.h"
Expand Down Expand Up @@ -1494,3 +1496,62 @@ TEST(prepare_buffer_fusing, inner_axis_data_offset_with_gemm_user) {
auto& crop_node = prog->get_node("crop2").as<crop>();
ASSERT_FALSE(crop_node.can_be_optimized());
}

TEST(prepare_buffer_fusing, skip_reshape_batch_can_be_squeezed) {
auto& engine = get_test_engine();

auto input_beam_idx_lay = layout{ov::PartialShape{-1}, data_types::i32, format::bfyx};
auto input_present_lay = layout{ov::PartialShape{-1, 8, -1, 64}, data_types::f32, format::bfyx};
auto input_param_lay = layout{ov::PartialShape{1}, data_types::f32, format::bfyx};
auto gemm_input_lay = layout{ov::PartialShape{-1, -1, -1}, data_types::f32, format::bfyx};

ov::op::util::VariableInfo info{ov::PartialShape{-1, 8, -1, 64}, data_types::f32, "v0"};
auto input_kv_lay = layout{info.data_shape, info.data_type, format::bfyx};
topology topology(input_layout("beam_idx", input_beam_idx_lay),
input_layout("present", input_present_lay),
input_layout("param1", input_param_lay),
input_layout("param2", input_param_lay),
input_layout("param3", input_param_lay),
input_layout("gemm_input", gemm_input_lay),
read_value("kv_cache", std::vector<input_info>{}, info.variable_id, {input_kv_lay}),
gather("gather",
input_info("kv_cache"),
input_info("beam_idx"),
0, // axis
input_kv_lay.get_partial_shape().size(), // input rank
ov::Shape{}, // output shape
0, // batch_dim
true), // support_neg_ind
kv_cache("concat1", {input_info("gather"), input_info("present")}, info, 2, 0, false),
concatenation("concat2", {input_info("param1"), input_info("param2"), input_info("param3")}, 0),
reshape("reshape", input_info("concat1"), input_info("concat2"), false, ov::PartialShape{-1, -1, 64}, cldnn::reshape::reshape_mode::base),
gemm("gemm",
{input_info("gemm_input"), input_info("reshape")},
data_types::f32,
std::vector<int64_t>{ 0, 1, 2 },
std::vector<int64_t>{ 0, 1, 2 },
std::vector<int64_t>{ 0, 1, 2 },
1.0f,
0.0f),
reorder("reorder", input_info("gemm"), format::bfyx, data_types::f32));

ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
config.set_property(ov::intel_gpu::optimize_data(true));

network network(engine, topology, config);
auto reshape_inst = network.get_primitive("reshape");

ASSERT_EQ(reshape_inst->get_node().can_be_optimized(), true);
ASSERT_EQ(reshape_inst->can_be_optimized(), true);

auto pad = tensor(0);
pad.feature[0] = 1;
{
std::vector<tensor::value_type> dynamic_pad_mask;
const auto& dynamic_pad_dims = reshape_inst->get_output_layout(0).data_padding._dynamic_dims_mask;
for (size_t i = 0; i < dynamic_pad_dims.size(); i++)
dynamic_pad_mask.push_back(dynamic_pad_dims[i]);
ASSERT_EQ(tensor(dynamic_pad_mask, 0), pad);
}
}

0 comments on commit 87195ad

Please sign in to comment.