Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-miotk committed Nov 24, 2024
1 parent d1bec7b commit 0a9756d
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 87 deletions.
16 changes: 1 addition & 15 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/rnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ struct RNNParams : public primitive_base<PType> {
const input_info& R,
const input_info& B,
const input_info& seq_lenghts,
const primitive_id& out1_prim_id = "",
const primitive_id& out2_prim_id = "",
const float clip = 0,
bool input_forget = false,
const std::vector<activation_func>& activations = {activation_func::logistic,
Expand All @@ -58,15 +56,13 @@ struct RNNParams : public primitive_base<PType> {
R(R),
B(B),
seq_lenghts(seq_lenghts),
out1_prim_id(out1_prim_id),
out2_prim_id(out2_prim_id),
clip(clip),
input_forget(input_forget),
activations(activations),
activation_params(activation_params),
offset_order(offset_order),
direction(direction) {
std::vector<std::string> pids{initial_hidden_state.pid, initial_cell_state.pid, W.pid, R.pid, B.pid, seq_lenghts.pid, out1_prim_id, out2_prim_id};
std::vector<std::string> pids{initial_hidden_state.pid, initial_cell_state.pid, W.pid, R.pid, B.pid, seq_lenghts.pid};
for (auto pid : pids) {
if (!pid.empty()) {
primitive_base<PType>::input.push_back(pid);
Expand All @@ -81,8 +77,6 @@ struct RNNParams : public primitive_base<PType> {
input_info R;
input_info B;
input_info seq_lenghts;
primitive_id out1_prim_id;
primitive_id out2_prim_id;
/// @brief Cell clip threshold T. It is applied to the input of activations [-T, T]. No clip is applied if it is not specified.
float clip;
bool input_forget;
Expand All @@ -108,8 +102,6 @@ struct RNNParams : public primitive_base<PType> {
seed = hash_combine(seed, W.pid);
seed = hash_combine(seed, R.pid);
seed = hash_combine(seed, B.pid);
seed = hash_combine(seed, out1_prim_id);
seed = hash_combine(seed, out2_prim_id);
seed = hash_combine(seed, clip);
seed = hash_range(seed, activations.begin(), activations.end());
for (auto& act_param : activation_params) {
Expand Down Expand Up @@ -141,8 +133,6 @@ struct RNNParams : public primitive_base<PType> {
cmp_fields(W) &&
cmp_fields(R) &&
cmp_fields(B) &&
cmp_fields(out1_prim_id) &&
cmp_fields(out2_prim_id) &&
cmp_fields(clip) &&
cmp_fields(activations) &&
cmp_fields(offset_order) &&
Expand All @@ -159,8 +149,6 @@ struct RNNParams : public primitive_base<PType> {
ob << R;
ob << B;
ob << seq_lenghts;
ob << out1_prim_id;
ob << out2_prim_id;
ob << clip;
ob << activations;
ob << activation_params;
Expand All @@ -177,8 +165,6 @@ struct RNNParams : public primitive_base<PType> {
ib >> R;
ib >> B;
ib >> seq_lenghts;
ib >> out1_prim_id;
ib >> out2_prim_id;
ib >> clip;
ib >> activations;
ib >> activation_params;
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/impls/ocl/lstm_cell.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2024 Intel Corporation
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down
41 changes: 38 additions & 3 deletions src/plugins/intel_gpu/src/graph/impls/onednn/lstm_seq_onednn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,41 @@ struct LSTMSeqImplementationManager : public ImplementationManager {

bool validate_impl(const program_node& node) const override {
assert(node.is_type<lstm_seq>());
return node.get_input_layout(0).format == cldnn::format::bfyx || node.get_input_layout(0).format == cldnn::format::fbyx \
const auto& info = node.get_program().get_engine().get_device_info();
if (info.arch == gpu_arch::unknown)
return false;

const auto& lstm_seq_node = node.as<lstm_seq>();
const auto& lstm_seq_prim = lstm_seq_node.get_primitive();
auto in0_dt = node.get_input_layout(0).data_type;
auto in1_dt = node.get_input_layout(1).data_type;
auto in2_dt = node.get_input_layout(2).data_type;
auto in3_dt = node.get_input_layout(3).data_type;
auto in4_dt = node.get_input_layout(4).data_type;
auto in5_dt = node.get_input_layout(5).data_type;
auto out0_dt = node.get_output_layout(0).data_type;
auto out1_dt = node.get_output_layout(1).data_type;
auto out2_dt = node.get_output_layout(2).data_type;
bool cell_state_check = one_of(in2_dt, {data_types::f16, data_types::bf16, data_types::f32}) &&
one_of(out2_dt, {data_types::f16, data_types::bf16, data_types::f32});
bool f16_case = everyone_is(data_types::f16, in0_dt, in1_dt, in3_dt, in4_dt, out0_dt, out1_dt);
bool bf16_case = everyone_is(data_types::bf16, in0_dt, in1_dt, in3_dt, in4_dt, out0_dt, out1_dt);
bool f32_case = everyone_is(data_types::f32, in0_dt, in1_dt, in3_dt, in4_dt, in5_dt, out0_dt, out1_dt);
bool u8u8u8_case = one_of(out0_dt, {data_types::u8, data_types::f32}) && everyone_is(data_types::i8, in3_dt, in4_dt) &&
everyone_is(data_types::u8, in0_dt, in1_dt, out1_dt) && everyone_is(data_types::f32, in2_dt, in5_dt, out2_dt);
bool f32u8f32_case = everyone_is(data_types::u8, in0_dt) && everyone_is(data_types::i8, in3_dt, in4_dt) &&
one_of(out0_dt, {data_types::u8, data_types::f32}) && everyone_is(data_types::f32, in1_dt, in5_dt, out1_dt);
bool s8s8s8_case = everyone_is(data_types::i8, in0_dt, in1_dt, out0_dt, out1_dt) && one_of(out0_dt, {data_types::i8, data_types::f32}) &&
everyone_is(data_types::f32, in2_dt, in5_dt, out2_dt);
bool f32s8f32_case = everyone_is(data_types::i8, in0_dt, in3_dt, in4_dt) && one_of(out0_dt, {data_types::i8, data_types::f32}) &&
everyone_is(data_types::f32, in1_dt, in5_dt, out1_dt);

if (!cell_state_check)
return false;
if (!f16_case && !f32_case && !bf16_case && !u8u8u8_case && !f32u8f32_case && !s8s8s8_case && !f32s8f32_case)
return false;

return node.get_input_layout(0).format == cldnn::format::bfyx || node.get_input_layout(0).format == cldnn::format::fbyx
|| node.get_input_layout(0).format == cldnn::format::ybfx;
}

Expand All @@ -31,12 +65,13 @@ struct LSTMSeqImplementationManager : public ImplementationManager {
std::vector<format::type> out_fmts(node.get_outputs_count(), format::any);

size_t out_rank = node.get_output_layout().get_rank();
for (size_t idx = 0 ; idx < node.get_dependencies().size() ; idx++) {
for (size_t idx = 0; idx < node.get_dependencies().size(); idx++) {
if (node.get_dependency(idx).is_constant())
continue;

auto target_format = format::get_default_format(out_rank);

if (idx == 0)
in_fmts[idx] = format::fbyx;
in_fmts[idx] = target_format;
}
out_fmts[0] = format::ybfx;
Expand Down
18 changes: 1 addition & 17 deletions src/plugins/intel_gpu/src/graph/include/lstm_cell_inst.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2024 Intel Corporation
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -15,15 +15,6 @@ struct typed_program_node<lstm_cell> : public typed_program_node_base<lstm_cell>

public:
using parent::parent;

program_node& input() const { return get_dependency(0); }
lstm_weights_order offset_order() const { return get_primitive()->offset_order; }
float clip() const {
float clip_val = get_primitive()->clip;
OPENVINO_ASSERT(clip_val >= 0, "Clip value < 0");
return clip_val;
}
ov::op::RecurrentSequenceDirection direction() const { return get_primitive()->direction; }
};

using lstm_cell_node = typed_program_node<lstm_cell>;
Expand All @@ -41,13 +32,6 @@ class typed_primitive_inst<lstm_cell> : public typed_primitive_inst_base<lstm_ce

public:
typed_primitive_inst(network& network, lstm_cell_node const& node);
lstm_weights_order offset_order() const { return get_typed_desc<lstm_cell>()->offset_order; }
float clip() const {
float clip_val = get_typed_desc<lstm_cell>()->clip;
OPENVINO_ASSERT(clip_val >= 0, "Clip value < 0");
return clip_val;
}
ov::op::RecurrentSequenceDirection direction() const { return get_typed_desc<lstm_cell>()->direction; }
bool has_cell() const { return !get_typed_desc<lstm_cell>()->initial_cell_state.pid.empty(); }
};

Expand Down
21 changes: 1 addition & 20 deletions src/plugins/intel_gpu/src/graph/include/lstm_seq_inst.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2024 Intel Corporation
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -15,18 +15,7 @@ struct typed_program_node<lstm_seq> : public typed_program_node_base<lstm_seq> {

public:
using parent::parent;

program_node& input() const { return get_dependency(0); }
lstm_weights_order offset_order() const { return get_primitive()->offset_order; }
float clip() const {
float clip_val = get_primitive()->clip;
OPENVINO_ASSERT(clip_val >= 0, "Clip value < 0");
return clip_val;
}
ov::op::RecurrentSequenceDirection direction() const { return get_primitive()->direction; }
std::vector<activation_func> activations() const {
return get_primitive()->activations;
}
};

using lstm_seq_node = typed_program_node<lstm_seq>;
Expand All @@ -44,14 +33,6 @@ class typed_primitive_inst<lstm_seq> : public typed_primitive_inst_base<lstm_seq

public:
typed_primitive_inst(network& network, lstm_seq_node const& node);
lstm_weights_order offset_order() const { return get_typed_desc<lstm_seq>()->offset_order; }
float clip() const {
float clip_val = get_typed_desc<lstm_seq>()->clip;
if (clip_val < 0)
throw std::range_error("Clip value < 0");
return clip_val;
}
uint32_t direction() const { return get_typed_desc<lstm_seq>()->direction == ov::op::RecurrentSequenceDirection::FORWARD ? 0 : 1; }
bool has_cell() const { return !get_typed_desc<lstm_seq>()->initial_cell_state.pid.empty(); }
};

Expand Down
8 changes: 2 additions & 6 deletions src/plugins/intel_gpu/src/graph/layout_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
#include <vector>
#include <memory>
#include <utility>
#include <openvino/op/constant.hpp>

#include "pass_manager.h"

Expand Down Expand Up @@ -1363,9 +1362,6 @@ format layout_optimizer::get_preferred_format(program_node& node) {
node.as<dft>().get_primitive()->direction == dft_direction::forward) {
node.set_preferred_input_fmt(0, format::get_default_format(node.get_input_layouts()[0].get_rank()));
}
} else if (node.is_type<lstm_seq>()) {
node.set_preferred_input_fmt(0, format::fbyx);
expected = format::fbyx;
}

if (allow_new_shape_infer && node.get_preferred_input_fmt() != format::any) {
Expand Down Expand Up @@ -1440,8 +1436,8 @@ void layout_optimizer::add_all_onednn_impls_optimization_attribute() {
}

bool layout_optimizer::has_all_enabled_onednn_impls_optimization_attribute() {
return is_enabled_onednn_for<concatenation>() && is_enabled_onednn_for<convolution>() && is_enabled_onednn_for<deconvolution>() && \
is_enabled_onednn_for<fully_connected>() && is_enabled_onednn_for<gemm>() && is_enabled_onednn_for<lstm_seq>() && \
return is_enabled_onednn_for<concatenation>() && is_enabled_onednn_for<convolution>() && is_enabled_onednn_for<deconvolution>() &&
is_enabled_onednn_for<fully_connected>() && is_enabled_onednn_for<gemm>() && is_enabled_onednn_for<lstm_seq>() &&
is_enabled_onednn_for<pooling>() && is_enabled_onednn_for<reduce>() && is_enabled_onednn_for<reorder>();
}

Expand Down
31 changes: 12 additions & 19 deletions src/plugins/intel_gpu/src/plugin/ops/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,11 @@ static void CreateLSTMCellOp(ProgramBuilder& p, const std::shared_ptr<ov::op::v4
std::vector<cldnn::activation_additional_params> activation_params;
GetLSTMActivationParams(op, activations, activation_params);
float clip = op->get_clip();
assert(!inputs[5].pid.empty());
if (p.use_new_shape_infer()) {
p.add_primitive(*op, cldnn::lstm_cell(layerName+".out0", inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], \
cldnn::input_info(), "", layerName + "_md_write.1", clip, false, activations, \
activation_params, cldnn::lstm_weights_order::fizo, ov::op::RecurrentSequenceDirection::FORWARD, cldnn::padding(), \
static_cast<int>(op->get_output_size())));
}
OPENVINO_ASSERT(!inputs[5].pid.empty());
OPENVINO_ASSERT(p.use_new_shape_infer());
p.add_primitive(*op, cldnn::lstm_cell(layerName+".out0", inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], cldnn::input_info(),
clip, false, activations, activation_params, cldnn::lstm_weights_order::fizo, ov::op::RecurrentSequenceDirection::FORWARD, cldnn::padding(),
static_cast<int>(op->get_output_size())));
}

static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op::v5::LSTMSequence>& op) {
Expand All @@ -90,20 +88,15 @@ static void CreateLSTMSequenceOp(ProgramBuilder& p, const std::shared_ptr<ov::op
std::vector<cldnn::activation_additional_params> activation_params;
GetLSTMActivationParams(op, activations, activation_params);
const float clip = op->get_clip();
if (op->get_input_shape(2).size() != 3 || op->get_input_shape(3).size() != 1 \
|| op->get_input_shape(4).size() != 3 || op->get_input_shape(5).size() != 3 || op->get_input_shape(6).size() != 2) {
OPENVINO_THROW("Wrong input shapes for LSTMSequence op ", op->get_friendly_name());
}
OPENVINO_ASSERT(op->get_input_shape(2).size() == 3 && op->get_input_shape(3).size() == 1 && op->get_input_shape(4).size() == 3 &&
op->get_input_shape(5).size() == 3 && op->get_input_shape(6).size() == 2, "Wrong input shapes for LSTMSequence op ", op->get_friendly_name());
auto direction = op->get_direction();

if (p.use_new_shape_infer()) {
cldnn::lstm_seq prim(layerName, inputs[0], inputs[1], \
inputs[2], inputs[4], inputs[5], inputs[6], inputs[3], "", "", \
clip, false, activations, activation_params, cldnn::lstm_weights_order::fizo, direction, cldnn::padding(), \
static_cast<int>(op->get_output_size()));
prim.output_data_types = get_output_data_types(op);
p.add_primitive(*op, prim);
}
OPENVINO_ASSERT(p.use_new_shape_infer());
cldnn::lstm_seq prim(layerName, inputs[0], inputs[1], inputs[2], inputs[4], inputs[5], inputs[6], inputs[3], clip, false, activations,
activation_params, cldnn::lstm_weights_order::fizo, direction, cldnn::padding(), static_cast<int>(op->get_output_size()));
prim.output_data_types = get_output_data_types(op);
p.add_primitive(*op, prim);
}

REGISTER_FACTORY_IMPL(v4, LSTMCell);
Expand Down
6 changes: 0 additions & 6 deletions src/plugins/intel_gpu/src/plugin/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,6 @@ ov::SupportedOpsMap Plugin::query_model(const std::shared_ptr<const ov::Model>&

ExecutionConfig config = m_configs_map.at(device_id);
config.set_user_property(orig_config);
if (ctx->get_engine().get_device_info().supports_immad) {
config.set_property(ov::intel_gpu::use_onednn(true));
}
config.apply_user_properties(ctx->get_engine().get_device_info());

ProgramBuilder prog(ctx->get_engine(), config);
Expand Down Expand Up @@ -330,9 +327,6 @@ std::shared_ptr<ov::ICompiledModel> Plugin::import_model(std::istream& model,

ExecutionConfig config = m_configs_map.at(device_id);
config.set_user_property(_orig_config);
if (context_impl->get_engine().get_device_info().supports_immad) {
config.set_property(ov::intel_gpu::use_onednn(true));
}
config.apply_user_properties(context_impl->get_engine().get_device_info());

cldnn::BinaryInputBuffer ib(model, context_impl->get_engine());
Expand Down

0 comments on commit 0a9756d

Please sign in to comment.