Skip to content

Commit

Permalink
Update prepare_buffer_fusing and propagate_padding for reshape
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Park <[email protected]>
  • Loading branch information
andrew-k-park committed Nov 28, 2024
1 parent 79493c2 commit 4166ab5
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 12 deletions.
19 changes: 19 additions & 0 deletions src/plugins/intel_gpu/src/graph/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,21 @@ std::vector<layout> gemm_inst::transform_input_layouts(const std::shared_ptr<con
return transposed_input_pshape;
};

auto get_input_padding = [&](const layout& layout, size_t input_rank, size_t output_rank) {
auto pad = layout.data_padding;
std::vector<tensor::value_type> pad_lower, pad_upper;
for (size_t i = 0; i < input_rank; i++) {
pad_lower.push_back(pad._lower_size[i]);
pad_upper.push_back(pad._upper_size[i]);
}

size_t ones_to_add = std::max(output_rank, static_cast<size_t>(4)) - input_rank;
pad_lower.insert(pad_lower.begin(), ones_to_add, 0);
pad_upper.insert(pad_upper.begin(), ones_to_add, 0);

return padding(pad_lower, pad_upper);
};

auto input0_pshape = input_layouts[0].get_partial_shape();
auto input1_pshape = input_layouts[1].get_partial_shape();

Expand All @@ -190,6 +205,10 @@ std::vector<layout> gemm_inst::transform_input_layouts(const std::shared_ptr<con
std::vector<layout> layouts = input_layouts;
layouts[0].set_partial_shape(transposed_input0_pshape);
layouts[1].set_partial_shape(transposed_input1_pshape);
if (layouts[0].data_padding)
layouts[0].data_padding = get_input_padding(layouts[0], input_rank, output_rank);
if (layouts[1].data_padding)
layouts[1].data_padding = get_input_padding(layouts[1], weight_rank, output_rank);

if (primitive->input_size() == 3) {
auto bias_pshape = input_layouts[2].get_partial_shape();
Expand Down
26 changes: 25 additions & 1 deletion src/plugins/intel_gpu/src/graph/include/reshape_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#include "intel_gpu/primitives/reshape.hpp"
#include "intel_gpu/runtime/tensor_accessor.hpp"
#include "openvino/core/partial_shape.hpp"
#include "concatenation_inst.h"
#include "crop_inst.h"
#include "kv_cache_inst.h"
#include "rope_inst.h"
#include "mvn_inst.h"
#include "primitive_inst.h"
Expand Down Expand Up @@ -50,6 +52,9 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {
return true;
}

if (batch_can_be_squeezed())
return true;

// TODO: This function is to limit condition to a specific case (crop + reshape) among cases for the base mode
if (!input().is_type<crop>())
return false;
Expand Down Expand Up @@ -91,6 +96,25 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {
return true;
}

bool batch_can_be_squeezed() const {
auto prim = typed_desc();
if (prim->mode == reshape::reshape_mode::base) {
if (!input().is_type<kv_cache>() || !prim->output_pattern.empty() || !get_dependency(1).is_type<concatenation>())
return false;

const auto& kv_cache_ps = input().get_output_layout(false).get_partial_shape();
const auto& concat_ps = get_dependency(1).get_output_layout(false).get_partial_shape();
if (concat_ps.size() != 1 || concat_ps[0].is_dynamic())
return false;

if (kv_cache_ps.size() - 1 != static_cast<size_t>(concat_ps[0].get_length()))
return false;

return true;
}
return false;
}

bool has_padding() const {
return (this->get_output_layout().data_padding
|| input().get_output_layout(false).data_padding
Expand Down Expand Up @@ -144,7 +168,7 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {
if (input_layout.data_padding.is_dynamic()) {
auto prim = typed_desc();
// TODO: If outer padding exists, ouput padding propagation is not supported in the base mode
if (prim->mode == reshape::reshape_mode::base)
if (prim->mode == reshape::reshape_mode::base && !batch_can_be_squeezed())
return;

ov::PartialShape pattern_shape = { static_cast<int64_t>(prim->output_pattern.size()) };
Expand Down
31 changes: 20 additions & 11 deletions src/plugins/intel_gpu/src/graph/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ namespace cldnn {
GPU_DEFINE_PRIMITIVE_TYPE_ID(reshape)

padding propagate_padding(const layout& in_layout, const ov::PartialShape& out_shape, reshape::reshape_mode mode, const ov::ITensorAccessor& ta) {
if (mode == reshape::reshape_mode::base)
return padding();

auto in_pad = in_layout.data_padding;
if (!in_pad.is_dynamic()) {
return padding();
}

std::vector<int64_t> axes;
if (auto t = ta(1)) {
axes = ov::get_tensor_data_as<int64_t, std::vector<int64_t>>(t);
} else {
OPENVINO_THROW("[GPU] Can't propagate padding for reshape op as axes data is not available");
// axes data is only needed when reshape mode is unsqueeze or squeeze
if (mode != reshape::reshape_mode::base) {
if (auto t = ta(1)) {
axes = ov::get_tensor_data_as<int64_t, std::vector<int64_t>>(t);
} else {
OPENVINO_THROW("[GPU] Can't propagate padding for reshape op as axes data is not available");
}
}

auto rank = in_layout.get_partial_shape().size();
Expand Down Expand Up @@ -76,7 +76,7 @@ padding propagate_padding(const layout& in_layout, const ov::PartialShape& out_s
update_pad_mask.push_back(0);
}
}
} else {
} else if (mode == reshape::reshape_mode::squeeze) {
std::unordered_set<int64_t> unique_axes;
std::transform(axes.begin(), axes.end(), std::inserter(unique_axes, unique_axes.end()), [=](int64_t axis) {
return ov::util::normalize(axis, rank);
Expand All @@ -96,6 +96,11 @@ padding propagate_padding(const layout& in_layout, const ov::PartialShape& out_s
return padding();
}
}
} else {
// padding propagation is allowed only if the batch dimension can be squeezed
update_pad_lower = std::vector<int32_t>(pad_lower.begin() + 1, pad_lower.end());
update_pad_upper = std::vector<int32_t>(pad_upper.begin() + 1, pad_upper.end());
update_pad_mask = std::vector<int32_t>(pad_mask.begin() + 1, pad_mask.end());
}

// TODO: rework this method
Expand Down Expand Up @@ -189,10 +194,14 @@ std::vector<layout> reshape_inst::calc_output_layouts(reshape_node const& node,
op.set_special_zero(prim->special_zero);
op.set_friendly_name(prim->id.c_str());
output_shapes = ov::op::v1::shape_infer(&op, input_shapes, ta);
// If the reshape is base mode, it is currently not set as can_be_optimized at prepare_buffer_fusing.
// So we can just run the reshape kernel
// If the reshape is base mode, it is currently not set as can_be_optimized at prepare_buffer_fusing
// On the other hand, it is only allowed if the batch dimension can be squeezed
// In other cases, we can just run the reshape kernel
// TODO: allow propagatable reshapes
out_pad = padding();
if (node.batch_can_be_squeezed())
out_pad = propagate_padding(input_layout, output_shapes[0], prim->mode, ta);
else
out_pad = padding();
break;
}
case reshape::reshape_mode::squeeze: {
Expand Down

0 comments on commit 4166ab5

Please sign in to comment.