From fe053396734c58e445e3e8c9973e6ca8c125682a Mon Sep 17 00:00:00 2001 From: gouzil <66515297+gouzil@users.noreply.github.com> Date: Thu, 27 Apr 2023 21:23:35 +0800 Subject: [PATCH] =?UTF-8?q?[phi]=20Move=20sequence=5Fpool=20to=20phi=20-?= =?UTF-8?q?=20Step=203=20=EF=BC=9Asequence=5Fpool=5Fgrad=5Fop=20(#52680)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [phi] move sequence_pool kernel to phi * mv kernels impl * fix parameter error * clean include * fix compat filename * [phi] move fluid sequence_pool_grad to phi * [phi][compat] sig rm GradVarName * [phi] fix sequence_pool out type * [phi] rm impl, add const string * [phi] fix const str * fix sequence_pooling cmake * [phi] mv sequence_pooling_test * [phi] fix grad sig * [phi] fix sequence_pool is_test error * [phi] fix sequence_pooling gpu include * [phi] mv to impl * [phi] fix SequencePoolFunctor cu include * [phi] modify out max_index int32_t * [phi] add pooltype mapping determine * [phi] fix sequence_pool_sig * [phi] fix sequence_pool_sig sum * [phi] try ci * [phi] fix max_index optional --- paddle/fluid/operators/math/CMakeLists.txt | 1 - .../sequence_ops/sequence_pool_op.cc | 12 +---- .../operators/sequence_ops/sequence_pool_op.h | 49 ------------------- .../sequence_ops/unity_build_rule.cmake | 1 - .../kernels/cpu/sequence_pool_grad_kernel.cc | 26 ++++++++++ .../kernels/gpu/sequence_pool_grad_kernel.cu} | 14 ++++-- .../impl/sequence_pool_grad_kernel_impl.h | 39 +++++++++++++++ .../kernels/impl/sequence_pool_kernel_impl.h | 2 +- .../phi/kernels/sequence_pool_grad_kernel.h | 30 ++++++++++++ paddle/phi/ops/compat/sequence_pool_sig.cc | 10 ++++ 10 files changed, 116 insertions(+), 68 deletions(-) delete mode 100644 paddle/fluid/operators/sequence_ops/sequence_pool_op.h create mode 100644 paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc rename paddle/{fluid/operators/sequence_ops/sequence_pool_op.cu => phi/kernels/gpu/sequence_pool_grad_kernel.cu} (59%) create mode 100644 paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/sequence_pool_grad_kernel.h diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 405661f3a1ab9..c67703ba52814 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -14,7 +14,6 @@ math_library(sample_prob) math_library(sampler DEPS generator) # math_library(math_function DEPS blas dense_tensor tensor) - if(WITH_XPU) math_library(beam_search DEPS math_function beam_search_xpu) else() diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index f51cc7d903664..d616bca2c4e3b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -12,10 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" - -#include -#include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -196,10 +193,3 @@ REGISTER_OPERATOR(sequence_pool, REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp, ops::SequencePoolGradOpNoNeedBufferVarsInferer); - -PD_REGISTER_STRUCT_KERNEL(sequence_pool_grad, - CPU, - ALL_LAYOUT, - ops::SequencePoolGradKernel, - float, - double) {} diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h b/paddle/fluid/operators/sequence_ops/sequence_pool_op.h deleted file mode 100644 index 7b00395bb6bb2..0000000000000 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/sequence_pooling.h" - -namespace paddle { -namespace operators { - -template -class SequencePoolGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* out_g = - context.Input(framework::GradVarName("Out")); - auto* in_g = context.Output(framework::GradVarName("X")); - std::string pooltype = context.Attr("pooltype"); - const phi::DenseTensor* index = nullptr; - if (pooltype == "MAX") { - index = context.Input("MaxIndex"); - } - in_g->mutable_data(context.GetPlace()); - phi::funcs::SequencePoolGradFunctor pool; - pool(context.template device_context(), - pooltype, - *out_g, - in_g, - index); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake b/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake index b23035082e4aa..b256a227a9646 100644 --- a/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake +++ b/paddle/fluid/operators/sequence_ops/unity_build_rule.cmake @@ -30,7 +30,6 @@ register_unity_group( sequence_expand_op.cu sequence_mask_op.cu sequence_pad_op.cu - sequence_pool_op.cu sequence_expand_as_op.cu sequence_reshape_op.cu sequence_reverse_op.cu diff --git a/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc new file mode 100644 index 0000000000000..9f15ed0147031 --- /dev/null +++ b/paddle/phi/kernels/cpu/sequence_pool_grad_kernel.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sequence_pool_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(sequence_pool_grad, + CPU, + ALL_LAYOUT, + phi::SequencePoolGradKernel, + float, + double) {} diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu b/paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu similarity index 59% rename from paddle/fluid/operators/sequence_ops/sequence_pool_op.cu rename to paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu index 796b02cb03e32..fe991a1fef431 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu +++ b/paddle/phi/kernels/gpu/sequence_pool_grad_kernel.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -11,8 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" -namespace ops = paddle::operators; -PD_REGISTER_STRUCT_KERNEL( - sequence_pool_grad, GPU, ALL_LAYOUT, ops::SequencePoolGradKernel, float) {} +#include "paddle/phi/kernels/sequence_pool_grad_kernel.h" +#include "paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + sequence_pool_grad, GPU, ALL_LAYOUT, phi::SequencePoolGradKernel, float) {} diff --git a/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h new file mode 100644 index 0000000000000..da9bdc1f1fdf5 --- /dev/null +++ b/paddle/phi/kernels/impl/sequence_pool_grad_kernel_impl.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/funcs/sequence_pooling.h" + +namespace phi { + +template +void SequencePoolGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& max_index, + const DenseTensor& out_grad, + bool is_test, + const std::string& pooltype, + float pad_value, + DenseTensor* x_grad) { + const phi::DenseTensor* index = nullptr; + if (pooltype == "MAX") { + index = max_index.get_ptr(); + } + dev_ctx.template Alloc(x_grad); + phi::funcs::SequencePoolGradFunctor pool; + pool(dev_ctx, pooltype, out_grad, x_grad, index); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h index 56633b6b54270..e448820516afe 100644 --- a/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/sequence_pool_kernel_impl.h @@ -68,7 +68,7 @@ void SequencePoolKernel(const Context& ctx, (is_test == false || (ctx.GetPlace() == phi::CPUPlace()) == false)) { index = max_index; index->Resize({dims}); - ctx.template Alloc(index); + ctx.template Alloc(index); } phi::funcs::SequencePoolFunctor pool; pool(ctx, pooltype, pad_value_, x, out, is_test, index); diff --git a/paddle/phi/kernels/sequence_pool_grad_kernel.h b/paddle/phi/kernels/sequence_pool_grad_kernel.h new file mode 100644 index 0000000000000..a88f9ceb4b1ae --- /dev/null +++ b/paddle/phi/kernels/sequence_pool_grad_kernel.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void SequencePoolGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& max_index, + const DenseTensor& out_grad, + bool is_test, + const std::string& pooltype, + float pad_value, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/ops/compat/sequence_pool_sig.cc b/paddle/phi/ops/compat/sequence_pool_sig.cc index 6c4d6691db4bb..224f83d1f32f9 100644 --- a/paddle/phi/ops/compat/sequence_pool_sig.cc +++ b/paddle/phi/ops/compat/sequence_pool_sig.cc @@ -21,6 +21,16 @@ KernelSignature SequencePoolOpArgumentMapping( {"Out", "MaxIndex"}); } +KernelSignature SequencePoolGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("sequence_pool_grad", + {"X", "MaxIndex", "Out@GRAD"}, + {"is_test", "pooltype", "pad_value"}, + {"X@GRAD"}); +} + } // namespace phi PD_REGISTER_ARG_MAPPING_FN(sequence_pool, phi::SequencePoolOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(sequence_pool_grad, + phi::SequencePoolGradOpArgumentMapping);