Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU]: SearchSorted basic implementation. #27356

Merged
merged 22 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7835bb3
[GPU]: Added stub for SearchSorted
pkowalc1 Oct 31, 2024
ae7aebf
[gpu]:[SearchSorted]: Added unittests.
pkowalc1 Nov 4, 2024
1c99ccd
[GPU]: SearchSorted: Added comments.
pkowalc1 Nov 6, 2024
824ac24
[GPU]: SearchSorted: Fixed unittests.
pkowalc1 Nov 6, 2024
1f4308c
[GPU]: SearchSorted: WIP
pkowalc1 Nov 6, 2024
985c2dc
[GPU]: SearchSorted: WIP.
pkowalc1 Nov 6, 2024
ce3604b
[GPU]: SearchSorted: WIP.
pkowalc1 Nov 6, 2024
11e5186
[GPU]: SearchSorted: WIP.
pkowalc1 Nov 7, 2024
af8af33
[GPU]: SortedSearch: Added ref impl for static shapes.
pkowalc1 Nov 12, 2024
72cc70e
[GPU]: SearchSorted: Changed unitttests data to handle unsigned types…
pkowalc1 Nov 13, 2024
4dc8abf
[GPU]: SearchSorte: Fixed and added more unitttests.
pkowalc1 Nov 13, 2024
117f216
[TEMPLATE]: SearchSorted: Fixed a bug when sorted has exactly one ele…
pkowalc1 Nov 14, 2024
bf46223
[GPU]: Added func tests for SearchSorted
pkowalc1 Nov 14, 2024
8f8c333
[GPU]: SearchSorted: Added dynamic shape support.
pkowalc1 Nov 14, 2024
a7f3a5e
[GPU]: SearchSorted: Added hack to properly handle inputs for searchs…
pkowalc1 Nov 15, 2024
071b9fb
[GPU]: SearchSorted: Fixed kernel for dynamic shapes.
pkowalc1 Nov 15, 2024
392ce25
[GPU]: SearchSorted: Fixed compilation warnings.
pkowalc1 Nov 15, 2024
294a498
[gpu]: SearchSorted: review fixes.
pkowalc1 Nov 25, 2024
6610323
[GPU]: SearchSorted: Removed not need code.
pkowalc1 Nov 25, 2024
d4af5e1
Merge branch 'master' into search_sorted_basic_gpu_impl
pkowalc1 Nov 25, 2024
6fb9fc2
Merge branch 'master' into search_sorted_basic_gpu_impl
pkowalc1 Nov 26, 2024
f54730f
[GPU]:SearchSorted: Fixed naming.
pkowalc1 Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void search_sorted(const T* sorted,
}

const size_t size = shape_size(values_shape);
const size_t sorted_inner_dim = sorted_shape.back();

auto func = [&](size_t i) {
auto it = values_transform.begin();
Expand All @@ -44,15 +45,12 @@ void search_sorted(const T* sorted,
Coordinate sorted_coord_begin = values_coord;
sorted_coord_begin.back() = 0;

Coordinate sorted_coord_last = values_coord;
sorted_coord_last.back() = sorted_shape.back();

const auto sorted_index_begin = coordinate_index(sorted_coord_begin, sorted_shape);
const auto sorted_index_last = coordinate_index(sorted_coord_last, sorted_shape);

const T* idx_ptr = compare_func(sorted + sorted_index_begin, sorted + sorted_index_last, value);
const T* sorted_begin_ptr = sorted + sorted_index_begin;
const T* sorted_end_ptr = sorted_begin_ptr + sorted_inner_dim;
const T* idx_ptr = compare_func(sorted_begin_ptr, sorted_end_ptr, value);

const ptrdiff_t sorted_index = (idx_ptr - sorted) - sorted_index_begin;
const ptrdiff_t sorted_index = idx_ptr - sorted_begin_ptr;

out[values_index] = static_cast<TOut>(sorted_index);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ REGISTER_FACTORY(v13, BitwiseXor);
REGISTER_FACTORY(v15, ROIAlignRotated);
REGISTER_FACTORY(v15, BitwiseRightShift);
REGISTER_FACTORY(v15, BitwiseLeftShift);
REGISTER_FACTORY(v15, SearchSorted);

// --------------------------- Supported internal ops --------------------------- //
REGISTER_FACTORY(internal, NonMaxSuppressionIEInternal);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include "primitive.hpp"

namespace cldnn {

/// @brief
/// @details
struct search_sorted : public primitive_base<search_sorted> {
CLDNN_DECLARE_PRIMITIVE(search_sorted)

search_sorted() : primitive_base("", {}) {}

/// @brief Constructs search_sorted primitive.
/// @param id This primitive id.
/// @param sorted Sorted input.
/// @param values Values input.
/// @param right_mode Enable/Disable right mode(check specification for details)..
search_sorted(const primitive_id& id, const input_info& sorted, const input_info& values, bool right_mode)
: primitive_base(id, {sorted, values}),
right_mode(right_mode) {}

/// @brief Enable/Disable right mode(check specification for details).
bool right_mode = false;

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, right_mode);
return seed;
}

bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;

auto rhs_casted = downcast<const search_sorted>(rhs);

return right_mode == rhs_casted.right_mode;
}

void save(BinaryOutputBuffer& ob) const override {
primitive_base<search_sorted>::save(ob);
ob << right_mode;
}

void load(BinaryInputBuffer& ib) override {
primitive_base<search_sorted>::load(ib);
ib >> right_mode;
}
};
} // namespace cldnn
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ void register_implementations() {
REGISTER_OCL(unique_gather);
REGISTER_OCL(scaled_dot_product_attention);
REGISTER_OCL(rope);
REGISTER_OCL(search_sorted);
}

} // namespace ocl
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ REGISTER_OCL(unique_count);
REGISTER_OCL(unique_gather);
REGISTER_OCL(scaled_dot_product_attention);
REGISTER_OCL(rope);
REGISTER_OCL(search_sorted);

#undef REGISTER_OCL

Expand Down
107 changes: 107 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "primitive_base.hpp"
#include "search_sorted/search_sorted_kernel_base.h"
#include "search_sorted/search_sorted_kernel_selector.h"
#include "search_sorted_inst.h"

namespace cldnn {
namespace ocl {

struct search_sorted_impl : typed_primitive_impl_ocl<search_sorted> {
using parent = typed_primitive_impl_ocl<search_sorted>;
using parent::parent;
using kernel_selector_t = kernel_selector::search_sorted_kernel_selector;
using kernel_params_t = kernel_selector::search_sorted_params;

DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::search_sorted_impl)

std::unique_ptr<primitive_impl> clone() const override {
return make_unique<search_sorted_impl>(*this);
}

void load(BinaryInputBuffer& ib) override {
parent::load(ib);
if (is_dynamic()) {
auto& kernel_selector = kernel_selector_t::Instance();
auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName);
kernel_impl->GetUpdateDispatchDataFunc(_kernel_data);
}
}

void update_dispatch_data(const kernel_impl_params& impl_param) override {
// If model loaded from cache, params are not initialized, so we create a new object and reuse it in the future
if (_kernel_data.params == nullptr) {
_kernel_data.params = std::make_shared<kernel_params_t>(get_kernel_params(impl_param, true));
}

update_shapes(*_kernel_data.params, impl_param);
(_kernel_data.update_dispatch_data_func)(*_kernel_data.params, _kernel_data);
}

static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool shape_agnostic = false) {
const auto& primitive = impl_param.typed_desc<search_sorted>();
auto params = get_default_params<kernel_selector::search_sorted_params>(impl_param, shape_agnostic);

// Manually add all inputs except first one, since get_default_params does not handle it.
for (size_t i = 1; i < impl_param.input_layouts.size(); ++i) {
params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(i)));
}

params.right_mode = primitive->right_mode;
return params;
}

// [NOTE]: Has to be added as a separete static function, since it is called via static dispatching in
// typed_primitive_impl_ocl::create()..
static kernel_impl_params static_canonicalize_shapes(const kernel_impl_params& impl_params) {
auto updated_impl_params = canonicalize_fused_shapes(impl_params);

for (auto& input_layout : updated_impl_params.input_layouts) {
input_layout.set_partial_shape(extend_shape_to_rank_from_begin(input_layout.get_partial_shape()));
}

for (auto& output_layout : updated_impl_params.output_layouts) {
output_layout.set_partial_shape(extend_shape_to_rank_from_begin(output_layout.get_partial_shape()));
}

return updated_impl_params;
}

kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override {
return static_canonicalize_shapes(impl_params);
}
};

namespace detail {

attach_search_sorted_impl::attach_search_sorted_impl() {
auto types = {
data_types::i8,
data_types::u8,
data_types::i16,
data_types::u16,
data_types::i32,
data_types::u32,
data_types::i64,
data_types::f16,
data_types::f32,
};

auto formats = {format::bfyx, format::bfzyx};

implementation_map<search_sorted>::add(impl_types::ocl,
mlukasze marked this conversation as resolved.
Show resolved Hide resolved
shape_types::any,
typed_primitive_impl_ocl<search_sorted>::create<search_sorted_impl>,
types,
formats);
}

} // namespace detail
} // namespace ocl
} // namespace cldnn

BIND_BINARY_BUFFER_WITH_TYPE(cldnn::ocl::search_sorted_impl)
BIND_BINARY_BUFFER_WITH_TYPE(cldnn::search_sorted)
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,4 @@ REGISTER_DEFAULT_IMPLS(unique_count, OCL_S, OCL_D);
REGISTER_DEFAULT_IMPLS(unique_gather, OCL_S, OCL_D);
REGISTER_DEFAULT_IMPLS(scaled_dot_product_attention, OCL_S, OCL_D);
REGISTER_DEFAULT_IMPLS(rope, OCL_S, OCL_D);
REGISTER_DEFAULT_IMPLS(search_sorted, OCL_S, OCL_D);
46 changes: 46 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/search_sorted_inst.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once

#include <intel_gpu/primitives/search_sorted.hpp>

#include "primitive_inst.h"

namespace cldnn {

template <>
struct typed_program_node<search_sorted> : public typed_program_node_base<search_sorted> {
using parent = typed_program_node_base<search_sorted>;
typed_program_node(const std::shared_ptr<search_sorted> prim, program& prog) : parent(prim, prog) {}

public:
using parent::parent;

program_node& input(size_t idx = 0) const {
return get_dependency(idx);
}
std::vector<size_t> get_shape_infer_dependencies() const override {
return {};
}
};

using search_sorted_node = typed_program_node<search_sorted>;

template <>
class typed_primitive_inst<search_sorted> : public typed_primitive_inst_base<search_sorted> {
using parent = typed_primitive_inst_base<search_sorted>;
using parent::parent;

public:
typed_primitive_inst(network& network, search_sorted_node const& desc);
template <typename ShapeType>
static std::vector<layout> calc_output_layouts(search_sorted_node const& node,
kernel_impl_params const& impl_param);
static layout calc_output_layout(search_sorted_node const& node, kernel_impl_params const& impl_param);
static std::string to_string(search_sorted_node const& node);
};

using search_sorted_inst = typed_primitive_inst<search_sorted>;

} // namespace cldnn
59 changes: 59 additions & 0 deletions src/plugins/intel_gpu/src/graph/search_sorted.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <json_object.h>
#include <search_sorted_inst.h>

#include <sstream>

#include "openvino/core/enum_names.hpp"
#include "primitive_type_base.h"
#include "search_sorted_shape_inference.hpp"

namespace cldnn {
GPU_DEFINE_PRIMITIVE_TYPE_ID(search_sorted)

search_sorted_inst::typed_primitive_inst(network& network, search_sorted_node const& node) : parent(network, node) {}

layout search_sorted_inst::calc_output_layout(search_sorted_node const& node, kernel_impl_params const& impl_param) {
return calc_output_layouts<ov::PartialShape>(node, impl_param)[0];
}

template <typename ShapeType>
std::vector<layout> search_sorted_inst::calc_output_layouts(search_sorted_node const& node,
kernel_impl_params const& impl_param) {
auto primitive = impl_param.typed_desc<search_sorted>();

auto input0_layout = impl_param.get_input_layout(0);
auto input1_layout = impl_param.get_input_layout(1);

const data_types output_type = impl_param.desc->output_data_types[0].value_or(data_types::i64);

std::vector<ShapeType> input_shapes = {
input0_layout.get<ShapeType>(), // sorted shape
input1_layout.get<ShapeType>(), // values shape
};

std::vector<ShapeType> output_shapes;

ov::op::v15::SearchSorted op;
op.set_right_mode(primitive->right_mode);
output_shapes = shape_infer(&op, input_shapes);

return {layout{output_shapes[0], output_type, input1_layout.format}};
}

std::string search_sorted_inst::to_string(search_sorted_node const& node) {
auto node_info = node.desc_to_json();
json_composite search_sorted_info;
search_sorted_info.add("sorted id", node.input(0).id());
search_sorted_info.add("values id", node.input(1).id());
search_sorted_info.add("right_mode", node.get_primitive()->right_mode);
node_info->add("search_sorted info", search_sorted_info);
std::stringstream primitive_description;
node_info->dump(primitive_description);
return primitive_description.str();
}

} // namespace cldnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "include/batch_headers/fetch_data.cl"

#if RIGHT_MODE == 0
#define CMP <=
#else
#define CMP <
#endif

OUTPUT_TYPE FUNC(binary_search_thread)(const INPUT0_TYPE search_val,
const __global INPUT0_TYPE* restrict sorted,
OUTPUT_TYPE sorted_begin_idx,
OUTPUT_TYPE sorted_end_idx) {
while(sorted_begin_idx != sorted_end_idx) {
const OUTPUT_TYPE half_offset = (sorted_end_idx-sorted_begin_idx)/2;
const OUTPUT_TYPE half_idx = sorted_begin_idx+half_offset;
const INPUT0_TYPE half_val = sorted[half_idx];
const bool comp_result = half_val CMP search_val;
if ( search_val CMP half_val )
sorted_end_idx = half_idx;
else
sorted_begin_idx = half_idx + 1;
}

return sorted_begin_idx;
}

#undef CMP

KERNEL(search_sorted_ref)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* restrict sorted,
const __global INPUT1_TYPE* restrict values,
__global OUTPUT_TYPE* restrict output)
{
// INPUT0_TYPE has to be egual to INPUT1_TYPE
const int this_thread_idx = get_global_id(0);
const INPUT0_TYPE search_val = values[this_thread_idx];

const int SORTED_STRIDE = INPUT0_BATCH_NUM*INPUT0_FEATURE_NUM*INPUT0_SIZE_Y*INPUT0_SIZE_Z;

// NOTE: SORTED_STRIDE-1 handles here a special case when sorted is actually 1D
// tensor and values is ND tensor. In such case we effectively want sorted_offset
// to be 0.
const int sorted_offset = min(this_thread_idx/INPUT1_SIZE_X, SORTED_STRIDE-1);

OUTPUT_TYPE sorted_begin_idx = sorted_offset * INPUT0_SIZE_X;
const OUTPUT_TYPE idx = FUNC_CALL(binary_search_thread)(search_val,
sorted + sorted_begin_idx,
0,
INPUT0_SIZE_X);

output[this_thread_idx] = idx;
}
3 changes: 2 additions & 1 deletion src/plugins/intel_gpu/src/kernel_selector/common_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ enum class KernelType {
RMS,
SWIGLU,
ROPE,
DYNAMIC_QUANTIZE
DYNAMIC_QUANTIZE,
SEARCH_SORTED
};

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading
Loading