Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[static op generation] transpose #54155

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/fluid/operators/generator/get_expected_kernel_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,5 +371,15 @@ phi::KernelKey GetConvExpectedKernelType(
return phi::KernelKey(input_data_type, ctx.GetPlace());
}

phi::KernelKey GetTranspose2ExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
auto data_type = op_ptr->IndicateVarDataType(ctx, "X");
auto& data_format = ctx.Attr<std::string>("data_format");
phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
return phi::KernelKey(
ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
}

} // namespace operators
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/operators/generator/get_expected_kernel_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,9 @@ phi::KernelKey GetConvExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);

phi::KernelKey GetTranspose2ExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);

} // namespace operators
} // namespace paddle
149 changes: 0 additions & 149 deletions paddle/fluid/operators/transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,138 +107,6 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
}
};

void Transpose2Op::InferShape(framework::InferShapeContext *ctx) const {
using CompatMetaTensor = framework::CompatMetaTensor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Transpose2OpFusedTransposeOpMaker耦合,而且infershape也需要修改支持,建议这个任务可以先行关闭。

CompatMetaTensor x(ctx->GetInputVarPtrs("X")[0], ctx->IsRuntime());
CompatMetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime());
std::vector<int> axis = ctx->Attrs().Get<std::vector<int>>("axis");
phi::TransposeInferMeta(x, axis, &out);

if (!ctx->HasOutput("XShape")) return;
const auto &in_dims = ctx->GetInputDim("X");
std::vector<int64_t> x_shape_dim(in_dims.size() + 1);
x_shape_dim[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
x_shape_dim[i + 1] = in_dims[i];
}
ctx->SetOutputDim("XShape", phi::make_ddim(x_shape_dim));
ctx->ShareLoD("X", /*->*/ "XShape");
}

phi::KernelKey Transpose2Op::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
auto &data_format = ctx.Attr<std::string>("data_format");
phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
return phi::KernelKey(
ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
}

void Transpose2OpMaker::Make() {
AddInput(
"X",
"(Tensor) The input tensor, tensors with rank up to 6 are supported.");
AddOutput("Out", "(Tensor)The output tensor.");
AddAttr<std::vector<int>>(
"axis",
"(vector<int>) A list of values, and the size of the list should be "
"the same with the input tensor rank. This operator permutes the input "
"tensor's axes according to the values given.");
AddOutput("XShape", "(Tensor)The output tensor.").AsIntermediate().AsExtra();
AddComment(R"DOC(
Transpose Operator.

The input tensor will be permuted according to the axes given.
The behavior of this operator is similar to how `numpy.transpose` works.

- suppose the input `X` is a 2-D tensor:
$$
X = \begin{pmatrix}
0 &1 &2 \\
3 &4 &5
\end{pmatrix}$$

the given `axes` is: $[1, 0]$, and $Y$ = transpose($X$, axis)

then the output $Y$ is:

$$
Y = \begin{pmatrix}
0 &3 \\
1 &4 \\
2 &5
\end{pmatrix}$$

- Given a input tensor with shape $(N, C, H, W)$ and the `axes` is
$[0, 2, 3, 1]$, then shape of the output tensor will be: $(N, H, W, C)$.

)DOC");
Apply();
}

template <typename T>
class Transpose2GradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("transpose2_grad");
grad_op->SetInput("XShape", this->Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};

class Transpose2CompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;

public:
void Apply() override {
paddle::Tensor xshape = this->GetSingleForwardOutput("XShape");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor dx = this->GetSingleInputGrad("X");
auto *dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
std::vector<int> axis =
static_cast<std::vector<int>>(this->Attr<std::vector<int>>("axis"));
VLOG(6) << "Runing transpose2_grad composite func";
prim::transpose_grad<prim::DescTensor>(out_grad, axis, dx_ptr);
this->RecoverOutputName(dx, dx_name);
}
};

template <typename T>
class Transpose2DoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("transpose2");
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetOutput("XShape", this->Input("XShape"));
grad_op->SetAttrMap(this->Attrs());
}
};

class Transpose2OpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::proto::VarType::Type data_type =
OperatorWithKernel::IndicateVarDataType(ctx,
framework::GradVarName("Out"));
std::string data_format = ctx.Attr<std::string>("data_format");
phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
return phi::KernelKey(
ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
}
};

class TransposeGradInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
Expand All @@ -257,10 +125,6 @@ DECLARE_INFER_SHAPE_FUNCTOR(transpose_grad,
TransposeGradInferShapeFunctor,
PD_INFER_META(phi::TransposeGradInferMeta));

DECLARE_INFER_SHAPE_FUNCTOR(transpose2_grad,
Transpose2GradInferShapeFunctor,
PD_INFER_META(phi::TransposeGradInferMeta));

namespace ops = paddle::operators;
REGISTER_OPERATOR(
transpose,
Expand All @@ -274,16 +138,3 @@ REGISTER_OPERATOR(transpose_grad,
ops::TransposeOpGrad,
ops::TransposeGradInferVarType,
TransposeGradInferShapeFunctor);

REGISTER_OPERATOR(transpose2,
ops::Transpose2Op,
ops::Transpose2OpMaker,
ops::Transpose2GradMaker<paddle::framework::OpDesc>,
ops::Transpose2GradMaker<paddle::imperative::OpBase>,
ops::Transpose2CompositeGradOpMaker);
REGISTER_OPERATOR(transpose2_grad,
ops::Transpose2OpGrad,
ops::TransposeGradInferVarType,
ops::Transpose2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Transpose2DoubleGradMaker<paddle::imperative::OpBase>,
Transpose2GradInferShapeFunctor);
18 changes: 18 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2173,6 +2173,24 @@
data_type : out_grad
no_need_buffer : x

- backward_op : transpose_double_grad
forward : transpose_grad (Tensor grad_out, int[] perm) -> Tensor(grad_x)
args : (Tensor grad_x_grad, int[] perm)
output : Tensor(grad_out_grad)
invoke : transpose(grad_x_grad, perm)

- backward_op : transpose_grad
forward : transpose (Tensor x, int[] perm) -> Tensor(out)
args : (Tensor out_grad, int[] perm)
output : Tensor(x_grad)
infer_meta :
func : TransposeGradInferMeta
param : [out_grad, perm]
kernel :
func : transpose_grad
backward : transpose_double_grad
composite: transpose_grad(out_grad, perm, x_grad)

- backward_op : triangular_solve_grad
forward : triangular_solve (Tensor x, Tensor y, bool upper=true, bool transpose=false, bool unitriangular=false) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out, Tensor out_grad, bool upper, bool transpose, bool unitriangular)
Expand Down
18 changes: 0 additions & 18 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -868,24 +868,6 @@
kernel :
func : trans_layout_grad

- backward_op : transpose_double_grad
forward : transpose_grad (Tensor grad_out, int[] perm) -> Tensor(grad_x)
args : (Tensor grad_x_grad, int[] perm)
output : Tensor(grad_out_grad)
invoke : transpose(grad_x_grad, perm)

- backward_op : transpose_grad
forward : transpose (Tensor x, int[] perm) -> Tensor(out)
args : (Tensor out_grad, int[] perm)
output : Tensor(x_grad)
infer_meta :
func : TransposeGradInferMeta
param : [out_grad, perm]
kernel :
func : transpose_grad
backward : transpose_double_grad
composite: transpose_grad(out_grad, perm, x_grad)

- backward_op : tril_grad
forward : tril(Tensor x, int diagonal) -> Tensor(out)
args : (Tensor out_grad, int diagonal)
Expand Down
9 changes: 0 additions & 9 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1008,15 +1008,6 @@
func : transpose
backward : trans_layout_grad

- op : transpose
args : (Tensor x, int[] perm)
output : Tensor
infer_meta :
func : TransposeInferMeta
kernel :
func : transpose
backward : transpose_grad

- op : tril
args : (Tensor x, int diagonal)
output : Tensor(out)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,8 @@
extra :
outputs : [XShape]
attrs : [bool use_mkldnn = false, str data_format = "AnyLayout", str mkldnn_data_type = "float32"]
get_expected_kernel_type:
transpose2: GetTranspose2ExpectedKernelType

- op : triangular_solve
backward : triangular_solve_grad
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2242,6 +2242,15 @@
func : trace
backward : trace_grad

- op : transpose
args : (Tensor x, int[] perm)
output : Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output : Tensor(out), Tensor(xshape)
intermediate :  xshape

infer_meta :
func : TransposeInferMeta
kernel :
func : transpose
backward : transpose_grad

- op : triangular_solve
args : (Tensor x, Tensor y, bool upper=true, bool transpose=false, bool unitriangular=false)
output : Tensor
Expand Down
6 changes: 0 additions & 6 deletions paddle/phi/ops/compat/transpose_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,5 @@ KernelSignature TransposeGradOpArgumentMapping(

} // namespace phi

PD_REGISTER_BASE_KERNEL_NAME(transpose2, transpose);
PD_REGISTER_BASE_KERNEL_NAME(transpose2_grad, transpose_grad);

PD_REGISTER_ARG_MAPPING_FN(transpose2, phi::TransposeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(transpose2_grad,
phi::TransposeGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(transpose, phi::TransposeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(transpose_grad, phi::TransposeGradOpArgumentMapping);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要修改transpose,应该修改transpose2相关的内容