From 23565176d4d7c01f0eb991fa686a3cbf82a09f39 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sun, 2 Apr 2023 14:13:46 +0800 Subject: [PATCH 01/23] [phi] move sequence_pool kernel to phi --- .../sequence_ops/sequence_pool_op.cc | 3 - .../sequence_ops/sequence_pool_op.cu | 2 - .../operators/sequence_ops/sequence_pool_op.h | 69 --- .../sequence_ops/unity_build_rule.cmake | 2 - .../phi/kernels/cpu/sequence_pool_kernel.cc | 95 ++++ paddle/phi/kernels/funcs/sequence_pooling.cc | 502 ++++++++++++++++++ paddle/phi/kernels/funcs/sequence_pooling.cu | 501 +++++++++++++++++ paddle/phi/kernels/funcs/sequence_pooling.h | 50 ++ .../phi/kernels/gpu/sequence_pool_kernel.cu | 95 ++++ paddle/phi/kernels/sequence_pool_kernel.h | 30 ++ 10 files changed, 1273 insertions(+), 76 deletions(-) create mode 100644 paddle/phi/kernels/cpu/sequence_pool_kernel.cc create mode 100644 paddle/phi/kernels/funcs/sequence_pooling.cc create mode 100644 paddle/phi/kernels/funcs/sequence_pooling.cu create mode 100644 paddle/phi/kernels/funcs/sequence_pooling.h create mode 100644 paddle/phi/kernels/gpu/sequence_pool_kernel.cu create mode 100644 paddle/phi/kernels/sequence_pool_kernel.h diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index 938b23a22a63c..c44427f98f211 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -196,9 +196,6 @@ REGISTER_OPERATOR(sequence_pool, REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp, ops::SequencePoolGradOpNoNeedBufferVarsInferer); -REGISTER_OP_CPU_KERNEL(sequence_pool, - ops::SequencePoolKernel, - ops::SequencePoolKernel); REGISTER_OP_CPU_KERNEL(sequence_pool_grad, ops::SequencePoolGradKernel, diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu index 882ec66f501db..df5dde79274f9 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu @@ -14,7 +14,5 @@ limitations under the License. */ #include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(sequence_pool, - ops::SequencePoolKernel); REGISTER_OP_CUDA_KERNEL(sequence_pool_grad, ops::SequencePoolGradKernel); diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h index ddf0d496a77fb..bcba5590bc567 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h @@ -23,75 +23,6 @@ limitations under the License. */ namespace paddle { namespace operators { -template -class SequencePoolKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - std::string pooltype = context.Attr("pooltype"); - T pad_value = static_cast(context.Attr("pad_value")); - - auto dims = in->dims(); - auto lod = in->lod(); - auto lod_level = lod.size(); - // InferShape by lod - PADDLE_ENFORCE_GT(lod_level, - 0, - platform::errors::InvalidArgument( - "Input(X) phi::DenseTensor of SequencePoolOp " - "does not contain LoD information.")); - PADDLE_ENFORCE_LE(lod_level, - 2UL, - platform::errors::InvalidArgument( - "The lod level of input shall be no more than 2." - "Received lod level is %d.", - lod_level)); - PADDLE_ENFORCE_GE( - dims[0], - /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), - platform::errors::InvalidArgument( - "The first dimension of Input(X) must be large than batch size." - "But received first dimension of Input(X) is %d, while batch" - "size is %d.", - dims[0], - static_cast(lod[lod_level - 1].size() - 1))); - if (lod_level > 1UL) { - PADDLE_ENFORCE_EQ(lod[0][lod[0].size() - 1], - lod[1].size() - 1, - platform::errors::InvalidArgument( - "The input lod information is illegal.")); - framework::LoD out_lod; - out_lod.push_back(lod[0]); - out->set_lod(out_lod); - } - dims[0] = lod[lod_level - 1].size() - 1; - out->Resize({dims}); - out->mutable_data(context.GetPlace()); - phi::DenseTensor* index = nullptr; - - bool is_test = - context.HasAttr("is_test") ? context.Attr("is_test") : false; - - // Do not create index buffer for inference mode - if (pooltype == "MAX" && - (is_test == false || - platform::is_cpu_place(context.GetPlace()) == false)) { - index = context.Output("MaxIndex"); - index->Resize({dims}); - index->mutable_data(context.GetPlace()); - } - math::SequencePoolFunctor pool; - pool(context.template device_context(), - pooltype, - pad_value, - *in, - out, - is_test, - index); - } -}; - template class SequencePoolGradKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake b/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake index 9a87e27b24197..2b66923df028e 100644 --- a/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake +++ b/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake @@ -13,7 +13,6 @@ register_unity_group( sequence_expand_op.cc sequence_mask_op.cc sequence_pad_op.cc - sequence_pool_op.cc sequence_expand_as_op.cc sequence_reshape_op.cc sequence_reverse_op.cc @@ -31,7 +30,6 @@ register_unity_group( sequence_expand_op.cu sequence_mask_op.cu sequence_pad_op.cu - sequence_pool_op.cu sequence_expand_as_op.cu sequence_reshape_op.cu sequence_reverse_op.cu diff --git a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc new file mode 100644 index 0000000000000..6912e17df30da --- /dev/null +++ b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc @@ -0,0 +1,95 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sequence_pool_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void SequencePoolKernel(const Context& ctx, + const DenseTensor& x, + bool is_test, + std::string pooltype, + float pad_value, + DenseTensor* out, + DenseTensor* max_index) { + Context pad_value_ = static_cast(pad_value); + + auto dims = x.dims(); + auto lod = x.lod(); + auto lod_level = lod.size(); + // InferShape by lod + PADDLE_ENFORCE_GT( + lod_level, + 0, + errors::InvalidArgument("Input(X) phi::DenseTensor of SequencePoolOp " + "does not contain LoD information.")); + PADDLE_ENFORCE_LE( + lod_level, + 2UL, + errors::InvalidArgument("The lod level of input shall be no more than 2." + "Received lod level is %d.", + lod_level)); + PADDLE_ENFORCE_GE( + dims[0], + /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), + errors::InvalidArgument( + "The first dimension of Input(X) must be large than batch size." + "But received first dimension of Input(X) is %d, while batch" + "size is %d.", + dims[0], + static_cast(lod[lod_level - 1].size() - 1))); + if (lod_level > 1UL) { + PADDLE_ENFORCE_EQ( + lod[0][lod[0].size() - 1], + lod[1].size() - 1, + errors::InvalidArgument("The input lod information is illegal.")); + phi::LoD out_lod; + out_lod.push_back(lod[0]); + out->set_lod(out_lod); + } + dims[0] = lod[lod_level - 1].size() - 1; + out->Resize({dims}); + // out->mutable_data(ctx.GetPlace()); + ctx.template Alloc(out); + phi::DenseTensor* index = nullptr; + + bool is_test_ = ctx.HasAttr("is_test") ? is_test : false; + + auto& place = *ctx.eigen_device(); + + // Do not create index buffer for inference mode + if (pooltype == "MAX" && is_test_ == false) { + index = max_index; + index->Resize({dims}); + // index->mutable_data(ctx.GetPlace()); + ctx.template Alloc(index); + } + math::SequencePoolFunctor pool; + pool(ctx.template device_context(), + pooltype, + pad_value_, + &x, + out, + is_test_, + index); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + sequence_pool, CPU, ALL_LAYOUT, phi::SequencePoolKernel, float, double) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/funcs/sequence_pooling.cc b/paddle/phi/kernels/funcs/sequence_pooling.cc new file mode 100644 index 0000000000000..7d235b7ad277b --- /dev/null +++ b/paddle/phi/kernels/funcs/sequence_pooling.cc @@ -0,0 +1,502 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/funcs/sequence_pooling.h" + +#include + +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/jit/kernels.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { +namespace math { + +template +using EigenVector = phi::EigenVector; +template +using EigenMatrix = phi::EigenMatrix; + +template +class MaxSeqPoolFunctor { + public: + void operator()(const phi::CPUContext& context, + const phi::DenseTensor& input, + T pad_value, + phi::DenseTensor* output, + phi::DenseTensor* index) { + auto in_dims = input.dims(); + auto out_dims = output->dims(); + auto idx_dims = index->dims(); + PADDLE_ENFORCE_GT(in_dims.size(), + 1, + errors::InvalidArgument( + "The rank of input shall be greater than 1, but got " + "the rank is %ld. Please check the input value", + in_dims.size())); + PADDLE_ENFORCE_GT(out_dims.size(), + 1, + errors::InvalidArgument( + "The rank of output shall be greater than 1, but got " + "the rank is %ld. Please check the input value", + out_dims.size())); + for (int64_t i = 1; i < in_dims.size(); ++i) { + PADDLE_ENFORCE_EQ( + in_dims[i], + out_dims[i], + errors::InvalidArgument( + "The dimension of input and output shall be same. Expected %ld " + "== %ld, but got %ld != %ld. Please check the input value.", + in_dims[i], + out_dims[i], + in_dims[i], + out_dims[i])); + } + PADDLE_ENFORCE_EQ( + idx_dims, + out_dims, + errors::InvalidArgument( + "The dimension of index and output shall be same. Expected %ld == " + "%ld, but got %ld != %ld. Please check the input value.", + idx_dims, + out_dims, + idx_dims, + out_dims)); + + auto lod_level = input.lod().size(); + auto starts = input.lod()[lod_level - 1]; + const T* in_data = input.data(); + T* out_data = output->data(); + int* max_index = index->data(); + + int64_t num_seq = out_dims[0]; + int64_t dim = output->numel() / num_seq; + for (int64_t i = 0; i < num_seq; ++i) { + if (starts[i] == starts[i + 1]) { + for (int64_t k = 0; k < dim; ++k) { + out_data[i * dim + k] = pad_value; + max_index[i * dim + k] = -1; + } + continue; + } + for (int64_t k = 0; k < dim; ++k) { + out_data[i * dim + k] = in_data[starts[i] * dim + k]; + max_index[i * dim + k] = starts[i]; + } + for (size_t j = starts[i] + 1; j < starts[i + 1]; ++j) { + for (int64_t k = 0; k < dim; ++k) { + if (in_data[j * dim + k] > out_data[i * dim + k]) { + out_data[i * dim + k] = in_data[j * dim + k]; + max_index[i * dim + k] = j; + } + } + } + } + } +}; +// Instantisation of Max Sequence Pooling for test phase eg. no need to fill +// index buffer +template +class MaxSeqPoolFunctor { + public: + void operator()(const phi::CPUContext& context, + const phi::DenseTensor& input, + T pad_value, + phi::DenseTensor* output, + phi::DenseTensor* index) { + auto in_dims = input.dims(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_GT(in_dims.size(), + 1, + errors::InvalidArgument( + "The rank of input shall be greater than 1, but got " + "%ld <= 1. Please check the input value.", + in_dims.size())); + PADDLE_ENFORCE_GT(out_dims.size(), + 1, + errors::InvalidArgument( + "The rank of output shall be greater than 1, but got " + "%ld <= 1. Please check the input value.", + out_dims.size())); + for (int64_t i = 1; i < in_dims.size(); ++i) { + PADDLE_ENFORCE_EQ( + in_dims[i], + out_dims[i], + errors::InvalidArgument( + "The dimension of input and output shall be same. Expected %ld " + "== %ld, but got %ld != %ld. Please check the input value.", + in_dims[i], + out_dims[i], + in_dims[i], + out_dims[i])); + } + + auto lod_level = input.lod().size(); + auto starts = input.lod()[lod_level - 1]; + const T* in_data = input.data(); + T* out_data = output->data(); + + int64_t num_seq = out_dims[0]; + int64_t dim = output->numel() / num_seq; + for (int64_t i = 0; i < num_seq; ++i) { + if (starts[i] == starts[i + 1]) { + for (int64_t k = 0; k < dim; ++k) { + out_data[i * dim + k] = pad_value; + } + continue; + } + std::memcpy( + &out_data[i * dim], &in_data[starts[i] * dim], dim * sizeof(T)); + for (size_t j = starts[i] + 1; j < starts[i + 1]; ++j) { + for (int64_t k = 0; k < dim; ++k) { + if (in_data[j * dim + k] > out_data[i * dim + k]) { + out_data[i * dim + k] = in_data[j * dim + k]; + } + } + } + } + } +}; +template +class MaxSeqPoolGradFunctor { + public: + void operator()(const phi::CPUContext& context, + const phi::DenseTensor& out_grad, + const phi::DenseTensor& index, + phi::DenseTensor* in_grad) { + auto og_dims = out_grad.dims(); + auto ig_dims = in_grad->dims(); + auto idx_dims = index.dims(); + PADDLE_ENFORCE_GT(og_dims.size(), + 1, + errors::InvalidArgument( + "The rank of output@Grad shall be greater than 1, " + "but got %ld <= 1. Please check the input value.", + og_dims.size())); + PADDLE_ENFORCE_GT(ig_dims.size(), + 1, + errors::InvalidArgument( + "The rank of input@Grad shall be greater than 1, but " + "got %ld <= 1. Please check the input value.", + ig_dims.size())); + for (int64_t i = 1; i < og_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(og_dims[i], + ig_dims[i], + errors::InvalidArgument( + "The dimension of input@Grad and output@Grad shall " + "be same. Expected %ld == %ld, but got %ld != %ld. " + "Please check the input value.", + og_dims[i], + ig_dims[i], + og_dims[i], + ig_dims[i])); + } + PADDLE_ENFORCE_EQ( + idx_dims, + og_dims, + errors::InvalidArgument( + "The dimension of index and output@Grad shall be same. Expected " + "%ld == %ld, but got %ld != %ld. Please check the input value.", + idx_dims, + og_dims, + idx_dims, + og_dims)); + + const T* og_data = out_grad.data(); + const int* max_index = index.data(); + T* ig_data = in_grad->data(); + + phi::funcs::SetConstant set_zero; + set_zero(context, in_grad, static_cast(0.0)); + int64_t num_seq = og_dims[0]; + int64_t dim = out_grad.numel() / num_seq; + for (int64_t i = 0; i < num_seq; ++i) { + for (int64_t j = 0; j < dim; ++j) { + int step_id = max_index[i * dim + j]; + if (step_id == -1) continue; + ig_data[step_id * dim + j] = og_data[i * dim + j]; + } + } + } +}; + +template +class LastSeqPoolFunctor { + public: + void operator()(const phi::CPUContext& context, + const phi::DenseTensor& input, + T pad_value, + phi::DenseTensor* output) { + // Create pointers to input and output data + auto* in_data = input.data(); + auto* out_data = output->data(); + + // Calculate the size of each item in sequence + int64_t item_size = input.numel() / input.dims()[0]; + auto lod_level = input.lod().size(); + auto lod = input.lod()[lod_level - 1]; + int seq_num = static_cast(lod.size()) - 1; + for (int i = 0; i < seq_num; ++i) { + // Calculate the length of each sequence + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + if (seq_len == 0) { + for (int j = 0; j < item_size; ++j) { + out_data[j] = pad_value; + } + } else { + // Point to the begin of next sequence + in_data += seq_len * item_size; + // Copy the last item of sequence to output + std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T)); + } + out_data += item_size; + } + } +}; + +template +class FirstSeqPoolFunctor { + public: + void operator()(const phi::CPUContext& context, + const phi::DenseTensor& input, + T pad_value, + phi::DenseTensor* output) { + // Create pointers to input and output data + auto* in_data = input.data(); + auto* out_data = output->data(); + + // Calculate the size of each item in sequence + int64_t item_size = input.numel() / input.dims()[0]; + auto lod_level = input.lod().size(); + auto lod = input.lod()[lod_level - 1]; + int seq_num = static_cast(lod.size()) - 1; + for (int i = 0; i < seq_num; ++i) { + // Calculate the length of each sequence + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + if (seq_len == 0) { + for (int j = 0; j < item_size; ++j) { + out_data[j] = pad_value; + } + } else { + // Copy the first item of sequence to output + std::memcpy(out_data, in_data, item_size * sizeof(T)); + // Point to the next sequence + in_data += seq_len * item_size; + } + out_data += item_size; + } + } +}; + +template +class SumSeqPoolGradFunctor { + public: + void operator()(const phi::CPUContext& context, + const phi::DenseTensor& out_grad, + phi::DenseTensor* in_grad) { + auto lod_level = in_grad->lod().size(); + auto lod = in_grad->lod()[lod_level - 1]; + int64_t out_w = out_grad.numel() / out_grad.dims()[0]; + int64_t in_w = in_grad->numel() / in_grad->dims()[0]; + PADDLE_ENFORCE_EQ(in_w, + out_w, + errors::InvalidArgument( + "The feature size of input@Grad and output@Grad " + "shall be same. Expected %ld == %ld, but got %ld != " + "%ld. Please check the input value.", + in_w, + out_w, + in_w, + out_w)); + const T* out_g_data = out_grad.data(); + T* in_g_data = in_grad->mutable_data(context.GetPlace()); + auto blas = phi::funcs::GetBlas(context); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + int64_t h = static_cast(lod[i + 1] - lod[i]); + if (h == 0) continue; + int64_t in_offset = lod[i] * in_w; + const T* out_pos = out_g_data + i * out_w; + T* in_pos = in_g_data + in_offset; + for (int r = 0; r != h; ++r) { + blas.VCOPY(in_w, out_pos, in_pos + r * in_w); + } + } + } +}; + +template +class SequencePoolFunctor { + public: + /* max pool has index output */ + void operator()(const phi::CPUContext& context, + const std::string pooltype, + T pad_value, + const phi::DenseTensor& input, + phi::DenseTensor* output, + bool is_test, + phi::DenseTensor* index = nullptr) { + if (pooltype == "MAX") { + if (is_test) { + phi::math::MaxSeqPoolFunctor max_pool; + max_pool(context, input, pad_value, output, index); + } else { + phi::math::MaxSeqPoolFunctor max_pool; + max_pool(context, input, pad_value, output, index); + } + return; + } + if (pooltype == "LAST") { + phi::math::LastSeqPoolFunctor last_pool; + last_pool(context, input, pad_value, output); + return; + } + if (pooltype == "FIRST") { + phi::math::FirstSeqPoolFunctor first_pool; + first_pool(context, input, pad_value, output); + return; + } + auto lod_level = input.lod().size(); + auto lod = input.lod()[lod_level - 1]; + if (pooltype == "SUM") { + auto place = context.GetPlace(); + PADDLE_ENFORCE_EQ( + platform::is_cpu_place(place), + true, + errors::InvalidArgument( + "Sequence_pool should run on CPU Device when pooltype is SUM")); + const T* src = input.data(); + T* dst = output->mutable_data(place); + phi::jit::seq_pool_attr_t attr( + static_cast(input.numel() / input.dims()[0]), + phi::jit::SeqPoolType::kSum); + auto seqpool = phi::jit::KernelFuncs, + platform::CPUPlace>::Cache() + .At(attr); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + attr.h = static_cast(lod[i + 1] - lod[i]); + if (attr.h == 0) { + for (int j = 0; j < attr.w; ++j) { + dst[j] = pad_value; + } + } else { + seqpool(src, dst, &attr); + } + dst += attr.w; + src += attr.h * attr.w; + } + return; + } + auto& place = *context.eigen_device(); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + phi::DenseTensor out_t = output->Slice(i, i + 1); + int64_t w = input.numel() / input.dims()[0]; + if (lod[i] == lod[i + 1]) { + for (int j = 0; j < w; ++j) { + out_t.data()[j] = pad_value; + } + continue; + } + phi::DenseTensor in_t = + input.Slice(static_cast(lod[i]), static_cast(lod[i + 1])); + int64_t h = static_cast(lod[i + 1] - lod[i]); + auto in_e = EigenMatrix::From(in_t, phi::make_ddim({h, w})); + auto out_e = EigenVector::Flatten(out_t); + if (pooltype == "AVERAGE") { + out_e.device(place) = in_e.mean(Eigen::array({{0}})); + } else if (pooltype == "SQRT") { + out_e.device(place) = in_e.sum(Eigen::array({{0}})) / + std::sqrt(static_cast(h)); + } else { + PADDLE_THROW(errors::InvalidArgument( + "unsupported pooling pooltype: %s. Only support \"AVERAGE\" and " + "\"SQRT\"", + pooltype)); + } + } + } +}; + +template +class SequencePoolGradFunctor { + public: + void operator()(const phi::CPUContext& context, + const std::string pooltype, + const phi::DenseTensor& out_grad, + phi::DenseTensor* in_grad, + /* max pool has index */ + const phi::DenseTensor* index = nullptr) { + if (pooltype == "MAX") { + phi::math::MaxSeqPoolGradFunctor max_pool_grad; + max_pool_grad(context, out_grad, *index, in_grad); + return; + } + + if (pooltype == "LAST" || pooltype == "FIRST") { + // set X@Grad be zero at first when pooltype is LAST/FIRST + phi::funcs::SetConstant functor; + functor(context, in_grad, 0); + } + + if (pooltype == "SUM") { + phi::math::SumSeqPoolGradFunctor sum_pool_grad; + sum_pool_grad(context, out_grad, in_grad); + return; + } + + auto lod_level = in_grad->lod().size(); + auto lod = in_grad->lod()[lod_level - 1]; + auto& place = *context.eigen_device(); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + if (lod[i] == lod[i + 1]) continue; + auto in_g_t = in_grad->Slice(static_cast(lod[i]), + static_cast(lod[i + 1])); + auto out_g_t = out_grad.Slice(i, i + 1); + int64_t h = static_cast(lod[i + 1] - lod[i]); + int64_t w = in_grad->numel() / in_grad->dims()[0]; + auto in_g_e = EigenMatrix::From(in_g_t, {h, w}); + auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); + auto out_g_e_v = EigenVector::Flatten(out_g_t); + Eigen::DSizes bcast(h, 1); + + if (pooltype == "AVERAGE") { + in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); + } else if (pooltype == "SQRT") { + in_g_e.device(place) = + (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); + } else if (pooltype == "LAST") { + in_g_e.chip(h - 1, 0).device(place) = out_g_e_v; + } else if (pooltype == "FIRST") { + in_g_e.chip(0, 0).device(place) = out_g_e_v; + } else { + PADDLE_THROW(errors::InvalidArgument( + "unsupported pooling pooltype: %s. Only support \"AVERAGE\", " + "\"SQRT\", \"LAST\" and \"FIRST\"", + pooltype)); + } + } + } +}; + +template class SequencePoolFunctor; +template class SequencePoolFunctor; +template class SequencePoolGradFunctor; +template class SequencePoolGradFunctor; + +} // namespace math +} // namespace phi diff --git a/paddle/phi/kernels/funcs/sequence_pooling.cu b/paddle/phi/kernels/funcs/sequence_pooling.cu new file mode 100644 index 0000000000000..01b8af816328b --- /dev/null +++ b/paddle/phi/kernels/funcs/sequence_pooling.cu @@ -0,0 +1,501 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/core/macros.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/sequence_pooling.h" + +namespace phi { +namespace math { + +template +struct MaxPoolFunctor { + HOSTDEVICE void operator()(const T* input, + const T pad_value, + const size_t start, + const size_t end, + const size_t item_dim, + T* output, + int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + T max_val = static_cast(-FLT_MAX); + int max_index = -1; + if (start == end) { + output[tid] = pad_value; + index[tid] = -1; + } else { + for (int i = start; i < end; ++i) { + if (max_val < input[item_dim * i + tid]) { + max_val = input[item_dim * i + tid]; + max_index = i; + } + } + output[tid] = max_val; + index[tid] = max_index; + } + } + } +}; + +template +struct AvgPoolFunctor { + HOSTDEVICE void operator()(const T* input, + const T pad_value, + const size_t start, + const size_t end, + const size_t item_dim, + T* output, + int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + if (start == end) { + output[tid] = pad_value; + } else { + T val = static_cast(0); + for (int i = start; i < end; ++i) { + val += input[item_dim * i + tid]; + } + // end, start is lod, so end - start != 0 + output[tid] = val / static_cast(end - start); + } + } + } +}; + +template +struct SumPoolFunctor { + HOSTDEVICE void operator()(const T* input, + const T pad_value, + const size_t start, + const size_t end, + const size_t item_dim, + T* output, + int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + if (start == end) { + output[tid] = pad_value; + } else { + T val = static_cast(0); + for (int i = start; i < end; ++i) { + val += input[item_dim * i + tid]; + } + output[tid] = val; + } + } + } +}; + +template +struct SqrtPoolFunctor { + HOSTDEVICE void operator()(const T* input, + const T pad_value, + const size_t start, + const size_t end, + const size_t item_dim, + T* output, + int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + if (start == end) { + output[tid] = pad_value; + } else { + T val = static_cast(0); + for (int i = start; i < end; ++i) { + val += input[item_dim * i + tid]; + } + // end, start is lod, so end - start != 0 + output[tid] = val / sqrt(end - start); + } + } + } +}; + +template +struct LastPoolFunctor { + HOSTDEVICE void operator()(const T* input, + const T pad_value, + const size_t start, + const size_t end, + const size_t item_dim, + T* output, + int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + if (start == end) { + output[tid] = pad_value; + } else { + output[tid] = input[item_dim * (end - 1) + tid]; + } + } + } +}; + +template +struct FirstPoolFunctor { + HOSTDEVICE void operator()(const T* input, + const T pad_value, + const size_t start, + const size_t end, + const size_t item_dim, + T* output, + int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + if (start == end) { + output[tid] = pad_value; + } else { + output[tid] = input[item_dim * start + tid]; + } + } + } +}; + +template +__global__ void sequence_pool_kernel(Range_OP op, + const T* input, + const T pad_value, + const size_t* lod, + const size_t lod_size, + const size_t item_dim, + T* output, + int* index) { + int bid = blockIdx.x; + if (bid >= lod_size - 1) return; + size_t start = lod[bid]; + size_t end = lod[bid + 1]; + int* index_offset = nullptr; + if (index != nullptr) { + index_offset = &index[bid * item_dim]; + } + op(input, + pad_value, + start, + end, + item_dim, + &output[bid * item_dim], + index_offset); +} + +template +class SequencePoolFunctor { + public: + void operator()(const phi::GPUContext& context, + const std::string pooltype, + T pad_value, + const phi::DenseTensor& input, + phi::DenseTensor* output, + bool is_test, + phi::DenseTensor* index = nullptr) { + auto lod_level = input.lod().size(); + auto& lod = input.lod()[lod_level - 1]; + const size_t item_dim = output->numel() / output->dims()[0]; + dim3 threads(1024, 1); + dim3 grid(std::max(static_cast(lod.size()) - 1, 1), 1); + phi::MixVector mix_vector(&lod); + if (pooltype == "MAX") { + sequence_pool_kernel> + <<>>( + MaxPoolFunctor(), + input.data(), + pad_value, + mix_vector.CUDAData(context.GetPlace()), + lod.size(), + item_dim, + output->mutable_data(context.GetPlace()), + index->data()); + } else if (pooltype == "AVERAGE") { + sequence_pool_kernel> + <<>>( + AvgPoolFunctor(), + input.data(), + pad_value, + mix_vector.CUDAData(context.GetPlace()), + lod.size(), + item_dim, + output->mutable_data(context.GetPlace()), + nullptr); + } else if (pooltype == "SUM") { + sequence_pool_kernel> + <<>>( + SumPoolFunctor(), + input.data(), + pad_value, + mix_vector.CUDAData(context.GetPlace()), + lod.size(), + item_dim, + output->mutable_data(context.GetPlace()), + nullptr); + } else if (pooltype == "SQRT") { + sequence_pool_kernel> + <<>>( + SqrtPoolFunctor(), + input.data(), + pad_value, + mix_vector.CUDAData(context.GetPlace()), + lod.size(), + item_dim, + output->mutable_data(context.GetPlace()), + nullptr); + } else if (pooltype == "LAST") { + sequence_pool_kernel> + <<>>( + LastPoolFunctor(), + input.data(), + pad_value, + mix_vector.CUDAData(context.GetPlace()), + lod.size(), + item_dim, + output->mutable_data(context.GetPlace()), + nullptr); + } else if (pooltype == "FIRST") { + sequence_pool_kernel> + <<>>( + FirstPoolFunctor(), + input.data(), + pad_value, + mix_vector.CUDAData(context.GetPlace()), + lod.size(), + item_dim, + output->mutable_data(context.GetPlace()), + nullptr); + } else { + PADDLE_THROW(errors::InvalidArgument( + "unsupported pooling pooltype: %s. Only support \"MAX\", " + "\"AVERAGE\", \"SUM\", \"SQRT\", \"LAST\" and \"FIRST\"", + pooltype)); + } + } +}; + +template +struct MaxPoolGradFunctor { + HOSTDEVICE void operator()(const T* out_grad, + const size_t start, + const size_t end, + const size_t item_dim, + T* in_grad, + const int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + for (int i = start; i < end; ++i) { + if (i == index[tid]) { + in_grad[item_dim * i + tid] = out_grad[tid]; + } else { + in_grad[item_dim * i + tid] = static_cast(0); + } + } + } + } +}; + +template +struct AvgPoolGradFunctor { + HOSTDEVICE void operator()(const T* out_grad, + const size_t start, + const size_t end, + const size_t item_dim, + T* in_grad, + const int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + for (int i = start; i < end; ++i) { + in_grad[item_dim * i + tid] = out_grad[tid] / (end - start); + } + } + } +}; + +template +struct SumPoolGradFunctor { + HOSTDEVICE void operator()(const T* out_grad, + const size_t start, + const size_t end, + const size_t item_dim, + T* in_grad, + const int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + for (int i = start; i < end; ++i) { + in_grad[item_dim * i + tid] = out_grad[tid]; + } + } + } +}; + +template +struct SqrtPoolGradFunctor { + HOSTDEVICE void operator()(const T* out_grad, + const size_t start, + const size_t end, + const size_t item_dim, + T* in_grad, + const int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + for (int i = start; i < end; ++i) { + in_grad[item_dim * i + tid] = + out_grad[tid] / (sqrt(static_cast(end - start))); + } + } + } +}; + +template +struct LastPoolGradFunctor { + HOSTDEVICE void operator()(const T* out_grad, + const size_t start, + const size_t end, + const size_t item_dim, + T* in_grad, + const int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + for (int i = start; i < end; ++i) { + if (i == end - 1) { + in_grad[item_dim * i + tid] = out_grad[tid]; + } else { + in_grad[item_dim * i + tid] = static_cast(0); + } + } + } + } +}; + +template +struct FirstPoolGradFunctor { + HOSTDEVICE void operator()(const T* out_grad, + const size_t start, + const size_t end, + const size_t item_dim, + T* in_grad, + const int* index) { + for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { + for (int i = start; i < end; ++i) { + if (i == start) { + in_grad[item_dim * i + tid] = out_grad[tid]; + } else { + in_grad[item_dim * i + tid] = static_cast(0); + } + } + } + } +}; + +template +__global__ void sequence_pool_grad_kernel(Range_OP op, + const T* out_grad, + const size_t* lod, + const size_t lod_size, + const size_t item_dim, + T* in_grad, + const int* index) { + int bid = blockIdx.x; + if (bid >= lod_size - 1) return; + size_t start = lod[bid]; + size_t end = lod[bid + 1]; + const int* index_offset = nullptr; + if (index != nullptr) { + index_offset = &index[bid * item_dim]; + } + op(&out_grad[bid * item_dim], start, end, item_dim, in_grad, index_offset); +} + +template +class SequencePoolGradFunctor { + public: + void operator()(const phi::GPUContext& context, + const std::string pooltype, + const phi::DenseTensor& out_grad, + phi::DenseTensor* in_grad, + /* max pool has index */ + const phi::DenseTensor* index = nullptr) { + auto lod_level = in_grad->lod().size(); + auto& lod = in_grad->lod()[lod_level - 1]; + const size_t item_dim = in_grad->numel() / in_grad->dims()[0]; + dim3 threads(1024, 1); + dim3 grid(std::max(static_cast(lod.size()) - 1, 1), 1); + phi::MixVector mix_vector(&lod); + if (pooltype == "MAX") { + sequence_pool_grad_kernel> + <<>>( + MaxPoolGradFunctor(), + out_grad.data(), + mix_vector.CUDAData(context.GetPlace()), + lod.size(), + item_dim, + in_grad->mutable_data(context.GetPlace()), + index->data()); + } else if (pooltype == "AVERAGE") { + sequence_pool_grad_kernel> + <<>>( + AvgPoolGradFunctor(), + out_grad.data(), + mix_vector.CUDAData(context.GetPlace()), + lod.size(), + item_dim, + in_grad->mutable_data(context.GetPlace()), + nullptr); + } else if (pooltype == "SUM") { + sequence_pool_grad_kernel> + <<>>( + SumPoolGradFunctor(), + out_grad.data(), + mix_vector.CUDAData(context.GetPlace()), + lod.size(), + item_dim, + in_grad->mutable_data(context.GetPlace()), + nullptr); + } else if (pooltype == "SQRT") { + sequence_pool_grad_kernel> + <<>>( + SqrtPoolGradFunctor(), + out_grad.data(), + mix_vector.CUDAData(context.GetPlace()), + lod.size(), + item_dim, + in_grad->mutable_data(context.GetPlace()), + nullptr); + } else if (pooltype == "LAST") { + sequence_pool_grad_kernel> + <<>>( + LastPoolGradFunctor(), + out_grad.data(), + mix_vector.CUDAData(context.GetPlace()), + lod.size(), + item_dim, + in_grad->mutable_data(context.GetPlace()), + nullptr); + } else if (pooltype == "FIRST") { + sequence_pool_grad_kernel> + <<>>( + FirstPoolGradFunctor(), + out_grad.data(), + mix_vector.CUDAData(context.GetPlace()), + lod.size(), + item_dim, + in_grad->mutable_data(context.GetPlace()), + nullptr); + + } else { + PADDLE_THROW(errors::InvalidArgument( + "unsupported pooling pooltype: %s. Only support \"MAX\", " + "\"AVERAGE\", \"SUM\", \"SQRT\", \"LAST\" and \"FIRST\"", + pooltype)); + } + } +}; + +// sequence pooling +template class SequencePoolFunctor; +template class SequencePoolFunctor; +template class SequencePoolGradFunctor; +template class SequencePoolGradFunctor; + +} // namespace math +} // namespace phi diff --git a/paddle/phi/kernels/funcs/sequence_pooling.h b/paddle/phi/kernels/funcs/sequence_pooling.h new file mode 100644 index 0000000000000..87929f237d692 --- /dev/null +++ b/paddle/phi/kernels/funcs/sequence_pooling.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" + +namespace phi { +namespace math { + +template +class SequencePoolFunctor { + public: + /* max pool has index output */ + void operator()(const DeviceContext& context, + const std::string pooltype, + T pad_value, + const phi::DenseTensor& input, + phi::DenseTensor* output, + bool is_test = false, + phi::DenseTensor* index = nullptr); +}; + +template +class SequencePoolGradFunctor { + public: + void operator()(const DeviceContext& context, + const std::string pooltype, + const phi::DenseTensor& out_grad, + phi::DenseTensor* in_grad, + /* max pool has index */ + const phi::DenseTensor* index = nullptr); +}; + +} // namespace math +} // namespace phi diff --git a/paddle/phi/kernels/gpu/sequence_pool_kernel.cu b/paddle/phi/kernels/gpu/sequence_pool_kernel.cu new file mode 100644 index 0000000000000..89a13005b39fc --- /dev/null +++ b/paddle/phi/kernels/gpu/sequence_pool_kernel.cu @@ -0,0 +1,95 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sequence_pool_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void SequencePoolKernel(const Context& ctx, + const DenseTensor& x, + bool is_test, + std::string pooltype, + float pad_value, + DenseTensor* out, + DenseTensor* max_index) { + Context pad_value_ = static_cast(pad_value); + + auto dims = x.dims(); + auto lod = x.lod(); + auto lod_level = lod.size(); + // InferShape by lod + PADDLE_ENFORCE_GT( + lod_level, + 0, + errors::InvalidArgument("Input(X) phi::DenseTensor of SequencePoolOp " + "does not contain LoD information.")); + PADDLE_ENFORCE_LE( + lod_level, + 2UL, + errors::InvalidArgument("The lod level of input shall be no more than 2." + "Received lod level is %d.", + lod_level)); + PADDLE_ENFORCE_GE( + dims[0], + /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), + errors::InvalidArgument( + "The first dimension of Input(X) must be large than batch size." + "But received first dimension of Input(X) is %d, while batch" + "size is %d.", + dims[0], + static_cast(lod[lod_level - 1].size() - 1))); + if (lod_level > 1UL) { + PADDLE_ENFORCE_EQ( + lod[0][lod[0].size() - 1], + lod[1].size() - 1, + errors::InvalidArgument("The input lod information is illegal.")); + phi::LoD out_lod; + out_lod.push_back(lod[0]); + out->set_lod(out_lod); + } + dims[0] = lod[lod_level - 1].size() - 1; + out->Resize({dims}); + // out->mutable_data(ctx.GetPlace()); + ctx.template Alloc(out); + phi::DenseTensor* index = nullptr; + + bool is_test_ = ctx.HasAttr("is_test") ? is_test : false; + + auto& place = *ctx.eigen_device(); + + // Do not create index buffer for inference mode + if (pooltype == "MAX") { + index = max_index; + index->Resize({dims}); + // index->mutable_data(ctx.GetPlace()); + ctx.template Alloc(index); + } + math::SequencePoolFunctor pool; + pool(ctx.template device_context(), + pooltype, + pad_value_, + &x, + out, + is_test_, + index); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + sequence_pool, GPU, ALL_LAYOUT, phi::SequencePoolKernel, float, double) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/sequence_pool_kernel.h b/paddle/phi/kernels/sequence_pool_kernel.h new file mode 100644 index 0000000000000..9e9adcb839525 --- /dev/null +++ b/paddle/phi/kernels/sequence_pool_kernel.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/sequence_pooling.h" + +namespace phi { +template +void SequencePoolKernel(const Context& ctx, + const DenseTensor& x, + bool is_test, + std::string pooltype, + float pad_value, + DenseTensor* out, + DenseTensor* max_index); + +} // namespace phi From cc3d41f24e8e1fd790ff818badfb111c5f039074 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sun, 2 Apr 2023 22:12:41 +0800 Subject: [PATCH 02/23] mv kernels impl --- .../sequence_ops/unity_build_rule.cmake | 2 + .../phi/kernels/cpu/sequence_pool_kernel.cc | 76 +---------------- paddle/phi/kernels/funcs/sequence_pooling.cc | 2 + paddle/phi/kernels/funcs/sequence_pooling.cu | 2 + paddle/phi/kernels/funcs/sequence_pooling.h | 2 + .../phi/kernels/gpu/sequence_pool_kernel.cu | 76 +---------------- .../kernels/impl/sequence_pool_kernel_impl.h | 82 +++++++++++++++++++ paddle/phi/ops/compat/sequence_pool.cc | 29 +++++++ 8 files changed, 121 insertions(+), 150 deletions(-) create mode 100644 paddle/phi/kernels/impl/sequence_pool_kernel_impl.h create mode 100644 paddle/phi/ops/compat/sequence_pool.cc diff --git a/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake b/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake index 2b66923df028e..9a87e27b24197 100644 --- a/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake +++ b/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake @@ -13,6 +13,7 @@ register_unity_group( sequence_expand_op.cc sequence_mask_op.cc sequence_pad_op.cc + sequence_pool_op.cc sequence_expand_as_op.cc sequence_reshape_op.cc sequence_reverse_op.cc @@ -30,6 +31,7 @@ register_unity_group( sequence_expand_op.cu sequence_mask_op.cu sequence_pad_op.cu + sequence_pool_op.cu sequence_expand_as_op.cu sequence_reshape_op.cu sequence_reverse_op.cu diff --git a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc index 6912e17df30da..e4e813bf49e33 100644 --- a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc +++ b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc @@ -12,82 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/sequence_pool_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -namespace phi { -template -void SequencePoolKernel(const Context& ctx, - const DenseTensor& x, - bool is_test, - std::string pooltype, - float pad_value, - DenseTensor* out, - DenseTensor* max_index) { - Context pad_value_ = static_cast(pad_value); - - auto dims = x.dims(); - auto lod = x.lod(); - auto lod_level = lod.size(); - // InferShape by lod - PADDLE_ENFORCE_GT( - lod_level, - 0, - errors::InvalidArgument("Input(X) phi::DenseTensor of SequencePoolOp " - "does not contain LoD information.")); - PADDLE_ENFORCE_LE( - lod_level, - 2UL, - errors::InvalidArgument("The lod level of input shall be no more than 2." - "Received lod level is %d.", - lod_level)); - PADDLE_ENFORCE_GE( - dims[0], - /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), - errors::InvalidArgument( - "The first dimension of Input(X) must be large than batch size." - "But received first dimension of Input(X) is %d, while batch" - "size is %d.", - dims[0], - static_cast(lod[lod_level - 1].size() - 1))); - if (lod_level > 1UL) { - PADDLE_ENFORCE_EQ( - lod[0][lod[0].size() - 1], - lod[1].size() - 1, - errors::InvalidArgument("The input lod information is illegal.")); - phi::LoD out_lod; - out_lod.push_back(lod[0]); - out->set_lod(out_lod); - } - dims[0] = lod[lod_level - 1].size() - 1; - out->Resize({dims}); - // out->mutable_data(ctx.GetPlace()); - ctx.template Alloc(out); - phi::DenseTensor* index = nullptr; - - bool is_test_ = ctx.HasAttr("is_test") ? is_test : false; - - auto& place = *ctx.eigen_device(); - - // Do not create index buffer for inference mode - if (pooltype == "MAX" && is_test_ == false) { - index = max_index; - index->Resize({dims}); - // index->mutable_data(ctx.GetPlace()); - ctx.template Alloc(index); - } - math::SequencePoolFunctor pool; - pool(ctx.template device_context(), - pooltype, - pad_value_, - &x, - out, - is_test_, - index); -} - -} // namespace phi +#include "paddle/phi/kernels/impl/sequence_pool_kernel_impl.h" PD_REGISTER_KERNEL( sequence_pool, CPU, ALL_LAYOUT, phi::SequencePoolKernel, float, double) { diff --git a/paddle/phi/kernels/funcs/sequence_pooling.cc b/paddle/phi/kernels/funcs/sequence_pooling.cc index 7d235b7ad277b..c3ada22a5cfc6 100644 --- a/paddle/phi/kernels/funcs/sequence_pooling.cc +++ b/paddle/phi/kernels/funcs/sequence_pooling.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { +namespace funcs { namespace math { template ; template class SequencePoolGradFunctor; } // namespace math +} // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/sequence_pooling.cu b/paddle/phi/kernels/funcs/sequence_pooling.cu index 01b8af816328b..423300b750443 100644 --- a/paddle/phi/kernels/funcs/sequence_pooling.cu +++ b/paddle/phi/kernels/funcs/sequence_pooling.cu @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/sequence_pooling.h" namespace phi { +namespace funcs { namespace math { template @@ -498,4 +499,5 @@ template class SequencePoolGradFunctor; template class SequencePoolGradFunctor; } // namespace math +} // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/sequence_pooling.h b/paddle/phi/kernels/funcs/sequence_pooling.h index 87929f237d692..5005f7aa1e405 100644 --- a/paddle/phi/kernels/funcs/sequence_pooling.h +++ b/paddle/phi/kernels/funcs/sequence_pooling.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/platform/device_context.h" namespace phi { +namespace funcs { namespace math { template @@ -47,4 +48,5 @@ class SequencePoolGradFunctor { }; } // namespace math +} // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/sequence_pool_kernel.cu b/paddle/phi/kernels/gpu/sequence_pool_kernel.cu index 89a13005b39fc..44cc833b1925a 100644 --- a/paddle/phi/kernels/gpu/sequence_pool_kernel.cu +++ b/paddle/phi/kernels/gpu/sequence_pool_kernel.cu @@ -12,82 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/sequence_pool_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -namespace phi { -template -void SequencePoolKernel(const Context& ctx, - const DenseTensor& x, - bool is_test, - std::string pooltype, - float pad_value, - DenseTensor* out, - DenseTensor* max_index) { - Context pad_value_ = static_cast(pad_value); - - auto dims = x.dims(); - auto lod = x.lod(); - auto lod_level = lod.size(); - // InferShape by lod - PADDLE_ENFORCE_GT( - lod_level, - 0, - errors::InvalidArgument("Input(X) phi::DenseTensor of SequencePoolOp " - "does not contain LoD information.")); - PADDLE_ENFORCE_LE( - lod_level, - 2UL, - errors::InvalidArgument("The lod level of input shall be no more than 2." - "Received lod level is %d.", - lod_level)); - PADDLE_ENFORCE_GE( - dims[0], - /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), - errors::InvalidArgument( - "The first dimension of Input(X) must be large than batch size." - "But received first dimension of Input(X) is %d, while batch" - "size is %d.", - dims[0], - static_cast(lod[lod_level - 1].size() - 1))); - if (lod_level > 1UL) { - PADDLE_ENFORCE_EQ( - lod[0][lod[0].size() - 1], - lod[1].size() - 1, - errors::InvalidArgument("The input lod information is illegal.")); - phi::LoD out_lod; - out_lod.push_back(lod[0]); - out->set_lod(out_lod); - } - dims[0] = lod[lod_level - 1].size() - 1; - out->Resize({dims}); - // out->mutable_data(ctx.GetPlace()); - ctx.template Alloc(out); - phi::DenseTensor* index = nullptr; - - bool is_test_ = ctx.HasAttr("is_test") ? is_test : false; - - auto& place = *ctx.eigen_device(); - - // Do not create index buffer for inference mode - if (pooltype == "MAX") { - index = max_index; - index->Resize({dims}); - // index->mutable_data(ctx.GetPlace()); - ctx.template Alloc(index); - } - math::SequencePoolFunctor pool; - pool(ctx.template device_context(), - pooltype, - pad_value_, - &x, - out, - is_test_, - index); -} - -} // namespace phi +#include "paddle/phi/kernels/impl/sequence_pool_kernel_impl.h" PD_REGISTER_KERNEL( sequence_pool, GPU, ALL_LAYOUT, phi::SequencePoolKernel, float, double) { diff --git a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h new file mode 100644 index 0000000000000..6094515eddee0 --- /dev/null +++ b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/sequence_pool_kernel.h" + +namespace phi { +template +void SequencePoolKernel(const Context& ctx, + const DenseTensor& x, + bool is_test, + std::string pooltype, + float pad_value, + DenseTensor* out, + DenseTensor* max_index) { + T pad_value_ = static_cast(pad_value); + + auto dims = x.dims(); + auto lod = x.lod(); + auto lod_level = lod.size(); + // InferShape by lod + PADDLE_ENFORCE_GT( + lod_level, + 0, + errors::InvalidArgument("Input(X) phi::DenseTensor of SequencePoolOp " + "does not contain LoD information.")); + PADDLE_ENFORCE_LE( + lod_level, + 2UL, + errors::InvalidArgument("The lod level of input shall be no more than 2." + "Received lod level is %d.", + lod_level)); + PADDLE_ENFORCE_GE( + dims[0], + /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), + errors::InvalidArgument( + "The first dimension of Input(X) must be large than batch size." + "But received first dimension of Input(X) is %d, while batch" + "size is %d.", + dims[0], + static_cast(lod[lod_level - 1].size() - 1))); + if (lod_level > 1UL) { + PADDLE_ENFORCE_EQ( + lod[0][lod[0].size() - 1], + lod[1].size() - 1, + errors::InvalidArgument("The input lod information is illegal.")); + phi::LoD out_lod; + out_lod.push_back(lod[0]); + out->set_lod(out_lod); + } + dims[0] = lod[lod_level - 1].size() - 1; + out->Resize({dims}); + // out->mutable_data(ctx.GetPlace()); + ctx.template Alloc(out); + phi::DenseTensor* index = nullptr; + + // Do not create index buffer for inference mode + if (pooltype == "MAX" && + (is_test || (ctx.GetPlace() == phi::CPUPlace()) == false)) { + index = max_index; + index->Resize({dims}); + // index->mutable_data(ctx.GetPlace()); + ctx.template Alloc(index); + } + funcs::math::SequencePoolFunctor pool; + pool(ctx, pooltype, pad_value_, &x, out, is_test, index); +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/sequence_pool.cc b/paddle/phi/ops/compat/sequence_pool.cc new file mode 100644 index 0000000000000..db781a6677b30 --- /dev/null +++ b/paddle/phi/ops/compat/sequence_pool.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SequencePoolOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("sequence_pool", + {"X"}, + {"is_test", "pooltype", "pad_value"}, + {"Out", "MaxIndex"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(sequence_pool, phi::SequencePoolOpArgumentMapping); From 38af314234a6d7ee3c832a5afff01d936e1f0b70 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Mon, 3 Apr 2023 18:10:43 +0800 Subject: [PATCH 03/23] fix parameter error --- paddle/phi/kernels/impl/sequence_pool_kernel_impl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h index 6094515eddee0..e55d20225b952 100644 --- a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h @@ -75,8 +75,8 @@ void SequencePoolKernel(const Context& ctx, // index->mutable_data(ctx.GetPlace()); ctx.template Alloc(index); } - funcs::math::SequencePoolFunctor pool; - pool(ctx, pooltype, pad_value_, &x, out, is_test, index); + phi::funcs::math::SequencePoolFunctor pool; + pool(ctx, pooltype, pad_value_, x, out, is_test, index); } } // namespace phi From 3fd894415ad1cdf3562566ccde05725400a57b35 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Mon, 3 Apr 2023 21:55:20 +0800 Subject: [PATCH 04/23] clean include --- paddle/phi/kernels/cpu/sequence_pool_kernel.cc | 7 ++++--- paddle/phi/kernels/funcs/sequence_pooling.cc | 18 ++++++++---------- paddle/phi/kernels/funcs/sequence_pooling.cu | 2 -- paddle/phi/kernels/funcs/sequence_pooling.h | 6 +----- .../kernels/impl/sequence_pool_kernel_impl.h | 5 ++--- paddle/phi/kernels/sequence_pool_kernel.h | 1 - 6 files changed, 15 insertions(+), 24 deletions(-) diff --git a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc index e4e813bf49e33..4bc0e03983d5b 100644 --- a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc +++ b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/sequence_pool_kernel.h" + #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/sequence_pool_kernel_impl.h" PD_REGISTER_KERNEL( - sequence_pool, CPU, ALL_LAYOUT, phi::SequencePoolKernel, float, double) { - kernel->OutputAt(1).SetDataType(phi::DataType::INT64); -} + sequence_pool, CPU, ALL_LAYOUT, phi::SequencePoolKernel, float, double) {} diff --git a/paddle/phi/kernels/funcs/sequence_pooling.cc b/paddle/phi/kernels/funcs/sequence_pooling.cc index c3ada22a5cfc6..1ea1707069312 100644 --- a/paddle/phi/kernels/funcs/sequence_pooling.cc +++ b/paddle/phi/kernels/funcs/sequence_pooling.cc @@ -23,7 +23,6 @@ limitations under the License. */ namespace phi { namespace funcs { -namespace math { template { phi::DenseTensor* index = nullptr) { if (pooltype == "MAX") { if (is_test) { - phi::math::MaxSeqPoolFunctor max_pool; + phi::funcs::MaxSeqPoolFunctor max_pool; max_pool(context, input, pad_value, output, index); } else { - phi::math::MaxSeqPoolFunctor max_pool; + phi::funcs::MaxSeqPoolFunctor max_pool; max_pool(context, input, pad_value, output, index); } return; } if (pooltype == "LAST") { - phi::math::LastSeqPoolFunctor last_pool; + phi::funcs::LastSeqPoolFunctor last_pool; last_pool(context, input, pad_value, output); return; } if (pooltype == "FIRST") { - phi::math::FirstSeqPoolFunctor first_pool; + phi::funcs::FirstSeqPoolFunctor first_pool; first_pool(context, input, pad_value, output); return; } @@ -377,7 +376,7 @@ class SequencePoolFunctor { if (pooltype == "SUM") { auto place = context.GetPlace(); PADDLE_ENFORCE_EQ( - platform::is_cpu_place(place), + place == phi::CPUPlace(), true, errors::InvalidArgument( "Sequence_pool should run on CPU Device when pooltype is SUM")); @@ -387,7 +386,7 @@ class SequencePoolFunctor { static_cast(input.numel() / input.dims()[0]), phi::jit::SeqPoolType::kSum); auto seqpool = phi::jit::KernelFuncs, - platform::CPUPlace>::Cache() + phi::CPUPlace>::Cache() .At(attr); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { attr.h = static_cast(lod[i + 1] - lod[i]); @@ -443,7 +442,7 @@ class SequencePoolGradFunctor { /* max pool has index */ const phi::DenseTensor* index = nullptr) { if (pooltype == "MAX") { - phi::math::MaxSeqPoolGradFunctor max_pool_grad; + phi::funcs::MaxSeqPoolGradFunctor max_pool_grad; max_pool_grad(context, out_grad, *index, in_grad); return; } @@ -455,7 +454,7 @@ class SequencePoolGradFunctor { } if (pooltype == "SUM") { - phi::math::SumSeqPoolGradFunctor sum_pool_grad; + phi::funcs::SumSeqPoolGradFunctor sum_pool_grad; sum_pool_grad(context, out_grad, in_grad); return; } @@ -499,6 +498,5 @@ template class SequencePoolFunctor; template class SequencePoolGradFunctor; template class SequencePoolGradFunctor; -} // namespace math } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/sequence_pooling.cu b/paddle/phi/kernels/funcs/sequence_pooling.cu index 423300b750443..25abdb40ae2ec 100644 --- a/paddle/phi/kernels/funcs/sequence_pooling.cu +++ b/paddle/phi/kernels/funcs/sequence_pooling.cu @@ -22,7 +22,6 @@ limitations under the License. */ namespace phi { namespace funcs { -namespace math { template struct MaxPoolFunctor { @@ -498,6 +497,5 @@ template class SequencePoolFunctor; template class SequencePoolGradFunctor; template class SequencePoolGradFunctor; -} // namespace math } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/sequence_pooling.h b/paddle/phi/kernels/funcs/sequence_pooling.h index 5005f7aa1e405..937022728b233 100644 --- a/paddle/phi/kernels/funcs/sequence_pooling.h +++ b/paddle/phi/kernels/funcs/sequence_pooling.h @@ -15,13 +15,10 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/core/device_context.h" namespace phi { namespace funcs { -namespace math { template class SequencePoolFunctor { @@ -47,6 +44,5 @@ class SequencePoolGradFunctor { const phi::DenseTensor* index = nullptr); }; -} // namespace math } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h index e55d20225b952..0cd40e72b54de 100644 --- a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h @@ -14,8 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/sequence_pool_kernel.h" +#include "paddle/phi/kernels/funcs/sequence_pooling.h" namespace phi { template @@ -75,7 +74,7 @@ void SequencePoolKernel(const Context& ctx, // index->mutable_data(ctx.GetPlace()); ctx.template Alloc(index); } - phi::funcs::math::SequencePoolFunctor pool; + phi::funcs::SequencePoolFunctor pool; pool(ctx, pooltype, pad_value_, x, out, is_test, index); } diff --git a/paddle/phi/kernels/sequence_pool_kernel.h b/paddle/phi/kernels/sequence_pool_kernel.h index 9e9adcb839525..0480376984fd6 100644 --- a/paddle/phi/kernels/sequence_pool_kernel.h +++ b/paddle/phi/kernels/sequence_pool_kernel.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/sequence_pooling.h" namespace phi { template From 279c02a53791051f1c8a5a8fed0ddef5bab7a949 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 4 Apr 2023 00:43:26 +0800 Subject: [PATCH 05/23] fix compat filename --- paddle/phi/ops/compat/{sequence_pool.cc => sequence_pool_sig.cc} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename paddle/phi/ops/compat/{sequence_pool.cc => sequence_pool_sig.cc} (100%) diff --git a/paddle/phi/ops/compat/sequence_pool.cc b/paddle/phi/ops/compat/sequence_pool_sig.cc similarity index 100% rename from paddle/phi/ops/compat/sequence_pool.cc rename to paddle/phi/ops/compat/sequence_pool_sig.cc From 314af781988593e7fcdf045771bfd1da8bcd9691 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 4 Apr 2023 21:05:19 +0800 Subject: [PATCH 06/23] [phi] move fluid sequence_pool_grad to phi --- .../sequence_ops/sequence_pool_op.cc | 9 +--- .../operators/sequence_ops/sequence_pool_op.h | 49 ------------------- .../kernels/cpu/sequence_pool_grad_kernel.cc | 26 ++++++++++ .../kernels/gpu/sequence_pool_grad_kernel.cu} | 8 ++- .../phi/kernels/gpu/sequence_pool_kernel.cu | 4 +- .../impl/sequence_pool_grad_kernel_impl.h | 40 +++++++++++++++ .../kernels/impl/sequence_pool_kernel_impl.h | 3 +- .../phi/kernels/sequence_pool_grad_kernel.h | 30 ++++++++++++ paddle/phi/ops/compat/sequence_pool_sig.cc | 10 ++++ 9 files changed, 112 insertions(+), 67 deletions(-) delete mode 100644 paddle/fluid/operators/sequence_ops/sequence_pool_op.h create mode 100644 paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc rename paddle/{fluid/operators/sequence_ops/sequence_pool_op.cu => phi/kernels/gpu/sequence_pool_grad_kernel.cu} (64%) create mode 100644 paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/sequence_pool_grad_kernel.h diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index c44427f98f211..d616bca2c4e3b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -12,10 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" - -#include -#include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -196,7 +193,3 @@ REGISTER_OPERATOR(sequence_pool, REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp, ops::SequencePoolGradOpNoNeedBufferVarsInferer); - -REGISTER_OP_CPU_KERNEL(sequence_pool_grad, - ops::SequencePoolGradKernel, - ops::SequencePoolGradKernel); diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h deleted file mode 100644 index bcba5590bc567..0000000000000 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/sequence_pooling.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -template -class SequencePoolGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* out_g = - context.Input(framework::GradVarName("Out")); - auto* in_g = context.Output(framework::GradVarName("X")); - std::string pooltype = context.Attr("pooltype"); - const phi::DenseTensor* index = nullptr; - if (pooltype == "MAX") { - index = context.Input("MaxIndex"); - } - in_g->mutable_data(context.GetPlace()); - math::SequencePoolGradFunctor pool; - pool(context.template device_context(), - pooltype, - *out_g, - in_g, - index); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc new file mode 100644 index 0000000000000..9f15ed0147031 --- /dev/null +++ b/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sequence_pool_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(sequence_pool_grad, + CPU, + ALL_LAYOUT, + phi::SequencePoolGradKernel, + float, + double) {} diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu b/paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu similarity index 64% rename from paddle/fluid/operators/sequence_ops/sequence_pool_op.cu rename to paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu index df5dde79274f9..0aaeea6e7a6e1 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu +++ b/paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(sequence_pool_grad, - ops::SequencePoolGradKernel); +PD_REGISTER_KERNEL( + sequence_pool_grad, GPU, ALL_LAYOUT, phi::SequencePoolGradKernel, float) {} diff --git a/paddle/phi/kernels/gpu/sequence_pool_kernel.cu b/paddle/phi/kernels/gpu/sequence_pool_kernel.cu index 44cc833b1925a..e151357c038ad 100644 --- a/paddle/phi/kernels/gpu/sequence_pool_kernel.cu +++ b/paddle/phi/kernels/gpu/sequence_pool_kernel.cu @@ -16,6 +16,4 @@ limitations under the License. */ #include "paddle/phi/kernels/impl/sequence_pool_kernel_impl.h" PD_REGISTER_KERNEL( - sequence_pool, GPU, ALL_LAYOUT, phi::SequencePoolKernel, float, double) { - kernel->OutputAt(1).SetDataType(phi::DataType::INT64); -} + sequence_pool, GPU, ALL_LAYOUT, phi::SequencePoolKernel, float) {} diff --git a/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h new file mode 100644 index 0000000000000..442034b4b1dba --- /dev/null +++ b/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/sequence_pooling.h" + +namespace phi { + +template +void SequencePoolGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& max_index, + const DenseTensor& out_grad, + bool is_test, + std::string pooltype, + float pad_value, + DenseTensor* x_grad) { + const phi::DenseTensor* index = nullptr; + if (pooltype == "MAX") { + index = &max_index; + } + dev_ctx.template Alloc(x_grad); + phi::funcs::SequencePoolGradFunctor pool; + pool(dev_ctx, pooltype, out_grad, x_grad, index); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h index 0cd40e72b54de..a417c04aae68c 100644 --- a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/sequence_pooling.h" namespace phi { @@ -62,7 +63,6 @@ void SequencePoolKernel(const Context& ctx, } dims[0] = lod[lod_level - 1].size() - 1; out->Resize({dims}); - // out->mutable_data(ctx.GetPlace()); ctx.template Alloc(out); phi::DenseTensor* index = nullptr; @@ -71,7 +71,6 @@ void SequencePoolKernel(const Context& ctx, (is_test || (ctx.GetPlace() == phi::CPUPlace()) == false)) { index = max_index; index->Resize({dims}); - // index->mutable_data(ctx.GetPlace()); ctx.template Alloc(index); } phi::funcs::SequencePoolFunctor pool; diff --git a/paddle/phi/kernels/sequence_pool_grad_kernel.h b/paddle/phi/kernels/sequence_pool_grad_kernel.h new file mode 100644 index 0000000000000..cb44c71135d16 --- /dev/null +++ b/paddle/phi/kernels/sequence_pool_grad_kernel.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void SequencePoolGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& max_index, + const DenseTensor& out_grad, + bool is_test, + std::string pooltype, + float pad_value, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/ops/compat/sequence_pool_sig.cc b/paddle/phi/ops/compat/sequence_pool_sig.cc index db781a6677b30..1af02c68e002e 100644 --- a/paddle/phi/ops/compat/sequence_pool_sig.cc +++ b/paddle/phi/ops/compat/sequence_pool_sig.cc @@ -24,6 +24,16 @@ KernelSignature SequencePoolOpArgumentMapping( {"Out", "MaxIndex"}); } +KernelSignature SequencePoolGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("sequence_pool_grad", + {"X", "MaxIndex", GradVarName("Out")}, + {"is_test", "pooltype", "pad_value"}, + {GradVarName("X")}); +} + } // namespace phi PD_REGISTER_ARG_MAPPING_FN(sequence_pool, phi::SequencePoolOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(sequence_pool_grad, + phi::SequencePoolGradOpArgumentMapping); From e1b4fe3f3381b0f61181df86924c1149ad800f39 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 4 Apr 2023 21:24:19 +0800 Subject: [PATCH 07/23] [phi][compat] sig rm GradVarName --- paddle/phi/ops/compat/sequence_pool_sig.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/ops/compat/sequence_pool_sig.cc b/paddle/phi/ops/compat/sequence_pool_sig.cc index 1af02c68e002e..992fa92ecb995 100644 --- a/paddle/phi/ops/compat/sequence_pool_sig.cc +++ b/paddle/phi/ops/compat/sequence_pool_sig.cc @@ -27,9 +27,9 @@ KernelSignature SequencePoolOpArgumentMapping( KernelSignature SequencePoolGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("sequence_pool_grad", - {"X", "MaxIndex", GradVarName("Out")}, + {"X", "MaxIndex", "Out"}, {"is_test", "pooltype", "pad_value"}, - {GradVarName("X")}); + {"X"}); } } // namespace phi From 720e1052068e47fede7bb6a163e84714c32f4b92 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 4 Apr 2023 23:07:54 +0800 Subject: [PATCH 08/23] [phi] fix sequence_pool out type --- paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h | 4 ++-- paddle/phi/kernels/impl/sequence_pool_kernel_impl.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h index 442034b4b1dba..1cce08baea70f 100644 --- a/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h @@ -32,8 +32,8 @@ void SequencePoolGradKernel(const Context& dev_ctx, if (pooltype == "MAX") { index = &max_index; } - dev_ctx.template Alloc(x_grad); - phi::funcs::SequencePoolGradFunctor pool; + dev_ctx.template Alloc(x_grad); + phi::funcs::SequencePoolGradFunctor pool; pool(dev_ctx, pooltype, out_grad, x_grad, index); } diff --git a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h index a417c04aae68c..7a35990fc1092 100644 --- a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h @@ -71,7 +71,7 @@ void SequencePoolKernel(const Context& ctx, (is_test || (ctx.GetPlace() == phi::CPUPlace()) == false)) { index = max_index; index->Resize({dims}); - ctx.template Alloc(index); + ctx.template Alloc(index); } phi::funcs::SequencePoolFunctor pool; pool(ctx, pooltype, pad_value_, x, out, is_test, index); From 4f5ca20aab4d0a397d428c7b4d064cfb136a41b6 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Wed, 5 Apr 2023 18:30:09 +0800 Subject: [PATCH 09/23] [phi] rm impl, add const string --- .../kernels/cpu/sequence_pool_grad_kernel.cc | 24 +++++- .../phi/kernels/cpu/sequence_pool_kernel.cc | 65 ++++++++++++++- .../kernels/gpu/sequence_pool_grad_kernel.cu | 3 + .../phi/kernels/gpu/sequence_pool_kernel.cu | 2 +- .../impl/sequence_pool_grad_kernel_impl.h | 40 ---------- .../kernels/impl/sequence_pool_kernel_impl.h | 80 ------------------- .../phi/kernels/sequence_pool_grad_kernel.h | 2 +- paddle/phi/kernels/sequence_pool_kernel.h | 2 +- 8 files changed, 93 insertions(+), 125 deletions(-) delete mode 100644 paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h delete mode 100644 paddle/phi/kernels/impl/sequence_pool_kernel_impl.h diff --git a/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc index 9f15ed0147031..011a47343d849 100644 --- a/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc @@ -16,7 +16,29 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h" +#include "paddle/phi/kernels/funcs/sequence_pooling.h" + +namespace phi { + +template +void SequencePoolGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& max_index, + const DenseTensor& out_grad, + bool is_test, + const std::string pooltype, + float pad_value, + DenseTensor* x_grad) { + const phi::DenseTensor* index = nullptr; + if (pooltype == "MAX") { + index = &max_index; + } + dev_ctx.template Alloc(x_grad); + phi::funcs::SequencePoolGradFunctor pool; + pool(dev_ctx, pooltype, out_grad, x_grad, index); +} + +} // namespace phi PD_REGISTER_KERNEL(sequence_pool_grad, CPU, diff --git a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc index 4bc0e03983d5b..fc8c8a06b53ad 100644 --- a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc +++ b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc @@ -16,7 +16,70 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/sequence_pool_kernel_impl.h" +#include "paddle/phi/kernels/funcs/sequence_pooling.h" + +namespace phi { + +template +void SequencePoolKernel(const Context& ctx, + const DenseTensor& x, + bool is_test, + const std::string pooltype, + float pad_value, + DenseTensor* out, + DenseTensor* max_index) { + T pad_value_ = static_cast(pad_value); + + auto dims = x.dims(); + auto lod = x.lod(); + auto lod_level = lod.size(); + // InferShape by lod + PADDLE_ENFORCE_GT( + lod_level, + 0, + errors::InvalidArgument("Input(X) phi::DenseTensor of SequencePoolOp " + "does not contain LoD information.")); + PADDLE_ENFORCE_LE( + lod_level, + 2UL, + errors::InvalidArgument("The lod level of input shall be no more than 2." + "Received lod level is %d.", + lod_level)); + PADDLE_ENFORCE_GE( + dims[0], + /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), + errors::InvalidArgument( + "The first dimension of Input(X) must be large than batch size." + "But received first dimension of Input(X) is %d, while batch" + "size is %d.", + dims[0], + static_cast(lod[lod_level - 1].size() - 1))); + if (lod_level > 1UL) { + PADDLE_ENFORCE_EQ( + lod[0][lod[0].size() - 1], + lod[1].size() - 1, + errors::InvalidArgument("The input lod information is illegal.")); + phi::LoD out_lod; + out_lod.push_back(lod[0]); + out->set_lod(out_lod); + } + dims[0] = lod[lod_level - 1].size() - 1; + out->Resize({dims}); + ctx.template Alloc(out); + phi::DenseTensor* index = nullptr; + + // Do not create index buffer for inference mode + if (pooltype == "MAX" && + (is_test || (ctx.GetPlace() == phi::CPUPlace()) == false)) { + index = max_index; + index->Resize({dims}); + ctx.template Alloc(index); + } + phi::funcs::SequencePoolFunctor pool; + pool(ctx, pooltype, pad_value_, x, out, is_test, index); +} + +} // namespace phi PD_REGISTER_KERNEL( sequence_pool, CPU, ALL_LAYOUT, phi::SequencePoolKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu b/paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu index 0aaeea6e7a6e1..1dc973ec3e5fe 100644 --- a/paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu @@ -12,5 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/sequence_pool_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" + PD_REGISTER_KERNEL( sequence_pool_grad, GPU, ALL_LAYOUT, phi::SequencePoolGradKernel, float) {} diff --git a/paddle/phi/kernels/gpu/sequence_pool_kernel.cu b/paddle/phi/kernels/gpu/sequence_pool_kernel.cu index e151357c038ad..a89183254fc8e 100644 --- a/paddle/phi/kernels/gpu/sequence_pool_kernel.cu +++ b/paddle/phi/kernels/gpu/sequence_pool_kernel.cu @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/sequence_pool_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/kernels/impl/sequence_pool_kernel_impl.h" PD_REGISTER_KERNEL( sequence_pool, GPU, ALL_LAYOUT, phi::SequencePoolKernel, float) {} diff --git a/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h deleted file mode 100644 index 1cce08baea70f..0000000000000 --- a/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/sequence_pooling.h" - -namespace phi { - -template -void SequencePoolGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& max_index, - const DenseTensor& out_grad, - bool is_test, - std::string pooltype, - float pad_value, - DenseTensor* x_grad) { - const phi::DenseTensor* index = nullptr; - if (pooltype == "MAX") { - index = &max_index; - } - dev_ctx.template Alloc(x_grad); - phi::funcs::SequencePoolGradFunctor pool; - pool(dev_ctx, pooltype, out_grad, x_grad, index); -} - -} // namespace phi diff --git a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h deleted file mode 100644 index 7a35990fc1092..0000000000000 --- a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/sequence_pooling.h" - -namespace phi { -template -void SequencePoolKernel(const Context& ctx, - const DenseTensor& x, - bool is_test, - std::string pooltype, - float pad_value, - DenseTensor* out, - DenseTensor* max_index) { - T pad_value_ = static_cast(pad_value); - - auto dims = x.dims(); - auto lod = x.lod(); - auto lod_level = lod.size(); - // InferShape by lod - PADDLE_ENFORCE_GT( - lod_level, - 0, - errors::InvalidArgument("Input(X) phi::DenseTensor of SequencePoolOp " - "does not contain LoD information.")); - PADDLE_ENFORCE_LE( - lod_level, - 2UL, - errors::InvalidArgument("The lod level of input shall be no more than 2." - "Received lod level is %d.", - lod_level)); - PADDLE_ENFORCE_GE( - dims[0], - /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), - errors::InvalidArgument( - "The first dimension of Input(X) must be large than batch size." - "But received first dimension of Input(X) is %d, while batch" - "size is %d.", - dims[0], - static_cast(lod[lod_level - 1].size() - 1))); - if (lod_level > 1UL) { - PADDLE_ENFORCE_EQ( - lod[0][lod[0].size() - 1], - lod[1].size() - 1, - errors::InvalidArgument("The input lod information is illegal.")); - phi::LoD out_lod; - out_lod.push_back(lod[0]); - out->set_lod(out_lod); - } - dims[0] = lod[lod_level - 1].size() - 1; - out->Resize({dims}); - ctx.template Alloc(out); - phi::DenseTensor* index = nullptr; - - // Do not create index buffer for inference mode - if (pooltype == "MAX" && - (is_test || (ctx.GetPlace() == phi::CPUPlace()) == false)) { - index = max_index; - index->Resize({dims}); - ctx.template Alloc(index); - } - phi::funcs::SequencePoolFunctor pool; - pool(ctx, pooltype, pad_value_, x, out, is_test, index); -} - -} // namespace phi diff --git a/paddle/phi/kernels/sequence_pool_grad_kernel.h b/paddle/phi/kernels/sequence_pool_grad_kernel.h index cb44c71135d16..cc66f45645a39 100644 --- a/paddle/phi/kernels/sequence_pool_grad_kernel.h +++ b/paddle/phi/kernels/sequence_pool_grad_kernel.h @@ -23,7 +23,7 @@ void SequencePoolGradKernel(const Context& dev_ctx, const DenseTensor& max_index, const DenseTensor& out_grad, bool is_test, - std::string pooltype, + const std::string pooltype, float pad_value, DenseTensor* x_grad); diff --git a/paddle/phi/kernels/sequence_pool_kernel.h b/paddle/phi/kernels/sequence_pool_kernel.h index 0480376984fd6..84e1990000a19 100644 --- a/paddle/phi/kernels/sequence_pool_kernel.h +++ b/paddle/phi/kernels/sequence_pool_kernel.h @@ -21,7 +21,7 @@ template void SequencePoolKernel(const Context& ctx, const DenseTensor& x, bool is_test, - std::string pooltype, + const std::string pooltype, float pad_value, DenseTensor* out, DenseTensor* max_index); From 55b65924f00675ca2b550c1b019b01dd9c075bee Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Wed, 5 Apr 2023 23:46:26 +0800 Subject: [PATCH 10/23] [phi] fix const str --- paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/sequence_pool_kernel.cc | 2 +- paddle/phi/kernels/sequence_pool_grad_kernel.h | 2 +- paddle/phi/kernels/sequence_pool_kernel.h | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc index 011a47343d849..8d450abe69e2a 100644 --- a/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc @@ -26,7 +26,7 @@ void SequencePoolGradKernel(const Context& dev_ctx, const DenseTensor& max_index, const DenseTensor& out_grad, bool is_test, - const std::string pooltype, + const std::string& pooltype, float pad_value, DenseTensor* x_grad) { const phi::DenseTensor* index = nullptr; diff --git a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc index fc8c8a06b53ad..2ddedaaec06c9 100644 --- a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc +++ b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc @@ -24,7 +24,7 @@ template void SequencePoolKernel(const Context& ctx, const DenseTensor& x, bool is_test, - const std::string pooltype, + const std::string& pooltype, float pad_value, DenseTensor* out, DenseTensor* max_index) { diff --git a/paddle/phi/kernels/sequence_pool_grad_kernel.h b/paddle/phi/kernels/sequence_pool_grad_kernel.h index cc66f45645a39..1117a96d4f1a2 100644 --- a/paddle/phi/kernels/sequence_pool_grad_kernel.h +++ b/paddle/phi/kernels/sequence_pool_grad_kernel.h @@ -23,7 +23,7 @@ void SequencePoolGradKernel(const Context& dev_ctx, const DenseTensor& max_index, const DenseTensor& out_grad, bool is_test, - const std::string pooltype, + const std::string& pooltype, float pad_value, DenseTensor* x_grad); diff --git a/paddle/phi/kernels/sequence_pool_kernel.h b/paddle/phi/kernels/sequence_pool_kernel.h index 84e1990000a19..7d4dcba67add9 100644 --- a/paddle/phi/kernels/sequence_pool_kernel.h +++ b/paddle/phi/kernels/sequence_pool_kernel.h @@ -21,7 +21,7 @@ template void SequencePoolKernel(const Context& ctx, const DenseTensor& x, bool is_test, - const std::string pooltype, + const std::string& pooltype, float pad_value, DenseTensor* out, DenseTensor* max_index); From b22dc4b16c46708ac3d4a4269a2aa1a924ef9646 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sat, 8 Apr 2023 00:47:45 +0800 Subject: [PATCH 11/23] fix sequence_pooling cmake --- paddle/fluid/operators/math/CMakeLists.txt | 11 +- .../fluid/operators/math/sequence_pooling.cc | 504 ------------------ .../fluid/operators/math/sequence_pooling.cu | 503 ----------------- .../fluid/operators/math/sequence_pooling.h | 52 -- .../sequence_ops/unity_build_rule.cmake | 1 - paddle/phi/kernels/CMakeLists.txt | 1 + paddle/phi/kernels/funcs/CMakeLists.txt | 1 + paddle/phi/kernels/funcs/sequence_pooling.h | 2 +- 8 files changed, 8 insertions(+), 1067 deletions(-) delete mode 100644 paddle/fluid/operators/math/sequence_pooling.cc delete mode 100644 paddle/fluid/operators/math/sequence_pooling.cu delete mode 100644 paddle/fluid/operators/math/sequence_pooling.h diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 61cc7dc9f4b64..9278ff5e904af 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -18,8 +18,6 @@ math_library(sample_prob) math_library(sampler DEPS generator) # math_library(math_function DEPS blas dense_tensor tensor) - -math_library(sequence_pooling DEPS math_function jit_kernel_helper) if(WITH_ASCEND_CL) math_library(beam_search DEPS math_function beam_search_npu) elseif(WITH_XPU) @@ -45,10 +43,11 @@ cc_test( vol2col_test SRCS vol2col_test.cc DEPS vol2col) -cc_test( - sequence_pooling_test - SRCS sequence_pooling_test.cc - DEPS sequence_pooling) +# TODO mv test +#cc_test( +# sequence_pooling_test +# SRCS sequence_pooling_test.cc +# DEPS sequence_pooling) cc_test( beam_search_test SRCS beam_search_test.cc diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc deleted file mode 100644 index 8dbeff2bce135..0000000000000 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ /dev/null @@ -1,504 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/sequence_pooling.h" - -#include - -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/jit/kernels.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { -namespace math { - -template -using EigenVector = phi::EigenVector; -template -using EigenMatrix = phi::EigenMatrix; - -template -class MaxSeqPoolFunctor { - public: - void operator()(const phi::CPUContext& context, - const phi::DenseTensor& input, - T pad_value, - phi::DenseTensor* output, - phi::DenseTensor* index) { - auto in_dims = input.dims(); - auto out_dims = output->dims(); - auto idx_dims = index->dims(); - PADDLE_ENFORCE_GT(in_dims.size(), - 1, - platform::errors::InvalidArgument( - "The rank of input shall be greater than 1, but got " - "the rank is %ld. Please check the input value", - in_dims.size())); - PADDLE_ENFORCE_GT(out_dims.size(), - 1, - platform::errors::InvalidArgument( - "The rank of output shall be greater than 1, but got " - "the rank is %ld. Please check the input value", - out_dims.size())); - for (int64_t i = 1; i < in_dims.size(); ++i) { - PADDLE_ENFORCE_EQ( - in_dims[i], - out_dims[i], - platform::errors::InvalidArgument( - "The dimension of input and output shall be same. Expected %ld " - "== %ld, but got %ld != %ld. Please check the input value.", - in_dims[i], - out_dims[i], - in_dims[i], - out_dims[i])); - } - PADDLE_ENFORCE_EQ( - idx_dims, - out_dims, - platform::errors::InvalidArgument( - "The dimension of index and output shall be same. Expected %ld == " - "%ld, but got %ld != %ld. Please check the input value.", - idx_dims, - out_dims, - idx_dims, - out_dims)); - - auto lod_level = input.lod().size(); - auto starts = input.lod()[lod_level - 1]; - const T* in_data = input.data(); - T* out_data = output->data(); - int* max_index = index->data(); - - int64_t num_seq = out_dims[0]; - int64_t dim = output->numel() / num_seq; - for (int64_t i = 0; i < num_seq; ++i) { - if (starts[i] == starts[i + 1]) { - for (int64_t k = 0; k < dim; ++k) { - out_data[i * dim + k] = pad_value; - max_index[i * dim + k] = -1; - } - continue; - } - for (int64_t k = 0; k < dim; ++k) { - out_data[i * dim + k] = in_data[starts[i] * dim + k]; - max_index[i * dim + k] = starts[i]; - } - for (size_t j = starts[i] + 1; j < starts[i + 1]; ++j) { - for (int64_t k = 0; k < dim; ++k) { - if (in_data[j * dim + k] > out_data[i * dim + k]) { - out_data[i * dim + k] = in_data[j * dim + k]; - max_index[i * dim + k] = j; - } - } - } - } - } -}; -// Instantisation of Max Sequence Pooling for test phase eg. no need to fill -// index buffer -template -class MaxSeqPoolFunctor { - public: - void operator()(const phi::CPUContext& context, - const phi::DenseTensor& input, - T pad_value, - phi::DenseTensor* output, - phi::DenseTensor* index) { - auto in_dims = input.dims(); - auto out_dims = output->dims(); - PADDLE_ENFORCE_GT(in_dims.size(), - 1, - platform::errors::InvalidArgument( - "The rank of input shall be greater than 1, but got " - "%ld <= 1. Please check the input value.", - in_dims.size())); - PADDLE_ENFORCE_GT(out_dims.size(), - 1, - platform::errors::InvalidArgument( - "The rank of output shall be greater than 1, but got " - "%ld <= 1. Please check the input value.", - out_dims.size())); - for (int64_t i = 1; i < in_dims.size(); ++i) { - PADDLE_ENFORCE_EQ( - in_dims[i], - out_dims[i], - platform::errors::InvalidArgument( - "The dimension of input and output shall be same. Expected %ld " - "== %ld, but got %ld != %ld. Please check the input value.", - in_dims[i], - out_dims[i], - in_dims[i], - out_dims[i])); - } - - auto lod_level = input.lod().size(); - auto starts = input.lod()[lod_level - 1]; - const T* in_data = input.data(); - T* out_data = output->data(); - - int64_t num_seq = out_dims[0]; - int64_t dim = output->numel() / num_seq; - for (int64_t i = 0; i < num_seq; ++i) { - if (starts[i] == starts[i + 1]) { - for (int64_t k = 0; k < dim; ++k) { - out_data[i * dim + k] = pad_value; - } - continue; - } - std::memcpy( - &out_data[i * dim], &in_data[starts[i] * dim], dim * sizeof(T)); - for (size_t j = starts[i] + 1; j < starts[i + 1]; ++j) { - for (int64_t k = 0; k < dim; ++k) { - if (in_data[j * dim + k] > out_data[i * dim + k]) { - out_data[i * dim + k] = in_data[j * dim + k]; - } - } - } - } - } -}; -template -class MaxSeqPoolGradFunctor { - public: - void operator()(const phi::CPUContext& context, - const phi::DenseTensor& out_grad, - const phi::DenseTensor& index, - phi::DenseTensor* in_grad) { - auto og_dims = out_grad.dims(); - auto ig_dims = in_grad->dims(); - auto idx_dims = index.dims(); - PADDLE_ENFORCE_GT(og_dims.size(), - 1, - platform::errors::InvalidArgument( - "The rank of output@Grad shall be greater than 1, " - "but got %ld <= 1. Please check the input value.", - og_dims.size())); - PADDLE_ENFORCE_GT(ig_dims.size(), - 1, - platform::errors::InvalidArgument( - "The rank of input@Grad shall be greater than 1, but " - "got %ld <= 1. Please check the input value.", - ig_dims.size())); - for (int64_t i = 1; i < og_dims.size(); ++i) { - PADDLE_ENFORCE_EQ(og_dims[i], - ig_dims[i], - platform::errors::InvalidArgument( - "The dimension of input@Grad and output@Grad shall " - "be same. Expected %ld == %ld, but got %ld != %ld. " - "Please check the input value.", - og_dims[i], - ig_dims[i], - og_dims[i], - ig_dims[i])); - } - PADDLE_ENFORCE_EQ( - idx_dims, - og_dims, - platform::errors::InvalidArgument( - "The dimension of index and output@Grad shall be same. Expected " - "%ld == %ld, but got %ld != %ld. Please check the input value.", - idx_dims, - og_dims, - idx_dims, - og_dims)); - - const T* og_data = out_grad.data(); - const int* max_index = index.data(); - T* ig_data = in_grad->data(); - - phi::funcs::SetConstant set_zero; - set_zero(context, in_grad, static_cast(0.0)); - int64_t num_seq = og_dims[0]; - int64_t dim = out_grad.numel() / num_seq; - for (int64_t i = 0; i < num_seq; ++i) { - for (int64_t j = 0; j < dim; ++j) { - int step_id = max_index[i * dim + j]; - if (step_id == -1) continue; - ig_data[step_id * dim + j] = og_data[i * dim + j]; - } - } - } -}; - -template -class LastSeqPoolFunctor { - public: - void operator()(const phi::CPUContext& context, - const phi::DenseTensor& input, - T pad_value, - phi::DenseTensor* output) { - // Create pointers to input and output data - auto* in_data = input.data(); - auto* out_data = output->data(); - - // Calculate the size of each item in sequence - int64_t item_size = input.numel() / input.dims()[0]; - auto lod_level = input.lod().size(); - auto lod = input.lod()[lod_level - 1]; - int seq_num = static_cast(lod.size()) - 1; - for (int i = 0; i < seq_num; ++i) { - // Calculate the length of each sequence - int64_t seq_len = static_cast(lod[i + 1] - lod[i]); - if (seq_len == 0) { - for (int j = 0; j < item_size; ++j) { - out_data[j] = pad_value; - } - } else { - // Point to the begin of next sequence - in_data += seq_len * item_size; - // Copy the last item of sequence to output - std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T)); - } - out_data += item_size; - } - } -}; - -template -class FirstSeqPoolFunctor { - public: - void operator()(const phi::CPUContext& context, - const phi::DenseTensor& input, - T pad_value, - phi::DenseTensor* output) { - // Create pointers to input and output data - auto* in_data = input.data(); - auto* out_data = output->data(); - - // Calculate the size of each item in sequence - int64_t item_size = input.numel() / input.dims()[0]; - auto lod_level = input.lod().size(); - auto lod = input.lod()[lod_level - 1]; - int seq_num = static_cast(lod.size()) - 1; - for (int i = 0; i < seq_num; ++i) { - // Calculate the length of each sequence - int64_t seq_len = static_cast(lod[i + 1] - lod[i]); - if (seq_len == 0) { - for (int j = 0; j < item_size; ++j) { - out_data[j] = pad_value; - } - } else { - // Copy the first item of sequence to output - std::memcpy(out_data, in_data, item_size * sizeof(T)); - // Point to the next sequence - in_data += seq_len * item_size; - } - out_data += item_size; - } - } -}; - -template -class SumSeqPoolGradFunctor { - public: - void operator()(const phi::CPUContext& context, - const phi::DenseTensor& out_grad, - phi::DenseTensor* in_grad) { - auto lod_level = in_grad->lod().size(); - auto lod = in_grad->lod()[lod_level - 1]; - int64_t out_w = out_grad.numel() / out_grad.dims()[0]; - int64_t in_w = in_grad->numel() / in_grad->dims()[0]; - PADDLE_ENFORCE_EQ(in_w, - out_w, - platform::errors::InvalidArgument( - "The feature size of input@Grad and output@Grad " - "shall be same. Expected %ld == %ld, but got %ld != " - "%ld. Please check the input value.", - in_w, - out_w, - in_w, - out_w)); - const T* out_g_data = out_grad.data(); - T* in_g_data = in_grad->mutable_data(context.GetPlace()); - auto blas = phi::funcs::GetBlas(context); - for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { - int64_t h = static_cast(lod[i + 1] - lod[i]); - if (h == 0) continue; - int64_t in_offset = lod[i] * in_w; - const T* out_pos = out_g_data + i * out_w; - T* in_pos = in_g_data + in_offset; - for (int r = 0; r != h; ++r) { - blas.VCOPY(in_w, out_pos, in_pos + r * in_w); - } - } - } -}; - -template -class SequencePoolFunctor { - public: - /* max pool has index output */ - void operator()(const phi::CPUContext& context, - const std::string pooltype, - T pad_value, - const phi::DenseTensor& input, - phi::DenseTensor* output, - bool is_test, - phi::DenseTensor* index = nullptr) { - if (pooltype == "MAX") { - if (is_test) { - math::MaxSeqPoolFunctor max_pool; - max_pool(context, input, pad_value, output, index); - } else { - math::MaxSeqPoolFunctor max_pool; - max_pool(context, input, pad_value, output, index); - } - return; - } - if (pooltype == "LAST") { - math::LastSeqPoolFunctor last_pool; - last_pool(context, input, pad_value, output); - return; - } - if (pooltype == "FIRST") { - math::FirstSeqPoolFunctor first_pool; - first_pool(context, input, pad_value, output); - return; - } - auto lod_level = input.lod().size(); - auto lod = input.lod()[lod_level - 1]; - if (pooltype == "SUM") { - auto place = context.GetPlace(); - PADDLE_ENFORCE_EQ( - platform::is_cpu_place(place), - true, - platform::errors::InvalidArgument( - "Sequence_pool should run on CPU Device when pooltype is SUM")); - const T* src = input.data(); - T* dst = output->mutable_data(place); - phi::jit::seq_pool_attr_t attr( - static_cast(input.numel() / input.dims()[0]), - phi::jit::SeqPoolType::kSum); - auto seqpool = phi::jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(attr); - for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { - attr.h = static_cast(lod[i + 1] - lod[i]); - if (attr.h == 0) { - for (int j = 0; j < attr.w; ++j) { - dst[j] = pad_value; - } - } else { - seqpool(src, dst, &attr); - } - dst += attr.w; - src += attr.h * attr.w; - } - return; - } - auto& place = *context.eigen_device(); - for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { - phi::DenseTensor out_t = output->Slice(i, i + 1); - int64_t w = input.numel() / input.dims()[0]; - if (lod[i] == lod[i + 1]) { - for (int j = 0; j < w; ++j) { - out_t.data()[j] = pad_value; - } - continue; - } - phi::DenseTensor in_t = - input.Slice(static_cast(lod[i]), static_cast(lod[i + 1])); - int64_t h = static_cast(lod[i + 1] - lod[i]); - auto in_e = EigenMatrix::From(in_t, phi::make_ddim({h, w})); - auto out_e = EigenVector::Flatten(out_t); - if (pooltype == "AVERAGE") { - out_e.device(place) = in_e.mean(Eigen::array({{0}})); - } else if (pooltype == "SQRT") { - out_e.device(place) = in_e.sum(Eigen::array({{0}})) / - std::sqrt(static_cast(h)); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "unsupported pooling pooltype: %s. Only support \"AVERAGE\" and " - "\"SQRT\"", - pooltype)); - } - } - } -}; - -template -class SequencePoolGradFunctor { - public: - void operator()(const phi::CPUContext& context, - const std::string pooltype, - const phi::DenseTensor& out_grad, - phi::DenseTensor* in_grad, - /* max pool has index */ - const phi::DenseTensor* index = nullptr) { - if (pooltype == "MAX") { - math::MaxSeqPoolGradFunctor max_pool_grad; - max_pool_grad(context, out_grad, *index, in_grad); - return; - } - - if (pooltype == "LAST" || pooltype == "FIRST") { - // set X@Grad be zero at first when pooltype is LAST/FIRST - phi::funcs::SetConstant functor; - functor(context, in_grad, 0); - } - - if (pooltype == "SUM") { - math::SumSeqPoolGradFunctor sum_pool_grad; - sum_pool_grad(context, out_grad, in_grad); - return; - } - - auto lod_level = in_grad->lod().size(); - auto lod = in_grad->lod()[lod_level - 1]; - auto& place = *context.eigen_device(); - for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { - if (lod[i] == lod[i + 1]) continue; - auto in_g_t = in_grad->Slice(static_cast(lod[i]), - static_cast(lod[i + 1])); - auto out_g_t = out_grad.Slice(i, i + 1); - int64_t h = static_cast(lod[i + 1] - lod[i]); - int64_t w = in_grad->numel() / in_grad->dims()[0]; - auto in_g_e = EigenMatrix::From(in_g_t, {h, w}); - auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); - auto out_g_e_v = EigenVector::Flatten(out_g_t); - Eigen::DSizes bcast(h, 1); - - if (pooltype == "AVERAGE") { - in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); - } else if (pooltype == "SQRT") { - in_g_e.device(place) = - (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); - } else if (pooltype == "LAST") { - in_g_e.chip(h - 1, 0).device(place) = out_g_e_v; - } else if (pooltype == "FIRST") { - in_g_e.chip(0, 0).device(place) = out_g_e_v; - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "unsupported pooling pooltype: %s. Only support \"AVERAGE\", " - "\"SQRT\", \"LAST\" and \"FIRST\"", - pooltype)); - } - } - } -}; - -template class SequencePoolFunctor; -template class SequencePoolFunctor; -template class SequencePoolGradFunctor; -template class SequencePoolGradFunctor; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/sequence_pooling.cu b/paddle/fluid/operators/math/sequence_pooling.cu deleted file mode 100644 index e56f0025a0e66..0000000000000 --- a/paddle/fluid/operators/math/sequence_pooling.cu +++ /dev/null @@ -1,503 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "paddle/fluid/operators/math/sequence_pooling.h" -#include "paddle/fluid/platform/macros.h" -#include "paddle/phi/backends/gpu/gpu_primitives.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { -namespace math { - -template -struct MaxPoolFunctor { - HOSTDEVICE void operator()(const T* input, - const T pad_value, - const size_t start, - const size_t end, - const size_t item_dim, - T* output, - int* index) { - for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - T max_val = static_cast(-FLT_MAX); - int max_index = -1; - if (start == end) { - output[tid] = pad_value; - index[tid] = -1; - } else { - for (int i = start; i < end; ++i) { - if (max_val < input[item_dim * i + tid]) { - max_val = input[item_dim * i + tid]; - max_index = i; - } - } - output[tid] = max_val; - index[tid] = max_index; - } - } - } -}; - -template -struct AvgPoolFunctor { - HOSTDEVICE void operator()(const T* input, - const T pad_value, - const size_t start, - const size_t end, - const size_t item_dim, - T* output, - int* index) { - for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - if (start == end) { - output[tid] = pad_value; - } else { - T val = static_cast(0); - for (int i = start; i < end; ++i) { - val += input[item_dim * i + tid]; - } - // end, start is lod, so end - start != 0 - output[tid] = val / static_cast(end - start); - } - } - } -}; - -template -struct SumPoolFunctor { - HOSTDEVICE void operator()(const T* input, - const T pad_value, - const size_t start, - const size_t end, - const size_t item_dim, - T* output, - int* index) { - for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - if (start == end) { - output[tid] = pad_value; - } else { - T val = static_cast(0); - for (int i = start; i < end; ++i) { - val += input[item_dim * i + tid]; - } - output[tid] = val; - } - } - } -}; - -template -struct SqrtPoolFunctor { - HOSTDEVICE void operator()(const T* input, - const T pad_value, - const size_t start, - const size_t end, - const size_t item_dim, - T* output, - int* index) { - for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - if (start == end) { - output[tid] = pad_value; - } else { - T val = static_cast(0); - for (int i = start; i < end; ++i) { - val += input[item_dim * i + tid]; - } - // end, start is lod, so end - start != 0 - output[tid] = val / sqrt(end - start); - } - } - } -}; - -template -struct LastPoolFunctor { - HOSTDEVICE void operator()(const T* input, - const T pad_value, - const size_t start, - const size_t end, - const size_t item_dim, - T* output, - int* index) { - for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - if (start == end) { - output[tid] = pad_value; - } else { - output[tid] = input[item_dim * (end - 1) + tid]; - } - } - } -}; - -template -struct FirstPoolFunctor { - HOSTDEVICE void operator()(const T* input, - const T pad_value, - const size_t start, - const size_t end, - const size_t item_dim, - T* output, - int* index) { - for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - if (start == end) { - output[tid] = pad_value; - } else { - output[tid] = input[item_dim * start + tid]; - } - } - } -}; - -template -__global__ void sequence_pool_kernel(Range_OP op, - const T* input, - const T pad_value, - const size_t* lod, - const size_t lod_size, - const size_t item_dim, - T* output, - int* index) { - int bid = blockIdx.x; - if (bid >= lod_size - 1) return; - size_t start = lod[bid]; - size_t end = lod[bid + 1]; - int* index_offset = nullptr; - if (index != nullptr) { - index_offset = &index[bid * item_dim]; - } - op(input, - pad_value, - start, - end, - item_dim, - &output[bid * item_dim], - index_offset); -} - -template -class SequencePoolFunctor { - public: - void operator()(const phi::GPUContext& context, - const std::string pooltype, - T pad_value, - const phi::DenseTensor& input, - phi::DenseTensor* output, - bool is_test, - phi::DenseTensor* index = nullptr) { - auto lod_level = input.lod().size(); - auto& lod = input.lod()[lod_level - 1]; - const size_t item_dim = output->numel() / output->dims()[0]; - dim3 threads(1024, 1); - dim3 grid(std::max(static_cast(lod.size()) - 1, 1), 1); - phi::MixVector mix_vector(&lod); - if (pooltype == "MAX") { - sequence_pool_kernel> - <<>>( - MaxPoolFunctor(), - input.data(), - pad_value, - mix_vector.CUDAData(context.GetPlace()), - lod.size(), - item_dim, - output->mutable_data(context.GetPlace()), - index->data()); - } else if (pooltype == "AVERAGE") { - sequence_pool_kernel> - <<>>( - AvgPoolFunctor(), - input.data(), - pad_value, - mix_vector.CUDAData(context.GetPlace()), - lod.size(), - item_dim, - output->mutable_data(context.GetPlace()), - nullptr); - } else if (pooltype == "SUM") { - sequence_pool_kernel> - <<>>( - SumPoolFunctor(), - input.data(), - pad_value, - mix_vector.CUDAData(context.GetPlace()), - lod.size(), - item_dim, - output->mutable_data(context.GetPlace()), - nullptr); - } else if (pooltype == "SQRT") { - sequence_pool_kernel> - <<>>( - SqrtPoolFunctor(), - input.data(), - pad_value, - mix_vector.CUDAData(context.GetPlace()), - lod.size(), - item_dim, - output->mutable_data(context.GetPlace()), - nullptr); - } else if (pooltype == "LAST") { - sequence_pool_kernel> - <<>>( - LastPoolFunctor(), - input.data(), - pad_value, - mix_vector.CUDAData(context.GetPlace()), - lod.size(), - item_dim, - output->mutable_data(context.GetPlace()), - nullptr); - } else if (pooltype == "FIRST") { - sequence_pool_kernel> - <<>>( - FirstPoolFunctor(), - input.data(), - pad_value, - mix_vector.CUDAData(context.GetPlace()), - lod.size(), - item_dim, - output->mutable_data(context.GetPlace()), - nullptr); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "unsupported pooling pooltype: %s. Only support \"MAX\", " - "\"AVERAGE\", \"SUM\", \"SQRT\", \"LAST\" and \"FIRST\"", - pooltype)); - } - } -}; - -template -struct MaxPoolGradFunctor { - HOSTDEVICE void operator()(const T* out_grad, - const size_t start, - const size_t end, - const size_t item_dim, - T* in_grad, - const int* index) { - for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - for (int i = start; i < end; ++i) { - if (i == index[tid]) { - in_grad[item_dim * i + tid] = out_grad[tid]; - } else { - in_grad[item_dim * i + tid] = static_cast(0); - } - } - } - } -}; - -template -struct AvgPoolGradFunctor { - HOSTDEVICE void operator()(const T* out_grad, - const size_t start, - const size_t end, - const size_t item_dim, - T* in_grad, - const int* index) { - for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - for (int i = start; i < end; ++i) { - in_grad[item_dim * i + tid] = out_grad[tid] / (end - start); - } - } - } -}; - -template -struct SumPoolGradFunctor { - HOSTDEVICE void operator()(const T* out_grad, - const size_t start, - const size_t end, - const size_t item_dim, - T* in_grad, - const int* index) { - for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - for (int i = start; i < end; ++i) { - in_grad[item_dim * i + tid] = out_grad[tid]; - } - } - } -}; - -template -struct SqrtPoolGradFunctor { - HOSTDEVICE void operator()(const T* out_grad, - const size_t start, - const size_t end, - const size_t item_dim, - T* in_grad, - const int* index) { - for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - for (int i = start; i < end; ++i) { - in_grad[item_dim * i + tid] = - out_grad[tid] / (sqrt(static_cast(end - start))); - } - } - } -}; - -template -struct LastPoolGradFunctor { - HOSTDEVICE void operator()(const T* out_grad, - const size_t start, - const size_t end, - const size_t item_dim, - T* in_grad, - const int* index) { - for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - for (int i = start; i < end; ++i) { - if (i == end - 1) { - in_grad[item_dim * i + tid] = out_grad[tid]; - } else { - in_grad[item_dim * i + tid] = static_cast(0); - } - } - } - } -}; - -template -struct FirstPoolGradFunctor { - HOSTDEVICE void operator()(const T* out_grad, - const size_t start, - const size_t end, - const size_t item_dim, - T* in_grad, - const int* index) { - for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { - for (int i = start; i < end; ++i) { - if (i == start) { - in_grad[item_dim * i + tid] = out_grad[tid]; - } else { - in_grad[item_dim * i + tid] = static_cast(0); - } - } - } - } -}; - -template -__global__ void sequence_pool_grad_kernel(Range_OP op, - const T* out_grad, - const size_t* lod, - const size_t lod_size, - const size_t item_dim, - T* in_grad, - const int* index) { - int bid = blockIdx.x; - if (bid >= lod_size - 1) return; - size_t start = lod[bid]; - size_t end = lod[bid + 1]; - const int* index_offset = nullptr; - if (index != nullptr) { - index_offset = &index[bid * item_dim]; - } - op(&out_grad[bid * item_dim], start, end, item_dim, in_grad, index_offset); -} - -template -class SequencePoolGradFunctor { - public: - void operator()(const phi::GPUContext& context, - const std::string pooltype, - const phi::DenseTensor& out_grad, - phi::DenseTensor* in_grad, - /* max pool has index */ - const phi::DenseTensor* index = nullptr) { - auto lod_level = in_grad->lod().size(); - auto& lod = in_grad->lod()[lod_level - 1]; - const size_t item_dim = in_grad->numel() / in_grad->dims()[0]; - dim3 threads(1024, 1); - dim3 grid(std::max(static_cast(lod.size()) - 1, 1), 1); - phi::MixVector mix_vector(&lod); - if (pooltype == "MAX") { - sequence_pool_grad_kernel> - <<>>( - MaxPoolGradFunctor(), - out_grad.data(), - mix_vector.CUDAData(context.GetPlace()), - lod.size(), - item_dim, - in_grad->mutable_data(context.GetPlace()), - index->data()); - } else if (pooltype == "AVERAGE") { - sequence_pool_grad_kernel> - <<>>( - AvgPoolGradFunctor(), - out_grad.data(), - mix_vector.CUDAData(context.GetPlace()), - lod.size(), - item_dim, - in_grad->mutable_data(context.GetPlace()), - nullptr); - } else if (pooltype == "SUM") { - sequence_pool_grad_kernel> - <<>>( - SumPoolGradFunctor(), - out_grad.data(), - mix_vector.CUDAData(context.GetPlace()), - lod.size(), - item_dim, - in_grad->mutable_data(context.GetPlace()), - nullptr); - } else if (pooltype == "SQRT") { - sequence_pool_grad_kernel> - <<>>( - SqrtPoolGradFunctor(), - out_grad.data(), - mix_vector.CUDAData(context.GetPlace()), - lod.size(), - item_dim, - in_grad->mutable_data(context.GetPlace()), - nullptr); - } else if (pooltype == "LAST") { - sequence_pool_grad_kernel> - <<>>( - LastPoolGradFunctor(), - out_grad.data(), - mix_vector.CUDAData(context.GetPlace()), - lod.size(), - item_dim, - in_grad->mutable_data(context.GetPlace()), - nullptr); - } else if (pooltype == "FIRST") { - sequence_pool_grad_kernel> - <<>>( - FirstPoolGradFunctor(), - out_grad.data(), - mix_vector.CUDAData(context.GetPlace()), - lod.size(), - item_dim, - in_grad->mutable_data(context.GetPlace()), - nullptr); - - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "unsupported pooling pooltype: %s. Only support \"MAX\", " - "\"AVERAGE\", \"SUM\", \"SQRT\", \"LAST\" and \"FIRST\"", - pooltype)); - } - } -}; - -// sequence pooling -template class SequencePoolFunctor; -template class SequencePoolFunctor; -template class SequencePoolGradFunctor; -template class SequencePoolGradFunctor; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/sequence_pooling.h b/paddle/fluid/operators/math/sequence_pooling.h deleted file mode 100644 index 6a8e943d5d834..0000000000000 --- a/paddle/fluid/operators/math/sequence_pooling.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include - -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device_context.h" - -namespace paddle { -namespace operators { -namespace math { - -template -class SequencePoolFunctor { - public: - /* max pool has index output */ - void operator()(const DeviceContext& context, - const std::string pooltype, - T pad_value, - const phi::DenseTensor& input, - phi::DenseTensor* output, - bool is_test = false, - phi::DenseTensor* index = nullptr); -}; - -template -class SequencePoolGradFunctor { - public: - void operator()(const DeviceContext& context, - const std::string pooltype, - const phi::DenseTensor& out_grad, - phi::DenseTensor* in_grad, - /* max pool has index */ - const phi::DenseTensor* index = nullptr); -}; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake b/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake index 9a87e27b24197..607f81bce1e30 100644 --- a/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake +++ b/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake @@ -31,7 +31,6 @@ register_unity_group( sequence_expand_op.cu sequence_mask_op.cu sequence_pad_op.cu - sequence_pool_op.cu sequence_expand_as_op.cu sequence_reshape_op.cu sequence_reverse_op.cu diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index b0da4df9c8da9..fe8548f9fb5fc 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -74,6 +74,7 @@ set(COMMON_KERNEL_DEPS phi_dynload_warpctc phi_dynload_warprnnt sequence_padding + sequence_pooling sequence_scale fft phi_data_layout_transform diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index 20e97cb887b26..bd1774d756c4b 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -25,6 +25,7 @@ math_library(maxouting) math_library(matrix_bit_code) math_library(sequence_scale) math_library(sequence_padding DEPS lod_utils) +math_library(sequence_pooling DEPS math_function jit_kernel_helper) cc_library( phi_data_layout_transform diff --git a/paddle/phi/kernels/funcs/sequence_pooling.h b/paddle/phi/kernels/funcs/sequence_pooling.h index 937022728b233..8602d5e4cfc00 100644 --- a/paddle/phi/kernels/funcs/sequence_pooling.h +++ b/paddle/phi/kernels/funcs/sequence_pooling.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include -#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/dense_tensor.h" namespace phi { namespace funcs { From dd0af520a7fb112feb289222555798081953bf8d Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sat, 8 Apr 2023 13:48:34 +0800 Subject: [PATCH 12/23] [phi] mv sequence_pooling_test --- paddle/fluid/operators/math/CMakeLists.txt | 5 -- test/cpp/phi/kernels/CMakeLists.txt | 5 ++ .../cpp/phi/kernels}/sequence_pooling_test.cc | 47 ++++++++++--------- 3 files changed, 30 insertions(+), 27 deletions(-) rename {paddle/fluid/operators/math => test/cpp/phi/kernels}/sequence_pooling_test.cc (79%) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 9278ff5e904af..fbb192f29d0be 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -43,11 +43,6 @@ cc_test( vol2col_test SRCS vol2col_test.cc DEPS vol2col) -# TODO mv test -#cc_test( -# sequence_pooling_test -# SRCS sequence_pooling_test.cc -# DEPS sequence_pooling) cc_test( beam_search_test SRCS beam_search_test.cc diff --git a/test/cpp/phi/kernels/CMakeLists.txt b/test/cpp/phi/kernels/CMakeLists.txt index a9e897eb614dc..3e7f394f186da 100644 --- a/test/cpp/phi/kernels/CMakeLists.txt +++ b/test/cpp/phi/kernels/CMakeLists.txt @@ -105,3 +105,8 @@ cc_test( sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding) + +cc_test( + sequence_pooling_test + SRCS sequence_pooling_test.cc + DEPS sequence_pooling) diff --git a/paddle/fluid/operators/math/sequence_pooling_test.cc b/test/cpp/phi/kernels/sequence_pooling_test.cc similarity index 79% rename from paddle/fluid/operators/math/sequence_pooling_test.cc rename to test/cpp/phi/kernels/sequence_pooling_test.cc index dac5eb63bfc13..dfc841cfd7695 100644 --- a/paddle/fluid/operators/math/sequence_pooling_test.cc +++ b/test/cpp/phi/kernels/sequence_pooling_test.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,13 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/sequence_pooling.h" - #include +#include "paddle/phi/kernels/funcs/sequence_pooling.h" + +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/tensor_utils.h" + template void TestSequencePoolingSum(const DeviceContext &context, - const paddle::framework::LoD &lod, + const phi::LoD &lod, const int64_t second_dim) { phi::DenseTensor cpu_out_grad; phi::DenseTensor cpu_in_grad; @@ -30,17 +34,17 @@ void TestSequencePoolingSum(const DeviceContext &context, auto out_dims = phi::make_ddim({static_cast(out_first_dim), second_dim}); - cpu_out_grad.mutable_data(out_dims, paddle::platform::CPUPlace()); + cpu_out_grad.mutable_data(out_dims, phi::CPUPlace()); for (int64_t i = 0; i < cpu_out_grad.numel(); ++i) { cpu_out_grad.data()[i] = static_cast(i); } // copy to dst out_grad auto place = context.GetPlace(); - if (paddle::platform::is_cpu_place(place)) { + if (place == phi::CPUPlace()) { out_grad = cpu_out_grad; } else { - paddle::framework::TensorCopySync(cpu_out_grad, place, &out_grad); + phi::Copy(context, cpu_out_grad, place, true, &out_grad); } // construct in_grad @@ -53,7 +57,7 @@ void TestSequencePoolingSum(const DeviceContext &context, PADDLE_ENFORCE_EQ( in_grad.dims().size(), out_grad.dims().size(), - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of input and output shall be same. Expected %ld == " "%ld, but got %ld != %ld. Please check the input value.", in_grad.dims().size(), @@ -64,7 +68,7 @@ void TestSequencePoolingSum(const DeviceContext &context, PADDLE_ENFORCE_EQ( in_grad.dims()[i], out_grad.dims()[i], - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of input and output shall be same. Expected %ld == " "%ld, but got %ld != %ld. Please check the input value.", in_grad.dims()[i], @@ -74,21 +78,20 @@ void TestSequencePoolingSum(const DeviceContext &context, } // call functor - paddle::operators::math::SequencePoolGradFunctor()( + phi::funcs::SequencePoolGradFunctor()( context, "SUM", out_grad, &in_grad); - if (paddle::platform::is_cpu_place(place)) { + if (place == phi::CPUPlace()) { cpu_in_grad = in_grad; } else { - paddle::framework::TensorCopySync( - in_grad, paddle::platform::CPUPlace(), &cpu_in_grad); + phi::Copy(context, in_grad, phi::CPUPlace(), true, &cpu_in_grad); cpu_in_grad.set_lod(in_grad.lod()); } EXPECT_EQ(in_grad.numel(), static_cast(lod[0].back() * second_dim)); EXPECT_EQ(in_grad.lod(), lod); - if (paddle::platform::is_cpu_place(place)) { + if (place == phi::CPUPlace()) { for (size_t i = 0; i < in_grad.lod()[0].size() - 1; ++i) { int64_t begin = in_grad.lod()[0][i]; int64_t end = in_grad.lod()[0][i + 1]; @@ -116,30 +119,30 @@ void TestSequencePoolingSum(const DeviceContext &context, } TEST(SequencePoolingGrad, CPU_SUM) { - auto place = paddle::platform::CPUPlace(); + auto place = phi::CPUPlace(); auto *context = static_cast( - paddle::platform::DeviceContextPool::Instance().Get(place)); + phi::DeviceContextPool::Instance().Get(place)); - paddle::framework::LoD lod1; + phi::LoD lod1; lod1.push_back(std::vector{0, 10}); TestSequencePoolingSum(*context, lod1, 128); - paddle::framework::LoD lod2; + phi::LoD lod2; lod2.push_back(std::vector{0, 2, 7, 10}); TestSequencePoolingSum(*context, lod2, 128); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(SequencePoolingGrad, CUDA_SUM) { - auto place = paddle::platform::CUDAPlace(0); + auto place = phi::GPUPlace(0); auto *context = static_cast( - paddle::platform::DeviceContextPool::Instance().Get(place)); + phi::DeviceContextPool::Instance().Get(place)); - paddle::framework::LoD lod1; + phi::LoD lod1; lod1.push_back(std::vector{0, 10}); TestSequencePoolingSum(*context, lod1, 128); - paddle::framework::LoD lod2; + phi::LoD lod2; lod2.push_back(std::vector{0, 2, 7, 10}); TestSequencePoolingSum(*context, lod2, 128); } From d2e10acf67f742882f44bd778fb77d77f89af5a8 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sat, 8 Apr 2023 18:34:03 +0800 Subject: [PATCH 13/23] [phi] fix grad sig --- paddle/phi/ops/compat/sequence_pool_sig.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/ops/compat/sequence_pool_sig.cc b/paddle/phi/ops/compat/sequence_pool_sig.cc index 992fa92ecb995..bb8baf4e4af8b 100644 --- a/paddle/phi/ops/compat/sequence_pool_sig.cc +++ b/paddle/phi/ops/compat/sequence_pool_sig.cc @@ -27,9 +27,9 @@ KernelSignature SequencePoolOpArgumentMapping( KernelSignature SequencePoolGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("sequence_pool_grad", - {"X", "MaxIndex", "Out"}, + {"X", "MaxIndex", "Out@GRAD"}, {"is_test", "pooltype", "pad_value"}, - {"X"}); + {"X@GRAD"}); } } // namespace phi From c51b9d37f393e0a99dece9fc9532db7923288290 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sat, 8 Apr 2023 21:31:27 +0800 Subject: [PATCH 14/23] [phi] fix sequence_pool is_test error --- paddle/phi/kernels/cpu/sequence_pool_kernel.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc index 2ddedaaec06c9..bc29b38654fe7 100644 --- a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc +++ b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc @@ -70,7 +70,7 @@ void SequencePoolKernel(const Context& ctx, // Do not create index buffer for inference mode if (pooltype == "MAX" && - (is_test || (ctx.GetPlace() == phi::CPUPlace()) == false)) { + (is_test == false || (ctx.GetPlace() == phi::CPUPlace()) == false)) { index = max_index; index->Resize({dims}); ctx.template Alloc(index); From eb3886206e00eb88107d88027bf9560755b6035b Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sat, 8 Apr 2023 22:15:23 +0800 Subject: [PATCH 15/23] [phi] fix sequence_pooling gpu include --- paddle/phi/kernels/funcs/sequence_pooling.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/phi/kernels/funcs/sequence_pooling.cu b/paddle/phi/kernels/funcs/sequence_pooling.cu index 25abdb40ae2ec..c4d3279fe4ff9 100644 --- a/paddle/phi/kernels/funcs/sequence_pooling.cu +++ b/paddle/phi/kernels/funcs/sequence_pooling.cu @@ -16,6 +16,8 @@ limitations under the License. */ #include #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/macros.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sequence_pooling.h" From 295dbbf3ee8946e4839bb52a2b606cbe02690086 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sun, 9 Apr 2023 10:33:00 +0800 Subject: [PATCH 16/23] [phi] mv to impl --- .../kernels/cpu/sequence_pool_grad_kernel.cc | 23 +----- .../phi/kernels/cpu/sequence_pool_kernel.cc | 65 +-------------- .../kernels/gpu/sequence_pool_grad_kernel.cu | 3 + .../phi/kernels/gpu/sequence_pool_kernel.cu | 3 + .../impl/sequence_pool_grad_kernel_impl.h | 39 +++++++++ .../kernels/impl/sequence_pool_kernel_impl.h | 80 +++++++++++++++++++ 6 files changed, 127 insertions(+), 86 deletions(-) create mode 100644 paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/sequence_pool_kernel_impl.h diff --git a/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc index 8d450abe69e2a..dc8ab37274c31 100644 --- a/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc @@ -17,28 +17,7 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/sequence_pooling.h" - -namespace phi { - -template -void SequencePoolGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& max_index, - const DenseTensor& out_grad, - bool is_test, - const std::string& pooltype, - float pad_value, - DenseTensor* x_grad) { - const phi::DenseTensor* index = nullptr; - if (pooltype == "MAX") { - index = &max_index; - } - dev_ctx.template Alloc(x_grad); - phi::funcs::SequencePoolGradFunctor pool; - pool(dev_ctx, pooltype, out_grad, x_grad, index); -} - -} // namespace phi +#include "paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h" PD_REGISTER_KERNEL(sequence_pool_grad, CPU, diff --git a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc index bc29b38654fe7..4bc0e03983d5b 100644 --- a/paddle/phi/kernels/cpu/sequence_pool_kernel.cc +++ b/paddle/phi/kernels/cpu/sequence_pool_kernel.cc @@ -16,70 +16,7 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/sequence_pooling.h" - -namespace phi { - -template -void SequencePoolKernel(const Context& ctx, - const DenseTensor& x, - bool is_test, - const std::string& pooltype, - float pad_value, - DenseTensor* out, - DenseTensor* max_index) { - T pad_value_ = static_cast(pad_value); - - auto dims = x.dims(); - auto lod = x.lod(); - auto lod_level = lod.size(); - // InferShape by lod - PADDLE_ENFORCE_GT( - lod_level, - 0, - errors::InvalidArgument("Input(X) phi::DenseTensor of SequencePoolOp " - "does not contain LoD information.")); - PADDLE_ENFORCE_LE( - lod_level, - 2UL, - errors::InvalidArgument("The lod level of input shall be no more than 2." - "Received lod level is %d.", - lod_level)); - PADDLE_ENFORCE_GE( - dims[0], - /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), - errors::InvalidArgument( - "The first dimension of Input(X) must be large than batch size." - "But received first dimension of Input(X) is %d, while batch" - "size is %d.", - dims[0], - static_cast(lod[lod_level - 1].size() - 1))); - if (lod_level > 1UL) { - PADDLE_ENFORCE_EQ( - lod[0][lod[0].size() - 1], - lod[1].size() - 1, - errors::InvalidArgument("The input lod information is illegal.")); - phi::LoD out_lod; - out_lod.push_back(lod[0]); - out->set_lod(out_lod); - } - dims[0] = lod[lod_level - 1].size() - 1; - out->Resize({dims}); - ctx.template Alloc(out); - phi::DenseTensor* index = nullptr; - - // Do not create index buffer for inference mode - if (pooltype == "MAX" && - (is_test == false || (ctx.GetPlace() == phi::CPUPlace()) == false)) { - index = max_index; - index->Resize({dims}); - ctx.template Alloc(index); - } - phi::funcs::SequencePoolFunctor pool; - pool(ctx, pooltype, pad_value_, x, out, is_test, index); -} - -} // namespace phi +#include "paddle/phi/kernels/impl/sequence_pool_kernel_impl.h" PD_REGISTER_KERNEL( sequence_pool, CPU, ALL_LAYOUT, phi::SequencePoolKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu b/paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu index 1dc973ec3e5fe..fe991a1fef431 100644 --- a/paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/sequence_pool_grad_kernel.h" +#include "paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h" + #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" PD_REGISTER_KERNEL( sequence_pool_grad, GPU, ALL_LAYOUT, phi::SequencePoolGradKernel, float) {} diff --git a/paddle/phi/kernels/gpu/sequence_pool_kernel.cu b/paddle/phi/kernels/gpu/sequence_pool_kernel.cu index a89183254fc8e..7baf83e75dc3d 100644 --- a/paddle/phi/kernels/gpu/sequence_pool_kernel.cu +++ b/paddle/phi/kernels/gpu/sequence_pool_kernel.cu @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/sequence_pool_kernel.h" +#include "paddle/phi/kernels/impl/sequence_pool_kernel_impl.h" + #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" PD_REGISTER_KERNEL( sequence_pool, GPU, ALL_LAYOUT, phi::SequencePoolKernel, float) {} diff --git a/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h new file mode 100644 index 0000000000000..e8eb17e612b5f --- /dev/null +++ b/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/funcs/sequence_pooling.h" + +namespace phi { + +template +void SequencePoolGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& max_index, + const DenseTensor& out_grad, + bool is_test, + const std::string& pooltype, + float pad_value, + DenseTensor* x_grad) { + const phi::DenseTensor* index = nullptr; + if (pooltype == "MAX") { + index = &max_index; + } + dev_ctx.template Alloc(x_grad); + phi::funcs::SequencePoolGradFunctor pool; + pool(dev_ctx, pooltype, out_grad, x_grad, index); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h new file mode 100644 index 0000000000000..c66dc1e123ca7 --- /dev/null +++ b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/funcs/sequence_pooling.h" + +namespace phi { + +template +void SequencePoolKernel(const Context& ctx, + const DenseTensor& x, + bool is_test, + const std::string& pooltype, + float pad_value, + DenseTensor* out, + DenseTensor* max_index) { + T pad_value_ = static_cast(pad_value); + + auto dims = x.dims(); + auto lod = x.lod(); + auto lod_level = lod.size(); + // InferShape by lod + PADDLE_ENFORCE_GT( + lod_level, + 0, + errors::InvalidArgument("Input(X) phi::DenseTensor of SequencePoolOp " + "does not contain LoD information.")); + PADDLE_ENFORCE_LE( + lod_level, + 2UL, + errors::InvalidArgument("The lod level of input shall be no more than 2." + "Received lod level is %d.", + lod_level)); + PADDLE_ENFORCE_GE( + dims[0], + /*batch size = */ static_cast(lod[lod_level - 1].size() - 1), + errors::InvalidArgument( + "The first dimension of Input(X) must be large than batch size." + "But received first dimension of Input(X) is %d, while batch" + "size is %d.", + dims[0], + static_cast(lod[lod_level - 1].size() - 1))); + if (lod_level > 1UL) { + PADDLE_ENFORCE_EQ( + lod[0][lod[0].size() - 1], + lod[1].size() - 1, + errors::InvalidArgument("The input lod information is illegal.")); + phi::LoD out_lod; + out_lod.push_back(lod[0]); + out->set_lod(out_lod); + } + dims[0] = lod[lod_level - 1].size() - 1; + out->Resize({dims}); + ctx.template Alloc(out); + phi::DenseTensor* index = nullptr; + + // Do not create index buffer for inference mode + if (pooltype == "MAX" && + (is_test == false || (ctx.GetPlace() == phi::CPUPlace()) == false)) { + index = max_index; + index->Resize({dims}); + ctx.template Alloc(index); + } + phi::funcs::SequencePoolFunctor pool; + pool(ctx, pooltype, pad_value_, x, out, is_test, index); +} + +} // namespace phi From 554361a368f4656d25eba1b2aff246d34b871302 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sun, 9 Apr 2023 11:04:40 +0800 Subject: [PATCH 17/23] [phi] fix SequencePoolFunctor cu include --- paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc | 1 - paddle/phi/kernels/funcs/sequence_pooling.cu | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc index dc8ab37274c31..9f15ed0147031 100644 --- a/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/sequence_pooling.h" #include "paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h" PD_REGISTER_KERNEL(sequence_pool_grad, diff --git a/paddle/phi/kernels/funcs/sequence_pooling.cu b/paddle/phi/kernels/funcs/sequence_pooling.cu index c4d3279fe4ff9..fd0011233b372 100644 --- a/paddle/phi/kernels/funcs/sequence_pooling.cu +++ b/paddle/phi/kernels/funcs/sequence_pooling.cu @@ -16,9 +16,8 @@ limitations under the License. */ #include #include "paddle/phi/backends/gpu/gpu_primitives.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/common/float16.h" #include "paddle/phi/core/macros.h" +#include "paddle/phi/core/mixed_vector.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sequence_pooling.h" From cee77c3dbb94de8fa6fcf07191cfa93bfeb7698c Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sat, 15 Apr 2023 16:04:19 +0800 Subject: [PATCH 18/23] [phi] modify out max_index int32_t --- paddle/phi/kernels/impl/sequence_pool_kernel_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h index 56633b6b54270..e448820516afe 100644 --- a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h @@ -68,7 +68,7 @@ void SequencePoolKernel(const Context& ctx, (is_test == false || (ctx.GetPlace() == phi::CPUPlace()) == false)) { index = max_index; index->Resize({dims}); - ctx.template Alloc(index); + ctx.template Alloc(index); } phi::funcs::SequencePoolFunctor pool; pool(ctx, pooltype, pad_value_, x, out, is_test, index); From f5815bb3af0c545a3119d2ca062c0b5e363d46dc Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Mon, 24 Apr 2023 16:38:30 +0800 Subject: [PATCH 19/23] [phi] add pooltype mapping determine --- paddle/phi/ops/compat/sequence_pool_sig.cc | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/paddle/phi/ops/compat/sequence_pool_sig.cc b/paddle/phi/ops/compat/sequence_pool_sig.cc index 224f83d1f32f9..96682e250c6e9 100644 --- a/paddle/phi/ops/compat/sequence_pool_sig.cc +++ b/paddle/phi/ops/compat/sequence_pool_sig.cc @@ -23,10 +23,20 @@ KernelSignature SequencePoolOpArgumentMapping( KernelSignature SequencePoolGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("sequence_pool_grad", - {"X", "MaxIndex", "Out@GRAD"}, - {"is_test", "pooltype", "pad_value"}, - {"X@GRAD"}); + const auto& use_pooltype_maxindex = + paddle::any_cast(ctx.Attr("pooltype")); + + if (use_pooltype_maxindex == "SUN") { + return KernelSignature("sequence_pool_grad", + {"X", "Out@GRAD"}, + {"is_test", "pooltype", "pad_value"}, + {"X@GRAD"}); + } else { + return KernelSignature("sequence_pool_grad", + {"X", "MaxIndex", "Out@GRAD"}, + {"is_test", "pooltype", "pad_value"}, + {"X@GRAD"}); + } } } // namespace phi From 56fb9e8bb8b105a5615842810e8d4b223aa544e2 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 25 Apr 2023 10:56:11 +0800 Subject: [PATCH 20/23] [phi] fix sequence_pool_sig --- paddle/phi/ops/compat/sequence_pool_sig.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/phi/ops/compat/sequence_pool_sig.cc b/paddle/phi/ops/compat/sequence_pool_sig.cc index 96682e250c6e9..7f0449c2d9555 100644 --- a/paddle/phi/ops/compat/sequence_pool_sig.cc +++ b/paddle/phi/ops/compat/sequence_pool_sig.cc @@ -26,14 +26,14 @@ KernelSignature SequencePoolGradOpArgumentMapping( const auto& use_pooltype_maxindex = paddle::any_cast(ctx.Attr("pooltype")); - if (use_pooltype_maxindex == "SUN") { + if (use_pooltype_maxindex == "MAX") { return KernelSignature("sequence_pool_grad", - {"X", "Out@GRAD"}, + {"X", "MaxIndex", "Out@GRAD"}, {"is_test", "pooltype", "pad_value"}, {"X@GRAD"}); } else { return KernelSignature("sequence_pool_grad", - {"X", "MaxIndex", "Out@GRAD"}, + {"X", "Out@GRAD"}, {"is_test", "pooltype", "pad_value"}, {"X@GRAD"}); } From 8f059d7a4e681f19ad2e8ccb5275b804f49784ea Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 25 Apr 2023 12:24:05 +0800 Subject: [PATCH 21/23] [phi] fix sequence_pool_sig sum --- paddle/phi/ops/compat/sequence_pool_sig.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/ops/compat/sequence_pool_sig.cc b/paddle/phi/ops/compat/sequence_pool_sig.cc index 7f0449c2d9555..7fa75ad92b81d 100644 --- a/paddle/phi/ops/compat/sequence_pool_sig.cc +++ b/paddle/phi/ops/compat/sequence_pool_sig.cc @@ -23,10 +23,10 @@ KernelSignature SequencePoolOpArgumentMapping( KernelSignature SequencePoolGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - const auto& use_pooltype_maxindex = + const auto& pooltype_value = paddle::any_cast(ctx.Attr("pooltype")); - if (use_pooltype_maxindex == "MAX") { + if (pooltype_value == "MAX" || pooltype_value == "SUM") { return KernelSignature("sequence_pool_grad", {"X", "MaxIndex", "Out@GRAD"}, {"is_test", "pooltype", "pad_value"}, From edfafafc62b4c472968310c35a3d7db7c6d9e9bb Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 25 Apr 2023 14:07:58 +0800 Subject: [PATCH 22/23] [phi] try ci --- paddle/phi/ops/compat/sequence_pool_sig.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/ops/compat/sequence_pool_sig.cc b/paddle/phi/ops/compat/sequence_pool_sig.cc index 7fa75ad92b81d..1b21ba6d5f4c3 100644 --- a/paddle/phi/ops/compat/sequence_pool_sig.cc +++ b/paddle/phi/ops/compat/sequence_pool_sig.cc @@ -26,7 +26,7 @@ KernelSignature SequencePoolGradOpArgumentMapping( const auto& pooltype_value = paddle::any_cast(ctx.Attr("pooltype")); - if (pooltype_value == "MAX" || pooltype_value == "SUM") { + if (pooltype_value == "MAX") { return KernelSignature("sequence_pool_grad", {"X", "MaxIndex", "Out@GRAD"}, {"is_test", "pooltype", "pad_value"}, From fee21b2242ee2f2fcc65db4270dbf32fd23259ec Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 25 Apr 2023 17:24:47 +0800 Subject: [PATCH 23/23] [phi] fix max_index optional --- .../impl/sequence_pool_grad_kernel_impl.h | 4 ++-- paddle/phi/kernels/sequence_pool_grad_kernel.h | 2 +- paddle/phi/ops/compat/sequence_pool_sig.cc | 18 ++++-------------- 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h index e8eb17e612b5f..da9bdc1f1fdf5 100644 --- a/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h @@ -21,7 +21,7 @@ namespace phi { template void SequencePoolGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& max_index, + const paddle::optional& max_index, const DenseTensor& out_grad, bool is_test, const std::string& pooltype, @@ -29,7 +29,7 @@ void SequencePoolGradKernel(const Context& dev_ctx, DenseTensor* x_grad) { const phi::DenseTensor* index = nullptr; if (pooltype == "MAX") { - index = &max_index; + index = max_index.get_ptr(); } dev_ctx.template Alloc(x_grad); phi::funcs::SequencePoolGradFunctor pool; diff --git a/paddle/phi/kernels/sequence_pool_grad_kernel.h b/paddle/phi/kernels/sequence_pool_grad_kernel.h index 1117a96d4f1a2..a88f9ceb4b1ae 100644 --- a/paddle/phi/kernels/sequence_pool_grad_kernel.h +++ b/paddle/phi/kernels/sequence_pool_grad_kernel.h @@ -20,7 +20,7 @@ namespace phi { template void SequencePoolGradKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& max_index, + const paddle::optional& max_index, const DenseTensor& out_grad, bool is_test, const std::string& pooltype, diff --git a/paddle/phi/ops/compat/sequence_pool_sig.cc b/paddle/phi/ops/compat/sequence_pool_sig.cc index 1b21ba6d5f4c3..224f83d1f32f9 100644 --- a/paddle/phi/ops/compat/sequence_pool_sig.cc +++ b/paddle/phi/ops/compat/sequence_pool_sig.cc @@ -23,20 +23,10 @@ KernelSignature SequencePoolOpArgumentMapping( KernelSignature SequencePoolGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - const auto& pooltype_value = - paddle::any_cast(ctx.Attr("pooltype")); - - if (pooltype_value == "MAX") { - return KernelSignature("sequence_pool_grad", - {"X", "MaxIndex", "Out@GRAD"}, - {"is_test", "pooltype", "pad_value"}, - {"X@GRAD"}); - } else { - return KernelSignature("sequence_pool_grad", - {"X", "Out@GRAD"}, - {"is_test", "pooltype", "pad_value"}, - {"X@GRAD"}); - } + return KernelSignature("sequence_pool_grad", + {"X", "MaxIndex", "Out@GRAD"}, + {"is_test", "pooltype", "pad_value"}, + {"X@GRAD"}); } } // namespace phi