From d9ac0686ff35b792dd790bec9d1eea17d4c21fd5 Mon Sep 17 00:00:00 2001 From: wenming2014 <2279939962@qq.com> Date: Wed, 9 Sep 2020 05:49:56 +0000 Subject: [PATCH] add pool1d, pool2d, pool3d, pad PEs and ops and C++/python tests --- cinn/common/ir_util.h | 10 +- cinn/hlir/op/CMakeLists.txt | 1 + cinn/hlir/op/broadcast.cc | 41 ++- cinn/hlir/op/nn.cc | 508 ++++++++++++++++++++++++++++-- cinn/hlir/op/op_broadcast_test.cc | 4 +- cinn/hlir/op/op_nn_test.cc | 308 ++++++++++++++++++ cinn/hlir/op/transform.cc | 70 ++-- cinn/hlir/pe/broadcast.cc | 41 ++- cinn/hlir/pe/broadcast.h | 25 +- cinn/hlir/pe/nn.cc | 332 ++++++++++++++++++- cinn/hlir/pe/nn.h | 77 +++++ cinn/hlir/pe/pe_broadcast_test.cc | 109 ++++++- cinn/ir/ir_operators.cc | 29 +- cinn/ir/ir_operators.h | 6 +- cinn/pybind/framework.cc | 2 +- python/tests/pool_utils.py | 425 +++++++++++++++++++++++++ python/tests/test_op_nn.py | 181 +++++++++++ python/tests/test_utils.py | 23 +- 18 files changed, 2077 insertions(+), 115 deletions(-) create mode 100644 cinn/hlir/op/op_nn_test.cc create mode 100644 python/tests/pool_utils.py diff --git a/cinn/common/ir_util.h b/cinn/common/ir_util.h index f0def23430..973af21845 100644 --- a/cinn/common/ir_util.h +++ b/cinn/common/ir_util.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include #include @@ -96,9 +97,14 @@ Expr make_const(Type t, T v) { } template -Expr FoldExpr(FuncOp funcOp, Expr init_value, const std::vector &values) { +Expr FoldExpr(FuncOp funcOp, const std::vector &values) { + Expr init_value; for (const Expr &val : values) { - init_value = funcOp(init_value, val); + if (!init_value.defined()) { + init_value = val; + } else { + init_value = funcOp(val, init_value); + } } return init_value; } diff --git a/cinn/hlir/op/CMakeLists.txt b/cinn/hlir/op/CMakeLists.txt index 3a97d1f7ac..340a992249 100644 --- a/cinn/hlir/op/CMakeLists.txt +++ b/cinn/hlir/op/CMakeLists.txt @@ -11,3 +11,4 @@ foreach(cpp ${srcs}) endforeach() cc_test(test_op_broadcast SRCS op_broadcast_test.cc DEPS core) +cc_test(test_op_nn SRCS op_nn_test.cc DEPS core) diff --git a/cinn/hlir/op/broadcast.cc b/cinn/hlir/op/broadcast.cc index 1ae029a463..a66c0dea64 100644 --- a/cinn/hlir/op/broadcast.cc +++ b/cinn/hlir/op/broadcast.cc @@ -21,18 +21,23 @@ std::shared_ptr StrategyForElementwiseAdd(const framework::NodeAttr const std::vector &out_type, const Target &target) { framework::CINNCompute add_compute([&attrs](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of add compute is empty! Please check.\n"; CINNValuePack a = args[0]; - Expr A_expr = a[0]; - Expr B_expr = a[1]; + CHECK_GE(a.size(), 2U) << "at least 2 input tensors for add compute\n"; + Expr A_expr = a[0]; + Expr B_expr = a[1]; CHECK(A_expr.as_tensor()); CHECK(B_expr.as_tensor()); - ir::Tensor A = A_expr.as_tensor_ref(); - ir::Tensor B = B_expr.as_tensor_ref(); - auto attr_store = attrs.attr_store; - auto iter = attr_store.find("axis"); + ir::Tensor A = A_expr.as_tensor_ref(); + ir::Tensor B = B_expr.as_tensor_ref(); Expr axis; - if (iter != attr_store.end()) { - axis = Expr(std::get(iter->second)); + bool trans_a; + for (auto &iter : attrs.attr_store) { + if (iter.first == "axis") { + axis = Expr(std::get(iter.second)); + } else { + LOG(ERROR) << "unsupported attr_store: " << iter.first << std::endl; + } } auto out = pe::Add(A, B, UniqName("C"), axis); @@ -42,10 +47,11 @@ std::shared_ptr StrategyForElementwiseAdd(const framework::NodeAttr }); framework::CINNSchedule add_schedule([](lang::Args args, lang::RetValue *ret) { - CINNValuePack arg_pack = args[0]; - Expr A [[maybe_unused]] = arg_pack[0]; + CHECK(!args.empty()) << "The input argument of add schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL); - *ret = arg_pack; + Expr A [[maybe_unused]] = arg_pack[0]; + *ret = arg_pack; }); auto strategy = std::make_shared(); @@ -59,9 +65,11 @@ std::shared_ptr StrategyForElementwiseMul(const framework::NodeAttr const std::vector &out_type, const Target &target) { framework::CINNCompute mul_compute([&attrs](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of elementwise_mul compute is empty! Please check.\n"; CINNValuePack a = args[0]; - Expr A_expr = a[0]; - Expr B_expr = a[1]; + CHECK_GE(a.size(), 2U) << "at least 2 input tensors for elementwise_mul compute\n"; + Expr A_expr = a[0]; + Expr B_expr = a[1]; CHECK(A_expr.as_tensor()); CHECK(B_expr.as_tensor()); ir::Tensor A = A_expr.as_tensor_ref(); @@ -80,10 +88,11 @@ std::shared_ptr StrategyForElementwiseMul(const framework::NodeAttr }); framework::CINNSchedule mul_schedule([](lang::Args args, lang::RetValue *ret) { - CINNValuePack arg_pack = args[0]; - Expr A [[maybe_unused]] = arg_pack[0]; + CHECK(!args.empty()) << "The input argument of elementwise_mul schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL); - *ret = arg_pack; + Expr A [[maybe_unused]] = arg_pack[0]; + *ret = arg_pack; }); auto strategy = std::make_shared(); diff --git a/cinn/hlir/op/nn.cc b/cinn/hlir/op/nn.cc index 1eadf59eae..9b48f9cb67 100644 --- a/cinn/hlir/op/nn.cc +++ b/cinn/hlir/op/nn.cc @@ -1,5 +1,4 @@ #include "cinn/hlir/pe/nn.h" - #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" @@ -20,8 +19,10 @@ std::shared_ptr StrategyForRelu(const framework::NodeAttr &attrs, const std::vector &out_type, const Target &target) { framework::CINNCompute relu_compute([](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of relu compute is empty! Please check.\n"; CINNValuePack a = args[0]; - Expr A = a[0]; + CHECK(!a.empty()) << "at least one input tensor for relu compute\n"; + Expr A = a[0]; CHECK(A.as_tensor()); auto out = pe::Relu(A.as_tensor_ref(), 0.0, UniqName("Relu_output")); auto stages = CreateStages({out}); @@ -29,10 +30,11 @@ std::shared_ptr StrategyForRelu(const framework::NodeAttr &attrs, }); framework::CINNSchedule relu_schedule([](lang::Args args, lang::RetValue *ret) { - CINNValuePack arg_pack = args[0]; - Expr A [[maybe_unused]] = arg_pack[0]; + CHECK(!args.empty()) << "The input argument of relu schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL); - *ret = arg_pack; + Expr A [[maybe_unused]] = arg_pack[0]; + *ret = arg_pack; }); auto strategy = std::make_shared(); @@ -63,8 +65,10 @@ std::shared_ptr StrategyForRelu6(const framework::NodeAttr &attrs, const std::vector &out_type, const Target &target) { framework::CINNCompute relu_compute([](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of relu6 compute is empty! Please check.\n"; CINNValuePack a = args[0]; - Expr A = a[0]; + CHECK(!a.empty()) << "at least one input tensor for relu6 compute\n"; + Expr A = a[0]; CHECK(A.as_tensor()); auto out = pe::Relu6(A.as_tensor_ref(), 0.0, UniqName("Relu6_output")); auto stages = CreateStages({out}); @@ -72,10 +76,11 @@ std::shared_ptr StrategyForRelu6(const framework::NodeAttr &attrs, }); framework::CINNSchedule relu_schedule([](lang::Args args, lang::RetValue *ret) { - CINNValuePack arg_pack = args[0]; - Expr A [[maybe_unused]] = arg_pack[0]; + CHECK(!args.empty()) << "The input argument of relu6 schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL); - *ret = arg_pack; + Expr A [[maybe_unused]] = arg_pack[0]; + *ret = arg_pack; }); auto strategy = std::make_shared(); @@ -109,9 +114,11 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, groups = std::get(attrs.attr_store.at("groups")); } framework::CINNCompute conv2d_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of conv2d compute is empty! Please check.\n"; CINNValuePack a = args[0]; - Expr A = a[0]; - Expr B = a[1]; + CHECK_GE(a.size(), 2U) << "at least 2 input tensors for conv2d compute\n"; + Expr A = a[0]; + Expr B = a[1]; CHECK(A.as_tensor()); CHECK(B.as_tensor()); CHECK_EQ(padding.size(), 2) << "The size of padding in conv2d op is not 2! Please check."; @@ -135,10 +142,11 @@ std::shared_ptr StrategyForConv2d(const framework::NodeAttr &attrs, }); framework::CINNSchedule conv2d_schedule([](lang::Args args, lang::RetValue *ret) { - CINNValuePack arg_pack = args[0]; - Expr A [[maybe_unused]] = arg_pack[0]; + CHECK(!args.empty()) << "The input argument of conv2d schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 4UL); - *ret = arg_pack; + Expr A [[maybe_unused]] = arg_pack[0]; + *ret = arg_pack; }); auto strategy = std::make_shared(); @@ -197,9 +205,11 @@ std::shared_ptr StrategyForBatchNorm(const framework::NodeAttr &attr epsilon = std::get(attrs.attr_store.at("epsilon")); } framework::CINNCompute batchnorm_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of batchnorm compute is empty! Please check.\n"; CINNValuePack a = args[0]; - Expr A = a[0]; - Expr B = a[1]; + CHECK_GE(a.size(), 2U) << "at least 2 input tensors for batchnorm compute\n"; + Expr A = a[0]; + Expr B = a[1]; CHECK(A.as_tensor()); CHECK(B.as_tensor()); auto out = pe::BatchNorm_NCHW(A.as_tensor_ref(), B.as_tensor_ref(), epsilon, UniqName("BatchNorm_output")); @@ -208,10 +218,11 @@ std::shared_ptr StrategyForBatchNorm(const framework::NodeAttr &attr }); framework::CINNSchedule batchnorm_schedule([](lang::Args args, lang::RetValue *ret) { - CINNValuePack arg_pack = args[0]; - Expr A [[maybe_unused]] = arg_pack[0]; + CHECK(!args.empty()) << "The input argument of batchnorm schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL); - *ret = arg_pack; + Expr A [[maybe_unused]] = arg_pack[0]; + *ret = arg_pack; }); auto strategy = std::make_shared(); @@ -237,6 +248,438 @@ std::vector InferDtypeForBatchNorm(const std::vector &inputs_type, c return res; } +std::shared_ptr StrategyForPool1d(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const Target &target) { + framework::CINNCompute pool1d_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of pool1d compute is empty! Please check.\n"; + CINNValuePack a = args[0]; + CHECK(!a.empty()) << "The input tensor of pool1d compute is empty! Please check.\n"; + Expr A = a[0]; + CHECK(A.as_tensor()); + auto attr_store = attrs.attr_store; + std::vector kernel_size; // [kernel_w] + std::vector stride_size; // [stride_w] + std::vector padding_size; // [padding_left, padding_right] + std::string pool_type = "max"; + bool ceil_mode = false; + bool exclusive = true; + std::string data_format = "NCW"; + for (auto &iter : attrs.attr_store) { + if (iter.first == "kernel_size") { + kernel_size = std::get>(iter.second); + } else if (iter.first == "stride_size") { + stride_size = std::get>(iter.second); + } else if (iter.first == "padding_size") { + padding_size = std::get>(iter.second); + } else if (iter.first == "pool_type") { + pool_type = std::get(iter.second); + } else if (iter.first == "ceil_mode") { + ceil_mode = std::get(iter.second); + } else if (iter.first == "exclusive") { + exclusive = std::get(iter.second); + } else if (iter.first == "data_format") { + data_format = std::get(iter.second); + } else { + LOG(ERROR) << "unsupported attr: " << iter.first << std::endl; + } + } + CHECK(!kernel_size.empty()); + CHECK(!stride_size.empty()); + CHECK(!padding_size.empty()); + auto out = pe::Pool1d(A.as_tensor_ref(), + kernel_size, + stride_size, + padding_size, + pool_type, + ceil_mode, + exclusive, + data_format, + UniqName("T_Pool1d_out")); + + auto stages = CreateStages(out); + std::vector res; + for (auto &t : out) { + res.push_back(CINNValue(Expr(t.get()))); + } + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule pool1d_schedule([](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of pool1d schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + CHECK_EQ(arg_pack.size(), 3UL); + Expr A [[maybe_unused]] = arg_pack[0]; + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(pool1d_compute, pool1d_schedule, "strategy.pool1d.x86", 1); + + return strategy; +} + +std::vector> InferShapeForPool1d(const std::vector> &inputs_shape, + const framework::NodeAttr &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; + auto attr_store = attrs.attr_store; + std::vector kernel_size; // [kernel_w] + std::vector stride_size; // [stride_w] + std::vector padding_size; // [padding_left, padding_right] + std::string pool_type = "max"; + bool ceil_mode = false; + bool exclusive = true; + std::string data_format = "NCW"; + for (auto &iter : attrs.attr_store) { + if (iter.first == "kernel_size") { + kernel_size = std::get>(iter.second); + } else if (iter.first == "stride_size") { + stride_size = std::get>(iter.second); + } else if (iter.first == "padding_size") { + padding_size = std::get>(iter.second); + } else if (iter.first == "ceil_mode") { + ceil_mode = std::get(iter.second); + } else if (iter.first == "exclusive") { + exclusive = std::get(iter.second); + } else if (iter.first == "data_format") { + data_format = std::get(iter.second); + } + } + CHECK_EQ(kernel_size.size(), 1U); + CHECK_EQ(stride_size.size(), 1U); + CHECK_EQ(padding_size.size(), 2U); + + std::vector output_shape0 = inputs_shape[0]; + std::vector output_shape1 = inputs_shape[0]; + CHECK_EQ(output_shape0.size(), 3U); + int width_axis = -1; + if (data_format == "NCW") { + width_axis = 2; + } else if (data_format == "NWC") { + width_axis = 1; + } else { + LOG(ERROR) << "unsupported data_format: " << data_format << std::endl; + } + + output_shape0[width_axis] += padding_size[0] + padding_size[1]; + if (ceil_mode) { + output_shape0[width_axis] += stride_size[0]; + output_shape1[width_axis] = + (inputs_shape[0][width_axis] - kernel_size[0] + padding_size[0] + padding_size[1] + stride_size[0] - 1) / + stride_size[0] + + 1; + } else { + output_shape1[width_axis] = + (inputs_shape[0][width_axis] - kernel_size[0] + padding_size[0] + padding_size[1]) / stride_size[0] + 1; + } + + std::vector> res{output_shape0, output_shape1}; + return res; +} + +std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const Target &target) { + framework::CINNCompute pool2d_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of pool2d compute is empty! Please check.\n"; + CINNValuePack a = args[0]; + CHECK(!a.empty()) << "The input tensor of pool2d compute is empty! Please check.\n"; + Expr A = a[0]; + CHECK(A.as_tensor()); + auto attr_store = attrs.attr_store; + std::vector kernel_size; // [kernel_h, kernel_w] + std::vector stride_size; // [stride_h, stride_w] + std::vector padding_size; // [padding_top, padding_left, padding_bottom, padding_right] + std::string pool_type = "max"; + bool ceil_mode = false; + bool exclusive = true; + std::string data_format = "NCHW"; + for (auto &iter : attrs.attr_store) { + if (iter.first == "kernel_size") { + kernel_size = std::get>(iter.second); + } else if (iter.first == "stride_size") { + stride_size = std::get>(iter.second); + } else if (iter.first == "padding_size") { + padding_size = std::get>(iter.second); + } else if (iter.first == "pool_type") { + pool_type = std::get(iter.second); + } else if (iter.first == "ceil_mode") { + ceil_mode = std::get(iter.second); + } else if (iter.first == "exclusive") { + exclusive = std::get(iter.second); + } else if (iter.first == "data_format") { + data_format = std::get(iter.second); + } else { + LOG(ERROR) << "unsupported attr: " << iter.first << std::endl; + } + } + CHECK(!kernel_size.empty()); + CHECK(!stride_size.empty()); + CHECK(!padding_size.empty()); + auto out = pe::Pool2d(A.as_tensor_ref(), + kernel_size, + stride_size, + padding_size, + pool_type, + ceil_mode, + exclusive, + data_format, + UniqName("T_Pool2d_out")); + + auto stages = CreateStages(out); + std::vector res; + for (auto &t : out) { + res.push_back(CINNValue(Expr(t.get()))); + } + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule pool2d_schedule([](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of pool2d schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + CHECK_EQ(arg_pack.size(), 3UL); + Expr A [[maybe_unused]] = arg_pack[0]; + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(pool2d_compute, pool2d_schedule, "strategy.pool2d.x86", 1); + + return strategy; +} + +std::vector> InferShapeForPool2d(const std::vector> &inputs_shape, + const framework::NodeAttr &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; + auto attr_store = attrs.attr_store; + std::vector kernel_size; + std::vector stride_size; + std::vector padding_size; + std::string pool_type = "max"; + bool ceil_mode = false; + bool exclusive = true; + std::string data_format = "NCHW"; + for (auto &iter : attrs.attr_store) { + if (iter.first == "kernel_size") { + kernel_size = std::get>(iter.second); + } else if (iter.first == "stride_size") { + stride_size = std::get>(iter.second); + } else if (iter.first == "padding_size") { + padding_size = std::get>(iter.second); + } else if (iter.first == "ceil_mode") { + ceil_mode = std::get(iter.second); + } else if (iter.first == "exclusive") { + exclusive = std::get(iter.second); + } else if (iter.first == "data_format") { + data_format = std::get(iter.second); + } + } + CHECK_EQ(kernel_size.size(), 2U); + CHECK_EQ(stride_size.size(), 2U); + + std::vector output_shape0 = inputs_shape[0]; + std::vector output_shape1 = inputs_shape[0]; + CHECK_EQ(output_shape0.size(), 4U); + int height_axis = -1; + int width_axis = -1; + if (data_format == "NCHW") { + height_axis = 2; + width_axis = 3; + } else if (data_format == "NHWC") { + height_axis = 1; + width_axis = 2; + } else { + LOG(ERROR) << "unsupported data_format: " << data_format << std::endl; + } + + output_shape0[height_axis] += padding_size[0] + padding_size[2]; + output_shape0[width_axis] += padding_size[1] + padding_size[3]; + if (ceil_mode) { + output_shape0[height_axis] += stride_size[0] - 1; + output_shape0[width_axis] += stride_size[1] - 1; + output_shape1[height_axis] = + (inputs_shape[0][height_axis] - kernel_size[0] + padding_size[0] + padding_size[2] + stride_size[0] - 1) / + stride_size[0] + + 1; + output_shape1[width_axis] = + (inputs_shape[0][width_axis] - kernel_size[1] + padding_size[1] + padding_size[3] + stride_size[1] - 1) / + stride_size[1] + + 1; + } else { + output_shape1[height_axis] = + (inputs_shape[0][height_axis] - kernel_size[0] + padding_size[0] + padding_size[2]) / stride_size[0] + 1; + output_shape1[width_axis] = + (inputs_shape[0][width_axis] - kernel_size[1] + padding_size[1] + padding_size[3]) / stride_size[1] + 1; + } + + std::vector> res{output_shape0, output_shape1}; + return res; +} + +std::shared_ptr StrategyForPool3d(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const Target &target) { + framework::CINNCompute pool3d_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of pool3d compute is empty! Please check.\n"; + CINNValuePack a = args[0]; + CHECK(!a.empty()) << "The input tensor of pool3d compute is empty! Please check.\n"; + Expr A = a[0]; + CHECK(A.as_tensor()); + auto attr_store = attrs.attr_store; + std::vector kernel_size; // [kernel_d, kernel_h, kernel_w] + std::vector stride_size; // [stride_d, stride_h, stride_w] + std::vector + padding_size; // [padding_front, padding_top, padding_left, padding_back, padding_bottom, padding_right] + std::string pool_type = "max"; + bool ceil_mode = false; + bool exclusive = true; + std::string data_format = "NCDHW"; + for (auto &iter : attrs.attr_store) { + if (iter.first == "kernel_size") { + kernel_size = std::get>(iter.second); + } else if (iter.first == "stride_size") { + stride_size = std::get>(iter.second); + } else if (iter.first == "padding_size") { + padding_size = std::get>(iter.second); + } else if (iter.first == "pool_type") { + pool_type = std::get(iter.second); + } else if (iter.first == "ceil_mode") { + ceil_mode = std::get(iter.second); + } else if (iter.first == "exclusive") { + exclusive = std::get(iter.second); + } else if (iter.first == "data_format") { + data_format = std::get(iter.second); + } else { + LOG(ERROR) << "unsupported attr: " << iter.first << std::endl; + } + } + CHECK(!kernel_size.empty()); + CHECK(!stride_size.empty()); + CHECK(!padding_size.empty()); + auto out = pe::Pool3d(A.as_tensor_ref(), + kernel_size, + stride_size, + padding_size, + pool_type, + ceil_mode, + exclusive, + data_format, + UniqName("T_Pool3d_out")); + + auto stages = CreateStages(out); + std::vector res; + for (auto &t : out) { + res.push_back(CINNValue(Expr(t.get()))); + } + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + framework::CINNSchedule pool3d_schedule([](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of pool3d schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; + CHECK_EQ(arg_pack.size(), 3UL); + Expr A [[maybe_unused]] = arg_pack[0]; + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(pool3d_compute, pool3d_schedule, "strategy.pool3d.x86", 1); + + return strategy; +} + +std::vector> InferShapeForPool3d(const std::vector> &inputs_shape, + const framework::NodeAttr &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; + auto attr_store = attrs.attr_store; + std::vector kernel_size; // [kernel_d, kernel_h, kernel_w] + std::vector stride_size; // [stride_d, stride_h, stride_w] + std::vector + padding_size; // [padding_front, padding_top, padding_left, padding_bottom, padding_right, padding_back] + std::string pool_type = "max"; + bool ceil_mode = false; + bool exclusive = true; + std::string data_format = "NCDHW"; + for (auto &iter : attrs.attr_store) { + if (iter.first == "kernel_size") { + kernel_size = std::get>(iter.second); + } else if (iter.first == "stride_size") { + stride_size = std::get>(iter.second); + } else if (iter.first == "padding_size") { + padding_size = std::get>(iter.second); + } else if (iter.first == "ceil_mode") { + ceil_mode = std::get(iter.second); + } else if (iter.first == "exclusive") { + exclusive = std::get(iter.second); + } else if (iter.first == "data_format") { + data_format = std::get(iter.second); + } + } + CHECK_EQ(kernel_size.size(), 3U); + CHECK_EQ(stride_size.size(), 3U); + + std::vector output_shape0 = inputs_shape[0]; + std::vector output_shape1 = inputs_shape[0]; + CHECK_EQ(output_shape0.size(), 6U); + int depth_axis = -1; + int height_axis = -1; + int width_axis = -1; + if (data_format == "NCDHW") { + depth_axis = 2; + height_axis = 3; + width_axis = 4; + } else if (data_format == "NDHWC") { + depth_axis = 1; + height_axis = 2; + width_axis = 3; + } else { + LOG(ERROR) << "unsupported data_format: " << data_format << std::endl; + } + + output_shape0[depth_axis] += padding_size[0] + padding_size[3]; + output_shape0[height_axis] += padding_size[1] + padding_size[4]; + output_shape0[width_axis] += padding_size[2] + padding_size[5]; + if (ceil_mode) { + output_shape0[depth_axis] += stride_size[0] - 1; + output_shape0[height_axis] += stride_size[1] - 1; + output_shape0[width_axis] += stride_size[2] - 1; + output_shape1[depth_axis] = + (inputs_shape[0][depth_axis] - kernel_size[0] + padding_size[0] + padding_size[3] + stride_size[0] - 1) / + stride_size[0] + + 1; + output_shape1[height_axis] = + (inputs_shape[0][height_axis] - kernel_size[1] + padding_size[1] + padding_size[4] + stride_size[1] - 1) / + stride_size[1] + + 1; + output_shape1[width_axis] = + (inputs_shape[0][width_axis] - kernel_size[2] + padding_size[2] + padding_size[5] + stride_size[2] - 1) / + stride_size[2] + + 1; + } else { + output_shape1[depth_axis] = + (inputs_shape[0][depth_axis] - kernel_size[0] + padding_size[0] + padding_size[3]) / stride_size[0] + 1; + output_shape1[height_axis] = + (inputs_shape[0][height_axis] - kernel_size[1] + padding_size[1] + padding_size[4]) / stride_size[1] + 1; + output_shape1[width_axis] = + (inputs_shape[0][width_axis] - kernel_size[2] + padding_size[2] + padding_size[5]) / stride_size[2] + 1; + } + + std::vector> res{output_shape0, output_shape1}; + return res; +} + +std::vector InferDtypeForPool(const std::vector &inputs_type, const framework::NodeAttr &attrs) { + CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + std::vector res{inputs_type[0], inputs_type[0]}; + return res; +} + } // namespace op } // namespace hlir } // namespace cinn @@ -278,5 +721,32 @@ CINN_REGISTER_HELPER(nn_ops) { .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForBatchNorm)) .set_support_level(4); + CINN_REGISTER_OP(pool1d) + .describe("Do pooling on the width dimension of the input tensor.") + .set_num_inputs(1) + .set_num_outputs(2) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForPool1d) + .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForPool1d)) + .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForPool)) + .set_support_level(4); + + CINN_REGISTER_OP(pool2d) + .describe("Do pooling on the height and width dimension of the input tensor.") + .set_num_inputs(1) + .set_num_outputs(2) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForPool2d) + .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForPool2d)) + .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForPool)) + .set_support_level(4); + + CINN_REGISTER_OP(pool3d) + .describe("Do pooling on the depth, height and width dimension of the input tensor.") + .set_num_inputs(1) + .set_num_outputs(2) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForPool3d) + .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForPool3d)) + .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForPool)) + .set_support_level(4); + return true; } diff --git a/cinn/hlir/op/op_broadcast_test.cc b/cinn/hlir/op/op_broadcast_test.cc index 786fa1c71a..98b9b4e72a 100644 --- a/cinn/hlir/op/op_broadcast_test.cc +++ b/cinn/hlir/op/op_broadcast_test.cc @@ -37,7 +37,7 @@ TEST(Operator, Operator_ElementWise_Add_Test0) { ASSERT_EQ(rets.size(), 2UL); // the last element is a StageMap for (int i = 0; i < rets->size() - 1; i++) { - ir::Expr temp = rets[i]; + Expr temp = rets[i]; inputs.push_back(temp.as_tensor_ref()); } auto func = Lower("add1", rets.back(), inputs); @@ -69,7 +69,7 @@ TEST(Operator, Operator_ElementWise_Add_Test1) { ASSERT_EQ(rets.size(), 2UL); // the last element is a StageMap for (int i = 0; i < rets->size() - 1; i++) { - ir::Expr temp = rets[i]; + Expr temp = rets[i]; inputs.push_back(temp.as_tensor_ref()); } auto func = Lower("add1", rets.back(), inputs); diff --git a/cinn/hlir/op/op_nn_test.cc b/cinn/hlir/op/op_nn_test.cc new file mode 100644 index 0000000000..99f42318a7 --- /dev/null +++ b/cinn/hlir/op/op_nn_test.cc @@ -0,0 +1,308 @@ +#include + +#include +#include + +#include "cinn/backends/llvm/execution_engine.h" +#include "cinn/cinn.h" +#include "cinn/common/target.h" +#include "cinn/common/test_helper.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/op/use_ops.h" +#include "cinn/hlir/pe/nn.h" + +namespace cinn { +namespace hlir { +namespace framework { + +using CCompute = std::function(const std::vector)>; + +TEST(Operator, Operator_Pool2d_Test0) { + auto pool2d = Operator::Get("pool2d"); + Operator temp = *pool2d; + auto strategy = Operator::GetAttrs("CINNStrategy"); + + Expr N(1), C(3), H(8), W(8); + Placeholder A("A", {N, C, H, W}); + + NodeAttr attrs; + std::vector kernel_size = {2, 2}; + std::vector stride_size = {2, 2}; + std::vector padding_size = {1, 1, 1, 1}; + std::string pool_type = "max"; + attrs.attr_store["kernel_size"] = kernel_size; + attrs.attr_store["stride_size"] = stride_size; + attrs.attr_store["padding_size"] = padding_size; + attrs.attr_store["pool_type"] = pool_type; + std::vector inputs{A.tensor()}; + std::vector type{Float(32)}; + common::Target target = common::DefaultHostTarget(); + auto impl = OpStrategy::SelectImpl(strategy[pool2d](attrs, inputs, type, target)); + common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); + rets = impl->fschedule(rets); + ASSERT_EQ(rets.size(), 3UL); + // the last element is a StageMap + for (int i = 0; i < rets->size() - 1; i++) { + Expr temp = rets[i]; + inputs.push_back(temp.as_tensor_ref()); + } + auto func = Lower("pool2d", rets.back(), inputs); + LOG(INFO) << "Test Strategy Codegen:\n" << func; + + Module::Builder builder("module0", target); + builder.AddFunction(func); + auto jit = backends::ExecutionEngine::Create({}); + auto module = builder.Build(); + + jit->Link(module); + auto fn = jit->Lookup("pool2d"); + CHECK(fn); + auto fn_ = reinterpret_cast(fn); + + cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), {1, 3, 8, 8}).set_random().Build(); + cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {1, 3, 10, 10}).set_random().Build(); + cinn_buffer_t *C_buf = common::BufferBuilder(Float(32), {1, 3, 5, 5}).set_random().Build(); + cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); + cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; + fn_(args, 3); + + ASSERT_EQ(impl->name, "strategy.pool2d.x86"); + ASSERT_EQ(pool2d->description, "Do pooling on the height and width dimension of the input tensor."); +} + +TEST(Operator, Operator_Pool2d_Test1) { + auto pool2d = Operator::Get("pool2d"); + Operator temp = *pool2d; + auto strategy = Operator::GetAttrs("CINNStrategy"); + + Expr N(1), C(3), H(8), W(8); + Placeholder A("A", {N, C, H, W}); + + NodeAttr attrs; + std::vector kernel_size = {2, 2}; + std::vector stride_size = {2, 2}; + std::vector padding_size = {1, 1, 1, 1}; + std::string pool_type = "avg"; + attrs.attr_store["kernel_size"] = kernel_size; + attrs.attr_store["stride_size"] = stride_size; + attrs.attr_store["padding_size"] = padding_size; + attrs.attr_store["pool_type"] = pool_type; + attrs.attr_store["ceil_mode"] = true; + attrs.attr_store["exclusive"] = false; + std::vector inputs{A.tensor()}; + std::vector type{Float(32)}; + common::Target target = common::DefaultHostTarget(); + auto impl = OpStrategy::SelectImpl(strategy[pool2d](attrs, inputs, type, target)); + common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); + rets = impl->fschedule(rets); + ASSERT_EQ(rets.size(), 3UL); + // the last element is a StageMap + for (int i = 0; i < rets->size() - 1; i++) { + Expr temp = rets[i]; + inputs.push_back(temp.as_tensor_ref()); + } + auto func = Lower("pool2d", rets.back(), inputs); + LOG(INFO) << "Test Strategy Codegen:\n" << func; + + Module::Builder builder("module0", target); + builder.AddFunction(func); + auto jit = backends::ExecutionEngine::Create({}); + auto module = builder.Build(); + + jit->Link(module); + auto fn = jit->Lookup("pool2d"); + CHECK(fn); + auto fn_ = reinterpret_cast(fn); + + cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), {1, 3, 8, 8}).set_random().Build(); + cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {1, 3, 11, 11}).set_random().Build(); + cinn_buffer_t *C_buf = common::BufferBuilder(Float(32), {1, 3, 5, 5}).set_random().Build(); + cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); + cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; + fn_(args, 3); + + ASSERT_EQ(impl->name, "strategy.pool2d.x86"); + ASSERT_EQ(pool2d->description, "Do pooling on the height and width dimension of the input tensor."); +} + +TEST(Operator, Operator_Pool2d_Test2) { + auto pool2d = Operator::Get("pool2d"); + Operator temp = *pool2d; + auto strategy = Operator::GetAttrs("CINNStrategy"); + + Expr N(1), H(8), W(8), C(3); + Placeholder A("A", {N, H, W, C}); + + NodeAttr attrs; + std::vector kernel_size = {2, 2}; + std::vector stride_size = {2, 2}; + std::vector padding_size = {1, 1, 1, 1}; + std::string pool_type = "avg"; + std::string data_format = "NHWC"; + attrs.attr_store["kernel_size"] = kernel_size; + attrs.attr_store["stride_size"] = stride_size; + attrs.attr_store["padding_size"] = padding_size; + attrs.attr_store["pool_type"] = pool_type; + attrs.attr_store["ceil_mode"] = true; + attrs.attr_store["exclusive"] = true; + attrs.attr_store["data_format"] = data_format; + std::vector inputs{A.tensor()}; + std::vector type{Float(32)}; + common::Target target = common::DefaultHostTarget(); + auto impl = OpStrategy::SelectImpl(strategy[pool2d](attrs, inputs, type, target)); + common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); + rets = impl->fschedule(rets); + ASSERT_EQ(rets.size(), 3UL); + // the last element is a StageMap + for (int i = 0; i < rets->size() - 1; i++) { + Expr temp = rets[i]; + inputs.push_back(temp.as_tensor_ref()); + } + auto func = Lower("pool2d", rets.back(), inputs); + LOG(INFO) << "Test Strategy Codegen:\n" << func; + + Module::Builder builder("module0", target); + builder.AddFunction(func); + auto jit = backends::ExecutionEngine::Create({}); + auto module = builder.Build(); + + jit->Link(module); + auto fn = jit->Lookup("pool2d"); + CHECK(fn); + auto fn_ = reinterpret_cast(fn); + + cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), {1, 8, 8, 3}).set_random().Build(); + cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {1, 11, 11, 3}).set_random().Build(); + cinn_buffer_t *C_buf = common::BufferBuilder(Float(32), {1, 5, 5, 3}).set_random().Build(); + cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); + cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; + fn_(args, 3); + + ASSERT_EQ(impl->name, "strategy.pool2d.x86"); + ASSERT_EQ(pool2d->description, "Do pooling on the height and width dimension of the input tensor."); +} + +TEST(Operator, Operator_Pool3d_Test0) { + auto pool3d = Operator::Get("pool3d"); + Operator temp = *pool3d; + auto strategy = Operator::GetAttrs("CINNStrategy"); + + Expr N(1), D(8), H(8), W(8), C(3); + Placeholder A("A", {N, D, H, W, C}); + + NodeAttr attrs; + std::vector kernel_size = {2, 2, 2}; + std::vector stride_size = {2, 2, 2}; + std::vector padding_size = {1, 1, 1, 1, 1, 1}; + std::string pool_type = "max"; + std::string data_format = "NDHWC"; + attrs.attr_store["kernel_size"] = kernel_size; + attrs.attr_store["stride_size"] = stride_size; + attrs.attr_store["padding_size"] = padding_size; + attrs.attr_store["pool_type"] = pool_type; + attrs.attr_store["ceil_mode"] = false; + attrs.attr_store["exclusive"] = true; + attrs.attr_store["data_format"] = data_format; + std::vector inputs{A.tensor()}; + std::vector type{Float(32)}; + common::Target target = common::DefaultHostTarget(); + auto impl = OpStrategy::SelectImpl(strategy[pool3d](attrs, inputs, type, target)); + common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); + rets = impl->fschedule(rets); + ASSERT_EQ(rets.size(), 3UL); + // the last element is a StageMap + for (int i = 0; i < rets->size() - 1; i++) { + Expr temp = rets[i]; + inputs.push_back(temp.as_tensor_ref()); + } + auto func = Lower("pool3d", rets.back(), inputs); + LOG(INFO) << "Test Strategy Codegen:\n" << func; + + Module::Builder builder("module0", target); + builder.AddFunction(func); + auto jit = backends::ExecutionEngine::Create({}); + auto module = builder.Build(); + + jit->Link(module); + auto fn = jit->Lookup("pool3d"); + CHECK(fn); + auto fn_ = reinterpret_cast(fn); + + cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), {1, 8, 8, 8, 3}).set_random().Build(); + cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {1, 11, 11, 11, 3}).set_random().Build(); + cinn_buffer_t *C_buf = common::BufferBuilder(Float(32), {1, 5, 5, 5, 3}).set_random().Build(); + cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); + cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; + fn_(args, 3); + + ASSERT_EQ(impl->name, "strategy.pool3d.x86"); + ASSERT_EQ(pool3d->description, "Do pooling on the depth, height and width dimension of the input tensor."); +} + +TEST(Operator, Operator_Pool1d_Test0) { + auto pool1d = Operator::Get("pool1d"); + Operator temp = *pool1d; + auto strategy = Operator::GetAttrs("CINNStrategy"); + + Expr N(1), W(8), C(3); + Placeholder A("A", {N, W, C}); + + NodeAttr attrs; + std::vector kernel_size = {2}; + std::vector stride_size = {2}; + std::vector padding_size = {1, 1}; + std::string pool_type = "max"; + std::string data_format = "NWC"; + attrs.attr_store["kernel_size"] = kernel_size; + attrs.attr_store["stride_size"] = stride_size; + attrs.attr_store["padding_size"] = padding_size; + attrs.attr_store["pool_type"] = pool_type; + attrs.attr_store["ceil_mode"] = false; + attrs.attr_store["exclusive"] = true; + attrs.attr_store["data_format"] = data_format; + std::vector inputs{A.tensor()}; + std::vector type{Float(32)}; + common::Target target = common::DefaultHostTarget(); + auto impl = OpStrategy::SelectImpl(strategy[pool1d](attrs, inputs, type, target)); + common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); + rets = impl->fschedule(rets); + ASSERT_EQ(rets.size(), 3UL); + // the last element is a StageMap + for (int i = 0; i < rets->size() - 1; i++) { + Expr temp = rets[i]; + inputs.push_back(temp.as_tensor_ref()); + } + auto func = Lower("pool1d", rets.back(), inputs); + LOG(INFO) << "Test Strategy Codegen:\n" << func; + + Module::Builder builder("module0", target); + builder.AddFunction(func); + auto jit = backends::ExecutionEngine::Create({}); + auto module = builder.Build(); + + jit->Link(module); + auto fn = jit->Lookup("pool1d"); + CHECK(fn); + auto fn_ = reinterpret_cast(fn); + + cinn_buffer_t *A_buf = common::BufferBuilder(Float(32), {1, 8, 3}).set_random().Build(); + cinn_buffer_t *B_buf = common::BufferBuilder(Float(32), {1, 11, 3}).set_random().Build(); + cinn_buffer_t *C_buf = common::BufferBuilder(Float(32), {1, 5, 3}).set_random().Build(); + cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); + cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; + fn_(args, 3); + + ASSERT_EQ(impl->name, "strategy.pool1d.x86"); + ASSERT_EQ(pool1d->description, "Do pooling on the width dimension of the input tensor."); +} + +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/transform.cc b/cinn/hlir/op/transform.cc index 4833207bdd..0739770fe6 100644 --- a/cinn/hlir/op/transform.cc +++ b/cinn/hlir/op/transform.cc @@ -19,10 +19,12 @@ std::shared_ptr StrategyForMul(const framework::NodeAttr &attrs, const std::vector &inputs, const std::vector &out_type, const Target &target) { - framework::CINNCompute add_compute([&attrs](lang::Args args, lang::RetValue *ret) { + framework::CINNCompute mul_compute([&attrs](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of mul compute is empty! Please check.\n"; CINNValuePack a = args[0]; - Expr A = a[0]; - Expr B = a[1]; + CHECK_GE(a.size(), 2U) << "at least 2 input tensors for mul compute\n"; + Expr A = a[0]; + Expr B = a[1]; CHECK(A.as_tensor()); CHECK(B.as_tensor()); auto attr_store = attrs.attr_store; @@ -39,6 +41,8 @@ std::shared_ptr StrategyForMul(const framework::NodeAttr &attrs, x_num_col_dims = std::get(iter.second); } else if (iter.first == "y_num_col_dims") { y_num_col_dims = std::get(iter.second); + } else { + LOG(ERROR) << "unsupported attr: " << iter.first << std::endl; } } @@ -50,39 +54,55 @@ std::shared_ptr StrategyForMul(const framework::NodeAttr &attrs, *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; }); - framework::CINNSchedule add_schedule([](lang::Args args, lang::RetValue *ret) { - CINNValuePack arg_pack = args[0]; - Expr A [[maybe_unused]] = arg_pack[0]; + framework::CINNSchedule mul_schedule([](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of mul schedule is empty! Please check.\n"; + CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL); - *ret = arg_pack; + Expr A [[maybe_unused]] = arg_pack[0]; + *ret = arg_pack; }); auto strategy = std::make_shared(); - strategy->AddImpl(add_compute, add_schedule, "strategy.mul.x86", 1); + strategy->AddImpl(mul_compute, mul_schedule, "strategy.mul.x86", 1); return strategy; } -std::vector InferShapeForMul(const std::vector &inputs_shape, const framework::NodeAttr &attrs) { - VLOG(3) << "Mul shape0: " << utils::Join(inputs_shape[0], ","); - VLOG(3) << "Mul shape1: " << utils::Join(inputs_shape[1], ","); - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape is empty"; - - int x_num_col_dims = -1; - if (attrs.attr_store.count("x_num_col_dims")) { - x_num_col_dims = std::get(attrs.attr_store.at("x_num_col_dims")); +std::vector> InferShapeForMul(const std::vector> &inputs_shape, + const framework::NodeAttr &attrs) { + CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; + std::vector output_shape; + std::vector shape1_new; + std::vector shape2_new; + bool trans_a = false; + bool trans_b = false; + int x_num_col_dims = 1; + int y_num_col_dims = 1; + for (auto &iter : attrs.attr_store) { + if (iter.first == "trans_a") { + trans_a = std::get(iter.second); + } else if (iter.first == "trans_b") { + trans_b = std::get(iter.second); + } else if (iter.first == "x_num_col_dims") { + x_num_col_dims = std::get(iter.second); + } else if (iter.first == "y_num_col_dims") { + y_num_col_dims = std::get(iter.second); + } } - int y_num_col_dims = -1; - if (attrs.attr_store.count("y_num_col_dims")) { - y_num_col_dims = std::get(attrs.attr_store.at("y_num_col_dims")); + shape1_new = inputs_shape[0]; + shape2_new = inputs_shape[1]; + if (trans_a) { + std::reverse(shape1_new.begin(), shape1_new.end()); } + if (trans_b) { + std::reverse(shape2_new.begin(), shape2_new.end()); + } + output_shape.insert(output_shape.begin(), shape1_new.begin(), shape1_new.begin() + x_num_col_dims); + output_shape.insert(output_shape.end(), shape2_new.begin() + y_num_col_dims, shape2_new.end()); - shape_t out_shape; - for (int i = 0; i < x_num_col_dims; i++) out_shape.push_back(inputs_shape[0][i]); - for (int i = 0; i < y_num_col_dims; i++) out_shape.push_back(inputs_shape[1][inputs_shape.size() - 1 - i]); - - if (out_shape.empty()) return {{1}}; - return {out_shape}; + if (output_shape.empty()) return {{1}}; + std::vector> res{output_shape}; + return res; } std::vector InferDtypeForMul(const std::vector &inputs_type, const framework::NodeAttr &attrs) { diff --git a/cinn/hlir/pe/broadcast.cc b/cinn/hlir/pe/broadcast.cc index dcef8db4f7..f6aa83a69f 100644 --- a/cinn/hlir/pe/broadcast.cc +++ b/cinn/hlir/pe/broadcast.cc @@ -1,5 +1,6 @@ #include "cinn/hlir/pe/broadcast.h" +#include #include #include "cinn/common/ir_util.h" @@ -11,9 +12,9 @@ namespace cinn { namespace hlir { namespace pe { -using namespace cinn::ir; -using cinn::common::make_zero; -using cinn::lang::Compute; +using common::make_zero; +using ir::Tensor; +using lang::Compute; void GetBroadcastShape(const std::vector& shape1, const std::vector& shape2, @@ -24,28 +25,34 @@ void GetBroadcastShape(const std::vector& shape1, CHECK(common_shape); CHECK(broadcast_flag1); CHECK(broadcast_flag2); + int size1 = shape1.size(); std::vector shape2_new = shape2; + int axis_offset = -1; if (axis.defined()) { int axis_val = axis.as_int32(); CHECK_GE(axis_val, -1) << "wrong axis: " << axis_val << std::endl; CHECK_GE(shape1.size(), shape2.size()) << "A's shape should no less than B's when axis is defined\n"; CHECK_LE(axis_val, shape1.size() - shape2.size()) << "wrong axis: " << axis_val << std::endl; if (axis_val >= 0) { - int axis_offset = shape1.size() - shape2.size() - axis_val; - for (int i = 0; i < axis_offset; ++i) { - // specified axis to align, we push the Expr one to align + axis_offset = shape1.size() - shape2.size() - axis_val; + for (int i = 1; i <= axis_offset; ++i) { + // specified axis to align, we push the Expr one in tensor B so as to align right with tensor A. shape2_new.emplace_back(Expr(1)); + common_shape->insert(common_shape->begin(), shape1[size1 - i]); + // flag is used to indicate whether to include the indice or not. + broadcast_flag1->emplace_back(true); + broadcast_flag2->emplace_back(false); } } } - int size1 = shape1.size(); + int size2 = shape2_new.size(); Expr one(1); int i; - // insert common axis from right to left to common_axis - for (i = 1; i <= std::min(size1, size2); ++i) { - auto* var1 = shape1[size1 - i].as_var(); - auto* var2 = shape2_new[size2 - i].as_var(); + i = axis_offset <= 0 ? 1 : axis_offset + 1; + for (; i <= std::min(size1, size2); ++i) { + auto* var1 = shape1[size1 - i].As(); + auto* var2 = shape2_new[size2 - i].As(); if (MathEqual(shape1[size1 - i], shape2_new[size2 - i])) { common_shape->insert(common_shape->begin(), shape1[size1 - i]); broadcast_flag1->emplace_back(true); @@ -61,7 +68,7 @@ void GetBroadcastShape(const std::vector& shape1, broadcast_flag1->emplace_back(true); broadcast_flag2->emplace_back(false); } else if (var1 && var2) { - Expr max_var = Max::Make(shape1[size1 - i], shape2_new[size2 - i]); + Expr max_var = ir::Max::Make(shape1[size1 - i], shape2_new[size2 - i]); common_shape->insert(common_shape->begin(), max_var); broadcast_flag1->emplace_back(true); broadcast_flag2->emplace_back(true); @@ -158,9 +165,9 @@ HLIR_IMP_BC_PE(Divide, return a / b;); HLIR_IMP_BC_PE(FloorDivide, return Floor(a / b);); HLIR_IMP_BC_PE(Mod, return a % b;); HLIR_IMP_BC_PE(FloorMod, return a - Floor(a / b) * b;); -HLIR_IMP_BC_PE(Maximum, return Max::Make(a, b);); -HLIR_IMP_BC_PE(Minimum, return Min::Make(a, b);); -HLIR_IMP_BC_PE(Power, return Power::Make(a, b);); +HLIR_IMP_BC_PE(Maximum, return ir::Max::Make(a, b);); +HLIR_IMP_BC_PE(Minimum, return ir::Min::Make(a, b);); +HLIR_IMP_BC_PE(Power, return ir::Power::Make(a, b);); HLIR_IMP_BC_PE(LeftShift, return a << b;); HLIR_IMP_BC_PE(RightShift, return a >> b;); HLIR_IMP_BC_PE(LogicaAnd, return a && b;); @@ -171,8 +178,8 @@ HLIR_IMP_BC_PE(BitwiseOr, return a | b;); HLIR_IMP_BC_PE(BitwiseXor, return a ^ b;); HLIR_IMP_BC_PE(Greater, return a > b;); HLIR_IMP_BC_PE(Less, return a < b;); -HLIR_IMP_BC_PE(Equal, return EQ::Make(a, b);); -HLIR_IMP_BC_PE(NotEqual, return NE::Make(a, b);); +HLIR_IMP_BC_PE(Equal, return ir::EQ::Make(a, b);); +HLIR_IMP_BC_PE(NotEqual, return ir::NE::Make(a, b);); HLIR_IMP_BC_PE(GreaterEqual, return a >= b;); HLIR_IMP_BC_PE(LessEqual, return a <= b;); diff --git a/cinn/hlir/pe/broadcast.h b/cinn/hlir/pe/broadcast.h index adf1b51d6f..3f18a442c1 100644 --- a/cinn/hlir/pe/broadcast.h +++ b/cinn/hlir/pe/broadcast.h @@ -11,18 +11,25 @@ namespace pe { * * @param A The first Tensor or Expr * @param B The second Tensor or Expr - * @param output_name The name of the output Tensor + * @param axis Tensor B's beginning position of Tensor A. Default is -1(right align) and then axis = rank(X)-rank(Y). + * @param out_name The name of the output Tensor * * @return The result Tensor or Expr. + * @notes Tensor A's shape should no less than Tensor B's. + * e.g. + * shape(A) = (2, 3, 4, 5), shape(B) = (4, 5), with axis=-1(default) or axis=2 + * shape(A) = (2, 3, 4, 5), shape(B) = (3, 4), with axis=1 + * shape(A) = (2, 3, 4, 5), shape(B) = (2), with axis=0 + * shape(A) = (2, 3, 4, 5), shape(B) = (2, 1), with axis=0 */ -#define HLIR_DCL_BC_PE(name__) \ - ir::Tensor name__(const ir::Tensor& A, \ - const ir::Tensor& B, \ - const std::string& output_name = "T_" #name__ "_out", \ - const Expr& axis = Expr()); \ - ir::Tensor name__(const Expr& A, const ir::Tensor& B, const std::string& output_name = "T_" #name__ "_out"); \ - ir::Tensor name__(const ir::Tensor& A, const Expr& B, const std::string& output_name = "T_" #name__ "_out"); \ - Expr name__(const Expr& A, const Expr& B); +#define HLIR_DCL_BC_PE(name__) \ + Expr name__(const Expr& A, const Expr& B); \ + ir::Tensor name__(const ir::Tensor& A, \ + const ir::Tensor& B, \ + const std::string& out_name = "T_" #name__ "_out", \ + const Expr& axis = Expr()); \ + ir::Tensor name__(const Expr& A, const ir::Tensor& B, const std::string& out_name = "T_" #name__ "_out"); \ + ir::Tensor name__(const ir::Tensor& A, const Expr& B, const std::string& out_name = "T_" #name__ "_out"); //! Compute A + B with auto-broadcasting. HLIR_DCL_BC_PE(Add); diff --git a/cinn/hlir/pe/nn.cc b/cinn/hlir/pe/nn.cc index 1ed51cdf1d..20a73dfeff 100644 --- a/cinn/hlir/pe/nn.cc +++ b/cinn/hlir/pe/nn.cc @@ -3,9 +3,9 @@ #include #include +#include "cinn/common/cas.h" #include "cinn/common/context.h" #include "cinn/hlir/pe/broadcast.h" -#include "cinn/hlir/pe/nn.h" #include "cinn/ir/ir_operators.h" #include "cinn/lang/builtin.h" #include "cinn/lang/compute.h" @@ -16,16 +16,14 @@ namespace hlir { namespace pe { using cinn::lang::Compute; -using namespace ir; - -enum PoolType { - kAvgPool, - kMaxPool, -}; +using ir::Max; +using ir::Min; +using ir::Select; +using ir::Tensor; Tensor LeakyRelu(const Tensor &A, double alpha, const std::string &output_name) { return Compute( - A->shape, [&](const std::vector &indice) { return LeakyRelu(A(indice), alpha); }, output_name); + A->shape, [=](const std::vector &indice) { return LeakyRelu(A(indice), alpha); }, output_name); } Tensor PRelu(const Tensor &A, const Tensor &slope, const int axis, const std::string &output_name) { @@ -33,7 +31,7 @@ Tensor PRelu(const Tensor &A, const Tensor &slope, const int axis, const std::st CHECK(A->shape[axis] == slope->shape[0]) << "Wrong slope shape: " << slope->shape[0] << std::endl; return Compute( A->shape, - [&](const std::vector &indice) { return LeakyRelu(A(indice), slope(indice[axis])); }, + [=](const std::vector &indice) { return LeakyRelu(A(indice), slope(indice[axis])); }, output_name); } @@ -116,6 +114,322 @@ ir::Tensor BatchNorm_NCHW(const ir::Tensor &input, return res; } +/** + * @brief Perform padding operation. + * @param tensor The input tensor. + * @param pad_before Vector of Exprs describing the padding before the respective dimension + * @param pad_after Vector of Exprs describing the padding after the respective dimension + * @param pad_value The value to fill padding elements with. Default is zero. + * @param name The name of the output padding tensor + * @param pad_mode Padding type to use: "constant" pads with constant_value; "edge" pads using the edge values of the + * input array; "reflect" pads by reflecting values with respect to the edges. + * + * @return the output tensor after padding. + * + * @note + * The pad_after vector must either be empty or have the same length as pad_before + * When pad_after is empty, it takes the same values as pad_before (symmetric padding) + * The pad vector applies from the leading dimensions and skips missing trailing dimensions: + * e.g. + * pad(t(i, j, k), {1}, {1}) returns the equivalent operation for + * the following pseudocode: + * for i in [0, t.shape[0] + 2): + * for j in [0, t.shape[0] + 2): + * for k in [0, t.shape[0] + 2): + * name(i,j,k) = + * i < 1 ? 0 : + * ((1 <= i < t.shape[0] + 1) ? + * t(i-1, j, k) : 0)); + * + */ +Tensor Pad(const Tensor &tensor, + const std::vector &pad_before, + std::vector pad_after = std::vector(), + Expr pad_value = Expr(), + const std::string &name = UniqName("T_pad_out"), + const std::string &pad_mode = "constant") { + // When pad_after is empty, it takes the same values as pad_before (symmetric padding) + if (pad_after.size() < pad_before.size()) { + for (size_t i = pad_after.size(); i < pad_before.size(); ++i) { + pad_after.push_back(pad_before[i]); + } + } + CHECK(!pad_before.empty()); + CHECK_EQ(pad_before.size(), pad_after.size()); + std::vector output_shape; + for (auto &ele : pad_before) { + CHECK(ele.type().is_int(32)) << "padding size should be int32\n"; + } + for (auto &ele : pad_after) { + CHECK(ele.type().is_int(32)) << "padding size should be int32\n"; + } + for (size_t i = 0; i < tensor->shape.size(); ++i) { + if (i >= pad_before.size()) { + output_shape.push_back(tensor->shape[i]); + } else { + auto shape = common::AutoSimplify(tensor->shape[i] + pad_before[i] + pad_after[i]); + output_shape.push_back(shape); + } + } + // default value is zero + if (!pad_value.defined()) { + pad_value = make_const(tensor->type(), 0); + } + + auto fn = [=](const std::vector &ovars) { + std::vector indices; + std::vector sel; + std::vector pad_idx; + for (size_t i = 0; i < tensor->shape.size(); ++i) { + if (i >= pad_before.size()) { + indices.emplace_back(ovars[i]); + continue; + } + if (!MathEqual(pad_before[i], Expr(0))) { + sel.push_back(ir::GE::Make(ovars[i], pad_before[i])); + indices.push_back(ovars[i] - pad_before[i]); + } else { + indices.emplace_back(ovars[i]); + } + Expr sel_after; + if (!MathEqual(pad_after[i], Expr(0))) { + sel_after = common::AutoSimplify(ovars[i] < pad_before[i] + tensor->shape[i]); + sel.push_back(sel_after); + } + if (pad_mode == "edge") { + pad_idx.push_back(Select::Make( + ovars[i] < pad_before[i], + 0, + Select::Make( + ovars[i] >= pad_before[i] + tensor->shape[i], tensor->shape[i] - 1, ovars[i] - pad_before[i]))); + } else if (pad_mode == "reflect") { + pad_idx.push_back(Select::Make(ovars[i] < pad_before[i], + pad_before[i] - ovars[i], + Select::Make(ovars[i] >= pad_before[i] + tensor->shape[i], + tensor->shape[i] * 2 - ovars[i] + pad_before[i] - 2, + ovars[i] - pad_before[i]))); + } + } + if (sel.size() != 0) { + auto fn = [](Expr a, Expr b) { return a && b; }; + if (pad_mode == "constant") { + return Select::Make(FoldExpr(fn, sel), tensor(indices), pad_value); + } else if (pad_mode == "edge" || pad_mode == "reflect") { + return Select::Make(FoldExpr(fn, sel), tensor(indices), tensor(pad_idx)); + } + } + return tensor(indices); + }; + return Compute(output_shape, fn, name); +} + +/** + * @brief Perform pooling on N-dimension of data. + * + * @param tensor The input tensor with the shape of {N, C, H, W} or {N, H, W, C}. + * @param kernel_size Vector of N ints that indicates pooling kernel size. If N is 2, then is {pool_kernel_Height, + * pool_kernel_Width}. + * @param stride_size Vector of N ints that indicates pooling stride size. If N is 2, then is {pool_stride_Height, + * pool_stride_Width}. + * @param padding_size Vector of N*2 ints {head_pad_d1, head_pad_d2, ..., head_pad_dN, tail_pad_d1, tail_pad_d2, ..., + * tail_pad_dN}. If N is 2, then is {pad_height_top, pad_width_left, pad_height_bottom, pad_width_right]}. + * @param pool_type The type of pooling operator, currently support "max" and "avg". + * @param axis Vector of axes of the tensor for pooling. + * @param ceil_mode Whether to use ceil when calculating the output size. + * @param exclusive Whether include padding in the calculation'. + * @param output_name the name of the output tensor after padding and pooling. + * + * @return the vector of padding tensor and pooling tensor + */ +std::vector PoolImpl(const Tensor &tensor, + const std::vector &kernel_size, + const std::vector &stride_size, + const std::vector &padding_size, + const std::string &pool_type, + const std::vector &axis, + bool ceil_mode, + bool exclusive, + const std::string &output_name) { + CHECK(!kernel_size.empty()) << "Pooling kernel_size should not be empty\n"; + int k_size = kernel_size.size(); + int x_size = tensor->shape.size(); + CHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel\n"; + CHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must have double elements as kernel\n"; + CHECK_EQ(axis.size(), k_size) << "Axis must have same elements as kernel\n"; + + std::vector daxis; + std::vector kernel(k_size); + std::vector stride(k_size); + std::vector pad_head(k_size); + std::vector pad_tail(k_size); + std::vector pad_before(x_size, Expr(0)); + std::vector pad_after(x_size, Expr(0)); + std::vector out_shape = tensor->shape; + + bool do_pad = false; + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + kernel[i] = Expr(kernel_size[i]); + stride[i] = Expr(stride_size[i]); + pad_head[i] = Expr(padding_size[i]); + pad_tail[i] = Expr(padding_size[i + k_size]); + do_pad = (do_pad) ? do_pad : (padding_size[i] || padding_size[i + k_size]); + + if (ceil_mode) { + pad_tail[i] = common::AutoSimplify(pad_tail[i] + stride[i] - 1); + } + + daxis.emplace_back(Var(kernel[i], UniqName("kernel_idx"))); + + pad_before[ii] = pad_head[i]; + pad_after[ii] = pad_tail[i]; + + auto out_dim = common::AutoSimplify((tensor->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i]) / stride[i] + 1); + + out_shape[ii] = out_dim; + } + + Tensor temp = tensor; + Tensor res = tensor; + if (pool_type == "max") { + Expr min_value = ir::min_value(tensor->type()); + // Pad the input tensor with the pad_value of type's minimum value + temp = do_pad ? Pad(tensor, pad_before, pad_after, min_value, UniqName("pad_temp")) : tensor; + res = Compute( + out_shape, + [=](const std::vector &output) { + std::vector indices; + for (auto &var : output) indices.push_back(var); + + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + indices[ii] = output[ii] * stride[i] + daxis[i]; + } + + return ReduceMax(temp(indices), min_value); + }, + output_name, + daxis); + } else if (pool_type == "avg") { + // Pad the input tensor with pad_value zero + temp = do_pad ? Pad(tensor, pad_before, pad_after, 0, UniqName("pad_temp")) : tensor; + res = Compute( + out_shape, + [=](const std::vector &output) { + std::vector indices; + for (const Expr &var : output) indices.push_back(var); + + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + indices[ii] = output[ii] * stride[i] + daxis[i]; + } + + if (exclusive) { + std::vector start(k_size); + std::vector end(k_size); + auto kernel_size = make_const(Int(32), 1); + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + start[i] = common::AutoSimplify(output[ii] * stride[i] - pad_head[i]); + end[i] = Min::Make(start[i] + kernel[i], tensor->shape[ii]); + start[i] = Max::Make(start[i], make_const(Int(32), 0)); + kernel_size = kernel_size * (end[i] - start[i]); + } + common::AutoSimplify(kernel_size); + Expr divide_factor = Max::Make(kernel_size, make_const(Int(32), 1)); + return ReduceSum(ir::Div::Make(temp(indices), cast(divide_factor, Float(32))), Expr()); + } else { + auto kernel_size = make_const(Int(32), 1); + for (int i = 0; i < k_size; i++) { + kernel_size = kernel_size * kernel[i]; + } + common::AutoSimplify(kernel_size); + return ReduceSum(ir::Div::Make(temp(indices), cast(kernel_size, Float(32))), Expr()); + } + }, + output_name, + daxis); + } else { + LOG(ERROR) << "Unrecognized pool_type: " << pool_type; + } + return {temp, res}; +} + +std::vector Pool1d(const Tensor &tensor, + const std::vector &kernel_size, + const std::vector &stride_size, + const std::vector &padding_size, + const std::string &pool_type, + bool ceil_mode, + bool exclusive, + const std::string &data_format, + const std::string &output_name) { + int width_axis = -1; + if (data_format == "NCW") { + width_axis = 2; + } else if (data_format == "NWC") { + width_axis = 1; + } else { + LOG(FATAL) << "Unsupported data format: " << data_format << std::endl; + } + CHECK_EQ(tensor->shape.size(), 3U) << "pool1d requires tensor's shape_size to be 3\n"; + std::vector axis = {width_axis}; + return PoolImpl(tensor, kernel_size, stride_size, padding_size, pool_type, axis, ceil_mode, exclusive, output_name); +} + +std::vector Pool2d(const Tensor &tensor, + const std::vector &kernel_size, + const std::vector &stride_size, + const std::vector &padding_size, + const std::string &pool_type, + bool ceil_mode, + bool exclusive, + const std::string &data_format, + const std::string &output_name) { + int height_axis = -1; + int width_axis = -1; + if (data_format == "NCHW") { + height_axis = 2; + width_axis = 3; + } else if (data_format == "NHWC") { + height_axis = 1; + width_axis = 2; + } else { + LOG(FATAL) << "Unsupported data format: " << data_format << std::endl; + } + CHECK_EQ(tensor->shape.size(), 4U) << "pool1d requires tensor's shape_size to be 4\n"; + std::vector axis = {height_axis, width_axis}; + return PoolImpl(tensor, kernel_size, stride_size, padding_size, pool_type, axis, ceil_mode, exclusive, output_name); +} + +std::vector Pool3d(const Tensor &tensor, + const std::vector &kernel_size, + const std::vector &stride_size, + const std::vector &padding_size, + const std::string &pool_type, + bool ceil_mode, + bool exclusive, + const std::string &data_format, + const std::string &output_name) { + int height_axis = -1; + int width_axis = -1; + int depth_axis = -1; + if (data_format == "NCDHW") { + depth_axis = 2; + height_axis = 3; + width_axis = 4; + } else if (data_format == "NDHWC") { + depth_axis = 1; + height_axis = 2; + width_axis = 3; + } else { + LOG(FATAL) << "Unsupported data format: " << data_format << std::endl; + } + CHECK_EQ(tensor->shape.size(), 5U) << "pool1d requires tensor's shape_size to be 5\n"; + std::vector axis = {depth_axis, height_axis, width_axis}; + return PoolImpl(tensor, kernel_size, stride_size, padding_size, pool_type, axis, ceil_mode, exclusive, output_name); +} + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/pe/nn.h b/cinn/hlir/pe/nn.h index 3e68d36d4c..34d1b41198 100644 --- a/cinn/hlir/pe/nn.h +++ b/cinn/hlir/pe/nn.h @@ -82,6 +82,83 @@ ir::Tensor BatchNorm_NCHW(const ir::Tensor& input, float epsilon, const std::string& output_name); +/** + * @brief Perform pooling on the width dimension of the tensor. + * Width axis is determined by the data_format string in which 'W' means width. Only support NCW and NWC + * data_format. + * @param tensor The input tensor with shape of {N, C, W} or {N, W, C} + * @param kernel_size Vector of ints: {pool_kernel_width} + * @param stride_size Vector of ints: {pool_stride_width} + * @param padding_size Vector of ints: {head_pad_width, tail_pad_width} + * @param pool_type The type of pooling operator, currently support "max" and "avg". Default is "max". + * @param ceil_mode Whether to use ceil when calculating the output size. Default is false. + * @param exclusive Whether include padding in the calculation. Default is True. + * @param data_format The input data format. Only support NCW and NWC data_format. + * @param output_name the name of the output tensor after padding and pooling. + * + * @return the vector of padding tensor and pooling tensor. + */ +std::vector Pool1d(const ir::Tensor& tensor, + const std::vector& kernel_size, + const std::vector& stride_size, + const std::vector& padding_size, + const std::string& pool_type = "max", + bool ceil_mode = false, + bool exclusive = true, + const std::string& data_format = "NCW", + const std::string& output_name = "T_Pool1d_out"); + +/** + * @brief Perform pooling on the height and width dimension of the tensor. + * Height and width axes are determined by the data_format string in which 'H' means height and 'W' means width. + * Only support NCHW and NHWC data_format. + * @param tensor The input tensor with shape of {N, C, H, W} or {N, H, W, C} + * @param kernel_size Vector of ints: {pool_kernel_height, pool_kernel_width} + * @param stride_size Vector of ints: {pool_stride_height, pool_stride_width} + * @param padding_size Vector of ints: {head_pad_height, head_pad_width, tail_pad_height, tail_pad_width} + * @param pool_type The type of pooling operator, currently support "max" and "avg". Default is "max". + * @param ceil_mode Whether to use ceil when calculating the output size. Default is false. + * @param exclusive Whether include padding in the calculation. Default is True. + * @param data_format The input data format. Only support NCHW and NHWC data_format. + * @param output_name the name of the output tensor after padding and pooling. + * + * @return the vector of padding tensor and pooling tensor. + */ +std::vector Pool2d(const ir::Tensor& tensor, + const std::vector& kernel_size, + const std::vector& stride_size, + const std::vector& padding_size, + const std::string& pool_type = "max", + bool ceil_mode = false, + bool exclusive = true, + const std::string& data_format = "NCHW", + const std::string& output_name = "T_Pool2d_out"); + +/** + * @brief Perform pooling on the depth, height and width dimension of the tensor. + * Depth, height and width axis is determined by the data_format string in which 'D' means depth, 'H' means + * height and 'W' means width. Only support NCDHW and NDHWC data_format. + * @param tensor The input tensor with shape of {N, C, D, H, W} or {N, D, H, W, C} + * @param kernel_size Vector of ints: {pool_kernel_depth, pool_kernel_height, pool_kernel_width} + * @param stride_size Vector of ints: {pool_stride_depth, pool_stride_height, pool_stride_width} + * @param padding_size Vector of ints: {head_pad_depth, head_pad_height, head_pad_width, tail_pad_depth, + * tail_pad_height, tail_pad_width} + * @param pool_type The type of pooling operator, currently support "max" and "avg". Default is "max". + * @param ceil_mode Whether to use ceil when calculating the output size. Default is false. + * @param exclusive Whether include padding in the calculation. Default is True. + * @param data_format The input data format. Only support NCDHW and NDHWC data_format. + * @param output_name the name of the output tensor after padding and pooling. + */ +std::vector Pool3d(const ir::Tensor& x, + const std::vector& kernel_size, + const std::vector& stride_size, + const std::vector& padding_size, + const std::string& pool_type = "max", + bool ceil_mode = false, + bool exclusive = true, + const std::string& data_format = "NCDHW", + const std::string& output_name = "T_Pool3d_out"); + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/pe/pe_broadcast_test.cc b/cinn/hlir/pe/pe_broadcast_test.cc index 0eb6c7ad4e..5c94653ffa 100644 --- a/cinn/hlir/pe/pe_broadcast_test.cc +++ b/cinn/hlir/pe/pe_broadcast_test.cc @@ -63,20 +63,121 @@ void TestBroadcastPE( } } +void TestBroadcastPE1( + const std::string &fn_name, + Tensor (*func_op)(const Tensor &A, const Tensor &B, const std::string &output_name, const Expr &axis), + float (*fn_runtime)(float, float), + int set_value = 0) { + Expr M(100), N(32), K(10); + Placeholder A("A", {M, N, K}); + Placeholder B("B", {N}); + auto C = func_op(A.tensor(), B.tensor(), "C", Expr(1)); + auto stages = CreateStages({C}); + Target target = common::DefaultHostTarget(); + Module::Builder builder("module0", target); + auto func = Lower("fn", stages, {A, B, C}); + builder.AddFunction(func); + LOG(INFO) << "func:\n" << func; + auto jit = backends::ExecutionEngine::Create({}); + auto module = builder.Build(); + jit->Link(module); + auto fn = jit->Lookup("fn"); + CHECK(fn); + auto fn_ = reinterpret_cast(fn); + cinn_buffer_t *A_buf; + cinn_buffer_t *B_buf; + if (set_value != 0) { + A_buf = common::BufferBuilder(Float(32), {100, 32, 10}).set_val(set_value).Build(); + B_buf = common::BufferBuilder(Float(32), {32}).set_val(set_value).Build(); + } else { + A_buf = common::BufferBuilder(Float(32), {100, 32, 10}).set_random().Build(); + B_buf = common::BufferBuilder(Float(32), {32}).set_random().Build(); + } + auto *C_buf = common::BufferBuilder(Float(32), {100, 32, 10}).set_zero().Build(); + cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); + cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; + fn_(args, 3); + auto *ad = reinterpret_cast(A_buf->memory); + auto *bd = reinterpret_cast(B_buf->memory); + auto *cd = reinterpret_cast(C_buf->memory); + for (size_t i = 0; i < 100; i++) { + for (size_t j = 0; j < 32; j++) { + for (size_t k = 0; k < 10; k++) { + int index = 32 * 10 * i + 10 * j + k; + ASSERT_NEAR(cd[index], fn_runtime(ad[index], bd[j]), 1e-5); + } + } + } +} + +void TestBroadcastPE2( + const std::string &fn_name, + Tensor (*func_op)(const Tensor &A, const Tensor &B, const std::string &output_name, const Expr &axis), + float (*fn_runtime)(float, float), + int set_value = 0) { + Expr M(100), N(32), K(10), R(1); + Placeholder A("A", {M, N, K, R}); + Placeholder B("B", {N, K}); + auto C = func_op(A.tensor(), B.tensor(), "C", Expr(1)); + auto stages = CreateStages({C}); + Target target = common::DefaultHostTarget(); + Module::Builder builder("module0", target); + auto func = Lower("fn", stages, {A, B, C}); + builder.AddFunction(func); + LOG(INFO) << "func:\n" << func; + auto jit = backends::ExecutionEngine::Create({}); + auto module = builder.Build(); + jit->Link(module); + auto fn = jit->Lookup("fn"); + CHECK(fn); + auto fn_ = reinterpret_cast(fn); + cinn_buffer_t *A_buf; + cinn_buffer_t *B_buf; + if (set_value != 0) { + A_buf = common::BufferBuilder(Float(32), {100, 32, 10, 1}).set_val(set_value).Build(); + B_buf = common::BufferBuilder(Float(32), {32, 10}).set_val(set_value).Build(); + } else { + A_buf = common::BufferBuilder(Float(32), {100, 32, 10, 1}).set_random().Build(); + B_buf = common::BufferBuilder(Float(32), {32, 10}).set_random().Build(); + } + auto *C_buf = common::BufferBuilder(Float(32), {100, 32, 10, 1}).set_zero().Build(); + cinn_pod_value_t a_arg(A_buf), b_arg(B_buf), c_arg(C_buf); + cinn_pod_value_t args[] = {a_arg, b_arg, c_arg}; + fn_(args, 3); + auto *ad = reinterpret_cast(A_buf->memory); + auto *bd = reinterpret_cast(B_buf->memory); + auto *cd = reinterpret_cast(C_buf->memory); + for (size_t i = 0; i < 100; i++) { + for (size_t j = 0; j < 32; j++) { + for (size_t k = 0; k < 10; k++) { + for (size_t r = 0; r < 1; r++) { + int index = 32 * 10 * i + 10 * j + k + r; + ASSERT_NEAR(cd[index], fn_runtime(ad[index], bd[10 * j + k]), 1e-5); + } + } + } + } +} + #define RULE(test_name__, rule__) \ float test_name__(float a, float b) { rule__ } -#define TEST_BROADCAST_PE_FP32_BASIC(test_name__) \ - TEST(elementwise_pe, test_name__) { TestBroadcastPE("PE_Broadcast_" #test_name__ "_fp32", test_name__, test_name__); } +#define TEST_BROADCAST_PE_FP32_BASIC(test_name__) \ + TEST(broadcast_pe, test_name__) { TestBroadcastPE("PE_Broadcast_" #test_name__ "_fp32", test_name__, test_name__); } \ + TEST(broadcast_pe1, test_name__) { \ + TestBroadcastPE1("PE_Broadcast_" #test_name__ "_fp32", test_name__, test_name__); \ + } \ + TEST(broadcast_pe2, test_name__) { TestBroadcastPE2("PE_Broadcast_" #test_name__ "_fp32", test_name__, test_name__); } #define TEST_BROADCAST_PE_FP32_SET_BASIC(test_name__) \ - TEST(elementwise_pe, test_name__) { TestBroadcastPE("PE_Broadcast_" #test_name__ "_fp32", test_name__, value); } + TEST(broadcast_pe, test_name__) { TestBroadcastPE("PE_Broadcast_" #test_name__ "_fp32", test_name__, value); } #define TEST_BROADCAST_PE_FP32(test_name__, rule__) \ RULE(test_name__, rule__) \ TEST_BROADCAST_PE_FP32_BASIC(test_name__) TEST_BROADCAST_PE_FP32(Add, return a + b;) +TEST_BROADCAST_PE_FP32(Multiply, return a * b;) } // namespace pe } // namespace hlir -} // namespace cinn \ No newline at end of file +} // namespace cinn diff --git a/cinn/ir/ir_operators.cc b/cinn/ir/ir_operators.cc index 8108103291..c1eb53988f 100644 --- a/cinn/ir/ir_operators.cc +++ b/cinn/ir/ir_operators.cc @@ -1,6 +1,8 @@ #include "cinn/ir/ir_operators.h" -#include +#include + +#include "cinn/lang/compute.h" #include "cinn/common/type.h" @@ -145,5 +147,30 @@ EXTERN_CALL_IMP(Tanh, tanh); EXTERN_CALL_IMP(Isfinite, isfinite); EXTERN_CALL_IMP(Isinf, isinf); +Expr min_value(const Type& type) { + CHECK_EQ(type.lanes(), 1); + if (type.is_int()) { + if (type.bits() == 64) { + return Expr(std::numeric_limits::lowest()); + } else if (type.bits() < 64) { + int64_t val = 1; + val = -(val << (type.bits() - 1)); + return Expr(val); + } + } else if (type.is_uint()) { + return Expr(0); + } else if (type.is_float()) { + if (type.bits() == 64) { + return Expr(std::numeric_limits::lowest()); + } else if (type.bits() == 32) { + return Expr(std::numeric_limits::lowest()); + } else if (type.bits() == 16) { + return Expr(-65504.0); + } + } + LOG(FATAL) << "Cannot decide min_value for type" << type; + return Expr(); +} + } // namespace ir } // namespace cinn diff --git a/cinn/ir/ir_operators.h b/cinn/ir/ir_operators.h index 9ca5d53192..bd5d98b886 100644 --- a/cinn/ir/ir_operators.h +++ b/cinn/ir/ir_operators.h @@ -1,5 +1,7 @@ #pragma once -#include +#include + +#include "cinn/common/ir_util.h" #include "cinn/ir/ir.h" @@ -208,5 +210,7 @@ inline Expr ReduceMul(Expr e, Expr initial) { inline Expr ReduceMax(Expr e, Expr initial) { return Reduce::Make(Reduce::kMax, initial, e); } inline Expr ReduceMin(Expr e, Expr initial) { return Reduce::Make(Reduce::kMin, initial, e); } +Expr min_value(const Type& type); + } // namespace ir } // namespace cinn diff --git a/cinn/pybind/framework.cc b/cinn/pybind/framework.cc index cfe935d0c5..96e722f09f 100644 --- a/cinn/pybind/framework.cc +++ b/cinn/pybind/framework.cc @@ -51,7 +51,7 @@ void BindFramework(pybind11::module *m) { .def_readwrite("attr_store", &NodeAttr::attr_store) .def("set_attr", [](NodeAttr &self, const std::string &key, NodeAttr::attr_t value) { self.attr_store[key] = value; }) + .def("get_attr", [](NodeAttr &self, const std::string &key) { return self.attr_store[key]; }) .def("__str__", [](NodeAttr &self) { return utils::GetStreamCnt(self); }); - } // namespace frontend } // namespace cinn::pybind diff --git a/python/tests/pool_utils.py b/python/tests/pool_utils.py new file mode 100644 index 0000000000..05c09c0c6d --- /dev/null +++ b/python/tests/pool_utils.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python3 +import math +import numpy as np +import sys + + +def pool2d(np_data, attrs, dtype="float32"): + pool_type = "max" + ceil_mode = False + exclusive = True + data_format = "NCHW" + for key in attrs.attr_store: + if key == "kernel_size": + kernel_size = attrs.get_attr("kernel_size") + elif key == "stride_size": + stride_size = attrs.get_attr("stride_size") + elif key == "padding_size": + padding_size = attrs.get_attr("padding_size") + elif key == "pool_type": + pool_type = attrs.get_attr("pool_type") + elif key == "ceil_mode": + ceil_mode = attrs.get_attr("ceil_mode") + elif key == "exclusive": + exclusive = attrs.get_attr("exclusive") + elif key == "data_format": + data_format = attrs.get_attr("data_format") + else: + raise ValueError("attr_store {} is not supported".format(key)) + + if data_format == "NCHW": + in_n, in_c, in_h, in_w = in_shape = np_data.shape + height_axis = 2 + width_axis = 3 + elif data_format == "NHWC": + in_n, in_h, in_w, in_c = in_shape = np_data.shape + height_axis = 1 + width_axis = 2 + else: + raise ValueError("data_format {} is not supported".format(data_format)) + + if isinstance(kernel_size, int): + k_h = k_w = kernel_size + else: + k_h, k_w = kernel_size + if isinstance(stride_size, int): + s_h = s_w = stride_size + else: + s_h, s_w = stride_size + if isinstance(padding_size, int): + pt = pl = pb = pr = padding_size + else: + pt, pl, pb, pr = padding_size + + out_shape0 = list(in_shape) + out_shape0[height_axis] = in_shape[height_axis] + pt + pb + out_shape0[width_axis] = in_shape[width_axis] + pl + pr + + out_shape = list(in_shape) + if ceil_mode: + out_shape0[height_axis] += s_h - 1 + out_shape0[width_axis] += s_w - 1 + out_shape[height_axis] = int( + math.ceil(float(in_shape[height_axis] - k_h + pt + pb) / s_h) + 1) + out_shape[width_axis] = int( + math.ceil(float(in_shape[width_axis] - k_w + pl + pr) / s_w) + 1) + else: + out_shape[height_axis] = int( + math.floor(float(in_shape[height_axis] - k_h + pt + pb) / s_h) + 1) + out_shape[width_axis] = int( + math.floor(float(in_shape[width_axis] - k_w + pl + pr) / s_w) + 1) + + fill_value = 0 + if exclusive and pool_type == 'max': + fill_value = sys.float_info.min + + if data_format == "NCHW": + pad_np = np.full( + shape=(in_n, in_c, in_h + pt + pb, in_w + pl + pr), + fill_value=fill_value, + dtype=dtype) + no_zero = (range(in_n), range(in_c), range(pt, in_h + pt), + range(pl, in_w + pl)) + else: + pad_np = np.full( + shape=(in_n, in_h + pt + pb, in_w + pl + pr, in_c), + fill_value=fill_value, + dtype=dtype) + no_zero = (range(in_n), range(pt, in_h + pt), range(pl, in_w + pl), + range(in_c)) + + pad_np[np.ix_(*no_zero)] = np_data + ret_np = np.zeros(shape=out_shape).astype(dtype) + if pool_type == 'avg': + for i in range(out_shape[height_axis]): + for j in range(out_shape[width_axis]): + if exclusive: + pad_exclusive = pad_np.copy() + pad_exclusive[np.ix_(*no_zero)] = 1 + if data_format == "NCHW": + pad_count = np.sum( + pad_exclusive[:, :, i * s_h:i * s_h + + k_h, j * s_w:j * s_w + k_w] == 1, + axis=(height_axis, width_axis)) + ret_np[:, :, i, j] = np.sum( + pad_np[:, :, i * s_h:i * s_h + + k_h, j * s_w:j * s_w + k_w], + axis=(height_axis, width_axis)) / np.maximum( + pad_count, 1) + else: + pad_count = np.sum( + pad_exclusive[:, i * s_h:i * s_h + + k_h, j * s_w:j * s_w + k_w, :] == 1, + axis=(height_axis, width_axis)) + ret_np[:, i, j, :] = np.sum( + pad_np[:, i * s_h:i * s_h + k_h, j * s_w:j * s_w + + k_w, :], + axis=(height_axis, width_axis)) / np.maximum( + pad_count, 1) + else: + if data_format == "NCHW": + ret_np[:, :,i, j] = \ + np.mean(pad_np[:, :, + i * s_h: i * s_h + k_h, + j * s_w: j * s_w + k_w], axis=(height_axis, width_axis)) + else: + ret_np[:, i, j, :] = \ + np.mean(pad_np[:, + i * s_h: i * s_h + k_h, + j * s_w: j * s_w + k_w, :], axis=(height_axis, width_axis)) + elif pool_type == 'max': + for i in range(out_shape[height_axis]): + for j in range(out_shape[width_axis]): + if data_format == "NCHW": + ret_np[:, :, i, j] = np.max( + pad_np[:, :, i * s_h:i * s_h + k_h, j * s_w:j * s_w + + k_w], + axis=(height_axis, width_axis)) + else: + ret_np[:, i, j, :] = np.max( + pad_np[:, i * s_h:i * s_h + k_h, j * s_w:j * s_w + + k_w, :], + axis=(height_axis, width_axis)) + else: + raise ValueError("pool type {} is not supported".format(pool_type)) + + ret_np = np.maximum(ret_np, fill_value) + return ret_np, [out_shape0, out_shape] + + +def pool3d(np_data, attrs, dtype="float32"): + pool_type = "max" + ceil_mode = False + exclusive = True + data_format = "NCDHW" + for key in attrs.attr_store: + if key == "kernel_size": + kernel_size = attrs.get_attr("kernel_size") + elif key == "stride_size": + stride_size = attrs.get_attr("stride_size") + elif key == "padding_size": + padding_size = attrs.get_attr("padding_size") + elif key == "pool_type": + pool_type = attrs.get_attr("pool_type") + elif key == "ceil_mode": + ceil_mode = attrs.get_attr("ceil_mode") + elif key == "exclusive": + exclusive = attrs.get_attr("exclusive") + elif key == "data_format": + data_format = attrs.get_attr("data_format") + else: + raise ValueError("attr_store {} is not supported".format(key)) + + if data_format == "NCDHW": + in_n, in_c, in_d, in_h, in_w = in_shape = np_data.shape + depth_axis = 2 + height_axis = 3 + width_axis = 4 + elif data_format == "NDHWC": + in_n, in_d, in_h, in_w, in_c = in_shape = np_data.shape + depth_axis = 1 + height_axis = 2 + width_axis = 3 + else: + raise ValueError("data_format {} is not supported".format(data_format)) + + if isinstance(kernel_size, int): + k_d = k_h = k_w = kernel_size + else: + k_d, k_h, k_w = kernel_size + if isinstance(stride_size, int): + s_d = s_h = s_w = stride_size + else: + s_d, s_h, s_w = stride_size + if isinstance(padding_size, int): + pf = pt = pl = pk = pb = pr = padding_size + else: + pf, pt, pl, pk, pb, pr = padding_size + + out_shape0 = list(in_shape) + out_shape0[depth_axis] = in_shape[depth_axis] + pf + pk + out_shape0[height_axis] = in_shape[height_axis] + pt + pb + out_shape0[width_axis] = in_shape[width_axis] + pl + pr + + out_shape = list(in_shape) + if ceil_mode: + out_shape0[depth_axis] += s_d - 1 + out_shape0[height_axis] += s_h - 1 + out_shape0[width_axis] += s_w - 1 + out_shape[depth_axis] = int( + math.ceil(float(in_shape[depth_axis] - k_d + pf + pk) / s_d) + 1) + out_shape[height_axis] = int( + math.ceil(float(in_shape[height_axis] - k_h + pt + pb) / s_h) + 1) + out_shape[width_axis] = int( + math.ceil(float(in_shape[width_axis] - k_w + pl + pr) / s_w) + 1) + else: + out_shape[depth_axis] = int( + math.floor(float(in_shape[depth_axis] - k_d + pf + pk) / s_d) + 1) + out_shape[height_axis] = int( + math.floor(float(in_shape[height_axis] - k_h + pt + pb) / s_h) + 1) + out_shape[width_axis] = int( + math.floor(float(in_shape[width_axis] - k_w + pl + pr) / s_w) + 1) + + fill_value = 0 + if exclusive and pool_type == 'max': + fill_value = sys.float_info.min + + if data_format == "NCDHW": + pad_np = np.full( + shape=(in_n, in_c, in_d + pf + pk, in_h + pt + pb, in_w + pl + pr), + fill_value=fill_value, + dtype=dtype) + no_zero = (range(in_n), range(in_c), range(pf, in_d + pf), + range(pt, in_h + pt), range(pl, in_w + pl)) + else: + pad_np = np.full( + shape=(in_n, in_d + pf + pk, in_h + pt + pb, in_w + pl + pr, in_c), + fill_value=fill_value, + dtype=dtype) + no_zero = (range(in_n), range(pf, in_d + pf), range(pt, in_h + pt), + range(pl, in_w + pl), range(in_c)) + + pad_np[np.ix_(*no_zero)] = np_data + ret_np = np.zeros(shape=out_shape).astype(dtype) + if pool_type == 'avg': + for i in range(out_shape[depth_axis]): + for j in range(out_shape[height_axis]): + for k in range(out_shape[width_axis]): + if exclusive: + pad_exclusive = pad_np.copy() + pad_exclusive[np.ix_(*no_zero)] = 1 + if data_format == "NCDHW": + pad_count = np.sum( + pad_exclusive[:, :, i * s_d:i * s_d + + k_d, j * s_h:j * s_h + + k_h, k * s_w:k * s_w + k_w] == 1, + axis=(depth_axis, height_axis, width_axis)) + ret_np[:, :, i, j, k] = np.sum( + pad_np[:, :, i * s_d:i * s_d + k_d, j * s_h:j * + s_h + k_h, k * s_w:k * s_w + k_w], + axis=(depth_axis, height_axis, + width_axis)) / np.maximum(pad_count, 1) + else: + pad_count = np.sum( + pad_exclusive[:, i * s_d:i * s_d + + k_d, j * s_h:j * s_h + k_h, k * + s_w:k * s_w + k_w, :] == 1, + axis=(depth_axis, height_axis, width_axis)) + ret_np[:, i, j, k, :] = np.sum( + pad_np[:, i * s_d:i * s_d + k_d, j * s_h:j * + s_h + k_h, k * s_w:k * s_w + k_w, :], + axis=(depth_axis, height_axis, + width_axis)) / np.maximum(pad_count, 1) + else: + if data_format == "NCDHW": + ret_np[:, :,i, j, k] = \ + np.mean(pad_np[:, :, + i * s_d: i * s_d + k_d, + j * s_h: j * s_h + k_h, + k * s_w: k * s_w + k_w], axis=(depth_axis, height_axis, width_axis)) + else: + ret_np[:, i, j, k, :] = \ + np.mean(pad_np[:, + i * s_d: i * s_d + k_d, + j * s_h: j * s_h + k_h, + k * s_w: k * s_w + k_w, + :], axis=(depth_axis, height_axis, width_axis)) + elif pool_type == 'max': + for i in range(out_shape[depth_axis]): + for j in range(out_shape[height_axis]): + for k in range(out_shape[width_axis]): + if data_format == "NCDHW": + ret_np[:, :, i, j, k] = np.max( + pad_np[:, :, i * s_d:i * s_d + k_d, j * + s_h:j * s_h + k_h, k * s_w:k * s_w + k_w], + axis=(depth_axis, height_axis, width_axis)) + else: + ret_np[:, i, j, k, :] = np.max( + pad_np[:, i * s_d:i * s_d + k_d, j * s_h:j * s_h + + k_h, k * s_w:k * s_w + k_w, :], + axis=(depth_axis, height_axis, width_axis)) + else: + raise ValueError("pool type {} is not supported".format(pool_type)) + + ret_np = np.maximum(ret_np, fill_value) + return ret_np, [out_shape0, out_shape] + + +def pool1d(np_data, attrs, dtype="float32"): + pool_type = "max" + ceil_mode = False + exclusive = True + data_format = "NCW" + for key in attrs.attr_store: + if key == "kernel_size": + kernel_size = attrs.get_attr("kernel_size") + elif key == "stride_size": + stride_size = attrs.get_attr("stride_size") + elif key == "padding_size": + padding_size = attrs.get_attr("padding_size") + elif key == "pool_type": + pool_type = attrs.get_attr("pool_type") + elif key == "ceil_mode": + ceil_mode = attrs.get_attr("ceil_mode") + elif key == "exclusive": + exclusive = attrs.get_attr("exclusive") + elif key == "data_format": + data_format = attrs.get_attr("data_format") + else: + raise ValueError("attr_store {} is not supported".format(key)) + + if data_format == "NCW": + in_n, in_c, in_w = in_shape = np_data.shape + width_axis = 2 + elif data_format == "NWC": + in_n, in_w, in_c = in_shape = np_data.shape + width_axis = 1 + else: + raise ValueError("data_format {} is not supported".format(data_format)) + + if isinstance(kernel_size, int): + k_w = kernel_size + else: + k_w, = kernel_size + if isinstance(stride_size, int): + s_w = stride_size + else: + s_w, = stride_size + if isinstance(padding_size, int): + pl = pr = padding_size + else: + pl, pr = padding_size + + out_shape0 = list(in_shape) + out_shape0[width_axis] = in_shape[width_axis] + pl + pr + + out_shape = list(in_shape) + if ceil_mode: + out_shape0[width_axis] += s_w - 1 + out_shape[width_axis] = int( + math.ceil(float(in_shape[width_axis] - k_w + pl + pr) / s_w) + 1) + else: + out_shape[width_axis] = int( + math.floor(float(in_shape[width_axis] - k_w + pl + pr) / s_w) + 1) + + fill_value = 0 + if exclusive and pool_type == 'max': + fill_value = sys.float_info.min + + if data_format == "NCW": + pad_np = np.full( + shape=(in_n, in_c, in_w + pl + pr), + fill_value=fill_value, + dtype=dtype) + no_zero = (range(in_n), range(in_c), range(pl, in_w + pl)) + else: + pad_np = np.full( + shape=(in_n, in_w + pl + pr, in_c), + fill_value=fill_value, + dtype=dtype) + no_zero = (range(in_n), range(pl, in_w + pl), range(in_c)) + + pad_np[np.ix_(*no_zero)] = np_data + ret_np = np.zeros(shape=out_shape).astype(dtype) + if pool_type == 'avg': + for i in range(out_shape[width_axis]): + if exclusive: + pad_exclusive = pad_np.copy() + pad_exclusive[np.ix_(*no_zero)] = 1 + if data_format == "NCW": + pad_count = np.sum( + pad_exclusive[:, :, i * s_w:i * s_w + k_w] == 1, + axis=width_axis) + ret_np[:, :, i] = np.sum( + pad_np[:, :, i * s_w:i * s_w + k_w], + axis=width_axis) / np.maximum(pad_count, 1) + else: + pad_count = np.sum( + pad_exclusive[:, i * s_w:i * s_w + k_w, :] == 1, + axis=width_axis) + ret_np[:, i, :] = np.sum( + pad_np[:, i * s_w:i * s_w + k_w, :], + axis=width_axis) / np.maximum(pad_count, 1) + else: + if data_format == "NCW": + ret_np[:, :, i] = \ + np.mean(pad_np[:, :, + i * s_w: i * s_w + k_w], axis=width_axis) + else: + ret_np[:, i, :] = \ + np.mean(pad_np[:, + i * s_w: i * s_w + k_w, + :], axis=width_axis) + elif pool_type == 'max': + for k in range(out_shape[width_axis]): + if data_format == "NCW": + ret_np[:, :, k] = np.max( + pad_np[:, :, k * s_w:k * s_w + k_w], axis=width_axis) + else: + ret_np[:, k, :] = np.max( + pad_np[:, k * s_w:k * s_w + k_w, :], axis=width_axis) + else: + raise ValueError("pool type {} is not supported".format(pool_type)) + + ret_np = np.maximum(ret_np, fill_value) + return ret_np, [out_shape0, out_shape] diff --git a/python/tests/test_op_nn.py b/python/tests/test_op_nn.py index 57760bc134..8d978fe430 100644 --- a/python/tests/test_op_nn.py +++ b/python/tests/test_op_nn.py @@ -12,6 +12,7 @@ from cinn.poly import create_stages import logging from test_utils import SingleOpTester +import pool_utils class OpTest_relu(SingleOpTester): @@ -50,6 +51,186 @@ def test_op(self): [[1, 3, 12, 12], [2, 3, 3, 3], [1, 2, 5, 5]], "conv2d", attrs) """ +class OpTest_pool1d(SingleOpTester): + attrs = framework.NodeAttr() + attrs.attr_store = { + "kernel_size": [2], + "stride_size": [2], + "padding_size": [1, 1], + "pool_type": "max", + "ceil_mode": False, + "exclusive": True, + "data_format": "NCW" + } + + def create_target_data(self, inputs_data): + return pool_utils.pool1d(inputs_data[0], self.attrs) + + def test_op(self): + input_shape = [1, 3, 8] + self.to_test_op([input_shape], None, "pool1d", self.attrs) + + +class OpTest_pool1d_1(SingleOpTester): + attrs = framework.NodeAttr() + attrs.attr_store = { + "kernel_size": [2], + "stride_size": [2], + "padding_size": [2, 3], + "pool_type": "avg", + "ceil_mode": False, + "exclusive": True, + "data_format": "NCW" + } + + def create_target_data(self, inputs_data): + return pool_utils.pool1d(inputs_data[0], self.attrs) + + def test_op(self): + input_shape = [1, 3, 8] + self.to_test_op([input_shape], None, "pool1d", self.attrs) + + +class OpTest_pool1d_2(SingleOpTester): + attrs = framework.NodeAttr() + attrs.attr_store = { + "kernel_size": [2], + "stride_size": [3], + "padding_size": [4, 5], + "pool_type": "avg", + "ceil_mode": True, + "exclusive": False, + "data_format": "NWC" + } + + def create_target_data(self, inputs_data): + return pool_utils.pool1d(inputs_data[0], self.attrs) + + def test_op(self): + input_shape = [1, 8, 3] + self.to_test_op([input_shape], None, "pool1d", self.attrs) + + +class OpTest_pool2d(SingleOpTester): + attrs = framework.NodeAttr() + attrs.attr_store = { + "kernel_size": [2, 2], + "stride_size": [2, 2], + "padding_size": [1, 1, 1, 1], + "pool_type": "max", + "ceil_mode": False, + "exclusive": True, + "data_format": "NCHW" + } + + def create_target_data(self, inputs_data): + return pool_utils.pool2d(inputs_data[0], self.attrs) + + def test_op(self): + input_shape = [1, 3, 8, 8] + self.to_test_op([input_shape], None, "pool2d", self.attrs) + + +class OpTest_pool2d_1(SingleOpTester): + attrs = framework.NodeAttr() + attrs.attr_store = { + "kernel_size": [2, 2], + "stride_size": [2, 2], + "padding_size": [2, 3, 4, 5], + "pool_type": "avg", + "ceil_mode": False, + "exclusive": True, + "data_format": "NCHW" + } + + def create_target_data(self, inputs_data): + return pool_utils.pool2d(inputs_data[0], self.attrs) + + def test_op(self): + input_shape = [1, 3, 8, 8] + self.to_test_op([input_shape], None, "pool2d", self.attrs) + + +class OpTest_pool2d_2(SingleOpTester): + attrs = framework.NodeAttr() + attrs.attr_store = { + "kernel_size": [2, 2], + "stride_size": [3, 3], + "padding_size": [2, 3, 4, 5], + "pool_type": "avg", + "ceil_mode": True, + "exclusive": False, + "data_format": "NHWC" + } + + def create_target_data(self, inputs_data): + return pool_utils.pool2d(inputs_data[0], self.attrs) + + def test_op(self): + input_shape = [1, 8, 8, 3] + self.to_test_op([input_shape], None, "pool2d", self.attrs) + + +class OpTest_pool3d(SingleOpTester): + attrs = framework.NodeAttr() + attrs.attr_store = { + "kernel_size": [2, 2, 2], + "stride_size": [2, 2, 2], + "padding_size": [1, 2, 3, 4, 5, 6], + "pool_type": "max", + "ceil_mode": False, + "exclusive": True, + "data_format": "NCDHW" + } + + def create_target_data(self, inputs_data): + return pool_utils.pool3d(inputs_data[0], self.attrs) + + def test_op(self): + input_shape = [2, 3, 8, 8, 8] + self.to_test_op([input_shape], None, "pool3d", self.attrs) + + +class OpTest_pool3d_1(SingleOpTester): + attrs = framework.NodeAttr() + attrs.attr_store = { + "kernel_size": [2, 2, 2], + "stride_size": [2, 2, 2], + "padding_size": [1, 1, 1, 1, 1, 1], + "pool_type": "avg", + "ceil_mode": False, + "exclusive": True, + "data_format": "NCDHW" + } + + def create_target_data(self, inputs_data): + return pool_utils.pool3d(inputs_data[0], self.attrs) + + def test_op(self): + input_shape = [1, 3, 8, 8, 8] + self.to_test_op([input_shape], None, "pool3d", self.attrs) + + +class OpTest_pool3d_2(SingleOpTester): + attrs = framework.NodeAttr() + attrs.attr_store = { + "kernel_size": [2, 2, 2], + "stride_size": [2, 2, 2], + "padding_size": [1, 2, 3, 4, 5, 6], + "pool_type": "avg", + "ceil_mode": True, + "exclusive": False, + "data_format": "NDHWC" + } + + def create_target_data(self, inputs_data): + return pool_utils.pool3d(inputs_data[0], self.attrs) + + def test_op(self): + input_shape = [1, 8, 8, 8, 3] + self.to_test_op([input_shape], None, "pool3d", self.attrs) + + class OpTest_batchnorm(SingleOpTester): def create_target_data(self, inputs_data): [X, Y] = inputs_data diff --git a/python/tests/test_utils.py b/python/tests/test_utils.py index 841631ee2c..6267bb8f97 100644 --- a/python/tests/test_utils.py +++ b/python/tests/test_utils.py @@ -63,7 +63,21 @@ def to_test_op(self, input_shapes, output_shape, op_name, attrs): inputs.append( lang.Placeholder("float32", self.__gen_var_name(), expr_shape).to_tensor()) + + args = [] + temp_inputs = [] + for in_data in inputs_data: + temp_inputs.append( + runtime.cinn_buffer_t(in_data, runtime.cinn_x86_device)) + for in_data in temp_inputs: + args.append(runtime.cinn_pod_value_t(in_data)) + if output_shape == None: + correct_result, output_shape = self.create_target_data(inputs_data) + else: + correct_result = self.create_target_data(inputs_data) + module = self.__codegen(op_name, inputs, attrs) + self.compiler.build(module) fn = self.compiler.lookup(op_name) out = [] @@ -73,20 +87,11 @@ def to_test_op(self, input_shapes, output_shape, op_name, attrs): np.zeros(out_shape).astype("float32"), runtime.cinn_x86_device)) - args = [] - temp_inputs = [] - for in_data in inputs_data: - temp_inputs.append( - runtime.cinn_buffer_t(in_data, runtime.cinn_x86_device)) - for in_data in temp_inputs: - args.append(runtime.cinn_pod_value_t(in_data)) for out_data in out: args.append(runtime.cinn_pod_value_t(out_data)) - fn(args) out_result = out[len(out) - 1].numpy() - correct_result = self.create_target_data(inputs_data) self.assertTrue(np.allclose(out_result, correct_result, atol=1e-4)) def __codegen(self, op_name, inputs, attrs):