From b008432a6e4fe46f8c326614f3c9462f66567666 Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Mon, 28 Nov 2022 22:39:48 +0800 Subject: [PATCH 1/8] add logical_right_shift op --- cinn/frontend/net_builder.cc | 1 + cinn/frontend/net_builder.h | 3 +- cinn/hlir/op/contrib/CMakeLists.txt | 2 + cinn/hlir/op/contrib/logical_right_shift.cc | 133 ++++++++++++++++++ cinn/hlir/op/contrib/logical_right_shift.h | 32 +++++ .../op/contrib/logical_right_shift_test.cc | 64 +++++++++ cinn/hlir/op/use_ops.h | 1 + .../tests/ops/test_logical_right_shift_op.py | 77 ++++++++++ 8 files changed, 312 insertions(+), 1 deletion(-) create mode 100644 cinn/hlir/op/contrib/logical_right_shift.cc create mode 100644 cinn/hlir/op/contrib/logical_right_shift.h create mode 100644 cinn/hlir/op/contrib/logical_right_shift_test.cc create mode 100644 python/tests/ops/test_logical_right_shift_op.py diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index 7fa38d99ab..43afc92b49 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -174,6 +174,7 @@ NETBUILDER_BINARY_OP_DEF(Equal, equal); NETBUILDER_BINARY_OP_DEF(NotEqual, not_equal); NETBUILDER_BINARY_OP_DEF(GreaterEqual, greater_equal); NETBUILDER_BINARY_OP_DEF(LessEqual, less_equal); +NETBUILDER_BINARY_OP_DEF(LogicalRightShift, logical_right_shift); #undef NETBUILDER_BINARY_OP_DEF diff --git a/cinn/frontend/net_builder.h b/cinn/frontend/net_builder.h index d3ce86bc8d..7f92067443 100644 --- a/cinn/frontend/net_builder.h +++ b/cinn/frontend/net_builder.h @@ -104,7 +104,8 @@ namespace frontend { macro__(GreaterThan) \ macro__(LessThan) \ macro__(GreaterEqual) \ - macro__(LessEqual) + macro__(LessEqual) \ + macro__(LogicalRightShift) // ******************************************* // // Reduce array elements over the given dims. diff --git a/cinn/hlir/op/contrib/CMakeLists.txt b/cinn/hlir/op/contrib/CMakeLists.txt index fde94ad153..6d55d1c7dd 100644 --- a/cinn/hlir/op/contrib/CMakeLists.txt +++ b/cinn/hlir/op/contrib/CMakeLists.txt @@ -19,6 +19,7 @@ gather_srcs(cinnapi_src SRCS clz.cc popc.cc reciprocal.cc + logical_right_shift.cc ) cc_test(test_cast SRCS cast_test.cc DEPS cinncore) @@ -38,3 +39,4 @@ cc_test(test_lookup_table SRCS lookup_table_test.cc DEPS cinncore) cc_test(test_clz SRCS clz_test.cc DEPS cinncore) cc_test(test_popc SRCS popc_test.cc DEPS cinncore) cc_test(test_reciprocal SRCS reciprocal_test.cc DEPS cinncore) +cc_test(test_logical_right_shift SRCS logical_right_shift_test.cc DEPS cinncore) diff --git a/cinn/hlir/op/contrib/logical_right_shift.cc b/cinn/hlir/op/contrib/logical_right_shift.cc new file mode 100644 index 0000000000..404c6c6d5f --- /dev/null +++ b/cinn/hlir/op/contrib/logical_right_shift.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "cinn/common/cas.h" +#include "cinn/common/common.h" +#include "cinn/common/context.h" +#include "cinn/common/macros.h" +#include "cinn/common/target.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/op/contrib/clip.h" +#include "cinn/hlir/pe/ir_schedule_pe.h" +#include "cinn/hlir/pe/nn.h" +#include "cinn/hlir/pe/schedule.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" +#include "gflags/gflags.h" + +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace hlir { +namespace op { + +using common::_CINNValuePack_; +using common::CINNValue; +using common::CINNValuePack; +using framework::OpStrategy; +using framework::shape_t; +using framework::StrategyFunction; + +ir::Tensor LogicalRightShift(const ir::Tensor &A, const ir::Tensor &B, const std::string &output_name) { + return Compute( + A->shape, + [=](const std::vector &indices) { + Expr bits = ir::Cast::Make(A->type(), A->type().bits() - 1); + return lang::BitwiseAnd(lang::RightShift(A(indices), B(indices)), lang::BitwiseNot(lang::LeftShift(lang::RightShift(lang::LeftShift(Expr(1), bits), B(indices)), Expr(1)))); + }, + UniqName(output_name)); +} + +std::shared_ptr StrategyForLogicalRightShift(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + std::string op_name("logical_right_shift"); + + framework::CINNCompute logical_right_shift_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 2U) << "2 input tensors for " << op_name << " compute\n"; + + Expr A_expr = pack_args[0]; + Expr B_expr = pack_args[1]; + CHECK(A_expr.as_tensor()); + CHECK(B_expr.as_tensor()); + ir::Tensor A = A_expr.as_tensor_ref(); + ir::Tensor B = B_expr.as_tensor_ref(); + + std::string tensor_name = UniqName("T_LogicalRightShift_out"); + + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 3U); + tensor_name = pack_args[2].operator std::string(); + } + + auto out = LogicalRightShift(A, B, tensor_name); + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(logical_right_shift_compute, + framework::GetInjectiveScheduleFunc(output_shapes, target), + "strategy.logical_right_shift.x86", + 1); + return strategy; +} + +std::vector InferShapeForLogicalRightShift(const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; + CHECK_EQ(inputs_shape[0].size(), inputs_shape[1].size()) << "The inputs' dims should be equal."; + std::vector res{inputs_shape[0]}; + return res; +} + +std::vector InferDtypeForLogicalRightShift(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + std::vector res{inputs_type[0]}; + return res; +} + +} // namespace op +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(logical_right_shift_ops) { + CINN_REGISTER_OP(logical_right_shift) + .describe("Logical Right Shift.") + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForLogicalRightShift) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForLogicalRightShift)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForLogicalRightShift)) + .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_support_level(4); + + return true; +} diff --git a/cinn/hlir/op/contrib/logical_right_shift.h b/cinn/hlir/op/contrib/logical_right_shift.h new file mode 100644 index 0000000000..1ef79bb135 --- /dev/null +++ b/cinn/hlir/op/contrib/logical_right_shift.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/tensor.h" + +namespace cinn { +namespace hlir { +namespace op { + +ir::Tensor LogicalRightShift(const ir::Tensor& A, const ir::Tensor& B, const std::string& output_name); + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/contrib/logical_right_shift_test.cc b/cinn/hlir/op/contrib/logical_right_shift_test.cc new file mode 100644 index 0000000000..3a67a42c93 --- /dev/null +++ b/cinn/hlir/op/contrib/logical_right_shift_test.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/logical_right_shift.h" + +#include +#include + +#include +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/codegen_c_x86.h" +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/common/context.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace hlir { +namespace op { + +TEST(GenerateCode_Cpu, LogicalRightShift) { + common::Context::Global().ResetNameId(); + + common::Target target = common::DefaultHostTarget(); + lang::Placeholder x("x", std::vector{10}); + lang::Placeholder y("y", std::vector{10}); + ir::Tensor res = LogicalRightShift(x, y, "test_logical_right_shift"); + + poly::StageMap stages = poly::CreateStages({res}); + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_LogicalRightShift", stages, {res}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CPU codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("LogicalRightShift_Module", target); + for (auto& f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + VLOG(6) << "Cpu Codegen result:"; + VLOG(6) << code << std::endl; +} + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/use_ops.h b/cinn/hlir/op/use_ops.h index 8720b74f41..54aa9c5f8b 100644 --- a/cinn/hlir/op/use_ops.h +++ b/cinn/hlir/op/use_ops.h @@ -41,3 +41,4 @@ CINN_USE_REGISTER(lookup_table_ops) CINN_USE_REGISTER(clz_ops) CINN_USE_REGISTER(popc_ops) CINN_USE_REGISTER(reciprocal_ops) +CINN_USE_REGISTER(logical_right_shift_ops) diff --git a/python/tests/ops/test_logical_right_shift_op.py b/python/tests/ops/test_logical_right_shift_op.py new file mode 100644 index 0000000000..b0c3f5d9fc --- /dev/null +++ b/python/tests/ops/test_logical_right_shift_op.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2022 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest, OpTestTool +import paddle +import paddle.nn.functional as F +import cinn +from cinn.frontend import * +from cinn.common import * + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "x86 test will be skipped due to timeout.") +class TestLogicalRightShift(OpTest): + def setUp(self): + self.init_case() + + def init_case(self): + self.inputs = { + # "x": self.random([1, 24], 'int32', low = -2147483648, high=2147483647) + "x": np.array([1690476611, 142184466, -1752569340, 1860589058, -1295695292, + 1912939056, -1416770533, -483282486, 284237925, -2094465968, + -823026780, -1503970769, -535860601, 1515033359, -1212100470, + -2008734407, 704803066, 1861454881, -479224831, 1939718614, + -1903975007, -1197706543, 1327016838, -232019105]).astype(np.int32) + } + self.inputs = { + # "y": self.random([1, 24], 'int32', low = 0, high=32) + "y": np.array([20, 3, 12, 3, 0, 31, 0, 2, 6, 16, 1, 7, 6, 2, 19, 16, + 7, 17, 10, 15, 8, 9, 24, 4]).astype(np.int32) + } + self.outputs = {"out": np.array([[1612, 17773058, 620702, 232573632, -1295695292, + 0, -1416770533, 952921202, 4441217, 33576, + 1735970258, 21804660, 58736042, 378758339, 5880, + 34885, 5506273, 14201, 3726311, 59195, + 9339813, 6049337, 79, 253934261]]).astype(np.int32)} + + def build_paddle_program(self, target): + out = paddle.to_tensor(self.outputs["out"], stop_gradient=False) + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("logical_right_shift") + x = builder.create_input( + self.nptype2cinntype(self.inputs["x"].dtype), + self.inputs["x"].shape, "x") + y = builder.create_input( + self.nptype2cinntype(self.inputs["y"].dtype), + self.inputs["y"].shape, "y") + out = builder.logical_right_shift(x, y, axis=self.axis) + + prog = builder.build() + res = self.get_cinn_output(prog, target, [x, y], + [self.inputs["x"], self.inputs["y"]], [out]) + + self.cinn_outputs = [res[0]] + + def test_check_results(self): + self.check_outputs_and_grads() + +if __name__ == "__main__": + unittest.main() From b70e9ebe4cbb30db919753b9e369870bd4fe6c01 Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Wed, 30 Nov 2022 22:20:34 +0800 Subject: [PATCH 2/8] correct codestyle --- cinn/hlir/op/contrib/logical_right_shift.cc | 6 ++-- .../tests/ops/test_logical_right_shift_op.py | 35 ++++++++++++------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/cinn/hlir/op/contrib/logical_right_shift.cc b/cinn/hlir/op/contrib/logical_right_shift.cc index 404c6c6d5f..8f8ebee68d 100644 --- a/cinn/hlir/op/contrib/logical_right_shift.cc +++ b/cinn/hlir/op/contrib/logical_right_shift.cc @@ -55,7 +55,9 @@ ir::Tensor LogicalRightShift(const ir::Tensor &A, const ir::Tensor &B, const std A->shape, [=](const std::vector &indices) { Expr bits = ir::Cast::Make(A->type(), A->type().bits() - 1); - return lang::BitwiseAnd(lang::RightShift(A(indices), B(indices)), lang::BitwiseNot(lang::LeftShift(lang::RightShift(lang::LeftShift(Expr(1), bits), B(indices)), Expr(1)))); + return lang::BitwiseAnd( + lang::RightShift(A(indices), B(indices)), + lang::BitwiseNot(lang::LeftShift(lang::RightShift(lang::LeftShift(Expr(1), bits), B(indices)), Expr(1)))); }, UniqName(output_name)); } @@ -100,7 +102,7 @@ std::shared_ptr StrategyForLogicalRightShift(const framework::NodeAt } std::vector InferShapeForLogicalRightShift(const std::vector &inputs_shape, - const framework::AttrMapType &attrs) { + const framework::AttrMapType &attrs) { CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; CHECK_EQ(inputs_shape[0].size(), inputs_shape[1].size()) << "The inputs' dims should be equal."; std::vector res{inputs_shape[0]}; diff --git a/python/tests/ops/test_logical_right_shift_op.py b/python/tests/ops/test_logical_right_shift_op.py index b0c3f5d9fc..6ce7ff78c5 100644 --- a/python/tests/ops/test_logical_right_shift_op.py +++ b/python/tests/ops/test_logical_right_shift_op.py @@ -33,22 +33,32 @@ def setUp(self): def init_case(self): self.inputs = { # "x": self.random([1, 24], 'int32', low = -2147483648, high=2147483647) - "x": np.array([1690476611, 142184466, -1752569340, 1860589058, -1295695292, - 1912939056, -1416770533, -483282486, 284237925, -2094465968, - -823026780, -1503970769, -535860601, 1515033359, -1212100470, - -2008734407, 704803066, 1861454881, -479224831, 1939718614, - -1903975007, -1197706543, 1327016838, -232019105]).astype(np.int32) + "x": + np.array([ + 1690476611, 142184466, -1752569340, 1860589058, -1295695292, + 1912939056, -1416770533, -483282486, 284237925, -2094465968, + -823026780, -1503970769, -535860601, 1515033359, -1212100470, + -2008734407, 704803066, 1861454881, -479224831, 1939718614, + -1903975007, -1197706543, 1327016838, -232019105 + ]).astype(np.int32) } self.inputs = { # "y": self.random([1, 24], 'int32', low = 0, high=32) - "y": np.array([20, 3, 12, 3, 0, 31, 0, 2, 6, 16, 1, 7, 6, 2, 19, 16, - 7, 17, 10, 15, 8, 9, 24, 4]).astype(np.int32) + "y": + np.array([ + 20, 3, 12, 3, 0, 31, 0, 2, 6, 16, 1, 7, 6, 2, 19, 16, 7, 17, + 10, 15, 8, 9, 24, 4 + ]).astype(np.int32) + } + self.outputs = { + "out": + np.array([[ + 1612, 17773058, 620702, 232573632, -1295695292, 0, -1416770533, + 952921202, 4441217, 33576, 1735970258, 21804660, 58736042, + 378758339, 5880, 34885, 5506273, 14201, 3726311, 59195, + 9339813, 6049337, 79, 253934261 + ]]).astype(np.int32) } - self.outputs = {"out": np.array([[1612, 17773058, 620702, 232573632, -1295695292, - 0, -1416770533, 952921202, 4441217, 33576, - 1735970258, 21804660, 58736042, 378758339, 5880, - 34885, 5506273, 14201, 3726311, 59195, - 9339813, 6049337, 79, 253934261]]).astype(np.int32)} def build_paddle_program(self, target): out = paddle.to_tensor(self.outputs["out"], stop_gradient=False) @@ -73,5 +83,6 @@ def build_cinn_program(self, target): def test_check_results(self): self.check_outputs_and_grads() + if __name__ == "__main__": unittest.main() From 3cbe321885b7075b76db4875c472a5c53f33c75d Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Thu, 1 Dec 2022 09:41:29 +0800 Subject: [PATCH 3/8] fix bugs --- cinn/hlir/op/contrib/logical_right_shift.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/cinn/hlir/op/contrib/logical_right_shift.cc b/cinn/hlir/op/contrib/logical_right_shift.cc index 8f8ebee68d..f5a0abaa0f 100644 --- a/cinn/hlir/op/contrib/logical_right_shift.cc +++ b/cinn/hlir/op/contrib/logical_right_shift.cc @@ -25,7 +25,6 @@ #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" -#include "cinn/hlir/op/contrib/clip.h" #include "cinn/hlir/pe/ir_schedule_pe.h" #include "cinn/hlir/pe/nn.h" #include "cinn/hlir/pe/schedule.h" From 5400e8088dd7ca534cc08eb4e3dfef9f2e272c1d Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Thu, 1 Dec 2022 12:36:24 +0800 Subject: [PATCH 4/8] fix bugs --- python/tests/ops/test_logical_right_shift_op.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tests/ops/test_logical_right_shift_op.py b/python/tests/ops/test_logical_right_shift_op.py index 6ce7ff78c5..95299038b6 100644 --- a/python/tests/ops/test_logical_right_shift_op.py +++ b/python/tests/ops/test_logical_right_shift_op.py @@ -40,9 +40,7 @@ def init_case(self): -823026780, -1503970769, -535860601, 1515033359, -1212100470, -2008734407, 704803066, 1861454881, -479224831, 1939718614, -1903975007, -1197706543, 1327016838, -232019105 - ]).astype(np.int32) - } - self.inputs = { + ]).astype(np.int32), # "y": self.random([1, 24], 'int32', low = 0, high=32) "y": np.array([ From 4b7a936a23b32c090cc6c7937bcefbaf78cc9420 Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Thu, 1 Dec 2022 17:11:43 +0800 Subject: [PATCH 5/8] fix bugs --- python/tests/ops/test_logical_right_shift_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/ops/test_logical_right_shift_op.py b/python/tests/ops/test_logical_right_shift_op.py index 95299038b6..ad2872b556 100644 --- a/python/tests/ops/test_logical_right_shift_op.py +++ b/python/tests/ops/test_logical_right_shift_op.py @@ -70,7 +70,7 @@ def build_cinn_program(self, target): y = builder.create_input( self.nptype2cinntype(self.inputs["y"].dtype), self.inputs["y"].shape, "y") - out = builder.logical_right_shift(x, y, axis=self.axis) + out = builder.logical_right_shift(x, y) prog = builder.build() res = self.get_cinn_output(prog, target, [x, y], From 8e0fb776c8f808df0d22bfb2b370241e17dba712 Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Sat, 3 Dec 2022 18:23:02 +0800 Subject: [PATCH 6/8] fix support on GPU --- cinn/hlir/op/contrib/logical_right_shift.cc | 33 +++++++++++++++---- cinn/hlir/op/contrib/logical_right_shift.h | 5 ++- .../op/contrib/logical_right_shift_test.cc | 2 +- cinn/runtime/cpu/host_intrinsics.cc | 4 +++ cinn/runtime/cpu/host_intrinsics.h | 2 ++ .../runtime/cuda/cinn_cuda_runtime_source.cuh | 2 +- cinn/runtime/cuda/cuda_intrinsics.cc | 1 + 7 files changed, 39 insertions(+), 10 deletions(-) diff --git a/cinn/hlir/op/contrib/logical_right_shift.cc b/cinn/hlir/op/contrib/logical_right_shift.cc index f5a0abaa0f..afe1b26d76 100644 --- a/cinn/hlir/op/contrib/logical_right_shift.cc +++ b/cinn/hlir/op/contrib/logical_right_shift.cc @@ -49,16 +49,35 @@ using framework::OpStrategy; using framework::shape_t; using framework::StrategyFunction; -ir::Tensor LogicalRightShift(const ir::Tensor &A, const ir::Tensor &B, const std::string &output_name) { +ir::Tensor LogicalRightShift(const ir::Tensor &A, + const ir::Tensor &B, + const Target &target, + const std::string &output_name) { + std::string extern_func = "cinn_"; + if (target == common::DefaultHostTarget()) { + extern_func += "host_"; + } else if (target == common::DefaultNVGPUTarget()) { + extern_func += "nvgpu_"; + } else { + CINN_NOT_IMPLEMENTED + } + + extern_func += "logical_right_shift"; + + if (A->type().is_int(32) || A->type().is_uint(32)) { + extern_func += "_int32"; + } else { + CINN_NOT_IMPLEMENTED + } + return Compute( A->shape, [=](const std::vector &indices) { - Expr bits = ir::Cast::Make(A->type(), A->type().bits() - 1); - return lang::BitwiseAnd( - lang::RightShift(A(indices), B(indices)), - lang::BitwiseNot(lang::LeftShift(lang::RightShift(lang::LeftShift(Expr(1), bits), B(indices)), Expr(1)))); + Expr x = A(indices); + Expr y = B(indices); + return lang::CallExtern(extern_func, {x, y}); }, - UniqName(output_name)); + output_name); } std::shared_ptr StrategyForLogicalRightShift(const framework::NodeAttr &attrs, @@ -87,7 +106,7 @@ std::shared_ptr StrategyForLogicalRightShift(const framework::NodeAt tensor_name = pack_args[2].operator std::string(); } - auto out = LogicalRightShift(A, B, tensor_name); + auto out = LogicalRightShift(A, B, target, tensor_name); auto stages = CreateStages({out}); *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; }); diff --git a/cinn/hlir/op/contrib/logical_right_shift.h b/cinn/hlir/op/contrib/logical_right_shift.h index 1ef79bb135..cf1d0ec522 100644 --- a/cinn/hlir/op/contrib/logical_right_shift.h +++ b/cinn/hlir/op/contrib/logical_right_shift.h @@ -25,7 +25,10 @@ namespace cinn { namespace hlir { namespace op { -ir::Tensor LogicalRightShift(const ir::Tensor& A, const ir::Tensor& B, const std::string& output_name); +ir::Tensor LogicalRightShift(const ir::Tensor& A, + const ir::Tensor& B, + const Target& target, + const std::string& output_name); } // namespace op } // namespace hlir diff --git a/cinn/hlir/op/contrib/logical_right_shift_test.cc b/cinn/hlir/op/contrib/logical_right_shift_test.cc index 3a67a42c93..b93f364e89 100644 --- a/cinn/hlir/op/contrib/logical_right_shift_test.cc +++ b/cinn/hlir/op/contrib/logical_right_shift_test.cc @@ -38,7 +38,7 @@ TEST(GenerateCode_Cpu, LogicalRightShift) { common::Target target = common::DefaultHostTarget(); lang::Placeholder x("x", std::vector{10}); lang::Placeholder y("y", std::vector{10}); - ir::Tensor res = LogicalRightShift(x, y, "test_logical_right_shift"); + ir::Tensor res = LogicalRightShift(x, y, target, "test_logical_right_shift"); poly::StageMap stages = poly::CreateStages({res}); std::vector funcs = diff --git a/cinn/runtime/cpu/host_intrinsics.cc b/cinn/runtime/cpu/host_intrinsics.cc index 6ab8207df4..963da90168 100644 --- a/cinn/runtime/cpu/host_intrinsics.cc +++ b/cinn/runtime/cpu/host_intrinsics.cc @@ -130,6 +130,8 @@ inline int FN_INT32(clz)(int x) { return __builtin_clz(x); } inline int FN_INT32(popc)(int x) { return __builtin_popcount(x); } +inline int FN_INT32(logical_right_shift)(int x, int y) { return (x >> y) & ~(((0x1 << 31) >> y) << 1); } + #undef FN_INT32 #define FN_INT64(func) cinn_host_##func##_int64 @@ -188,6 +190,8 @@ CINN_REGISTER_HELPER(host_intrinsics) { REGISTER_EXTERN_FUNC_2_IN_1_INT32(pow) + REGISTER_EXTERN_FUNC_2_IN_1_INT32(logical_right_shift) + #undef REGISTER_EXTERN_FUNC_2_IN_1_INT32 REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_host_clz_int32, host_target, int, int); diff --git a/cinn/runtime/cpu/host_intrinsics.h b/cinn/runtime/cpu/host_intrinsics.h index 50c37900e4..fdffcaf668 100644 --- a/cinn/runtime/cpu/host_intrinsics.h +++ b/cinn/runtime/cpu/host_intrinsics.h @@ -53,6 +53,8 @@ inline int FN_INT32(clz)(int x); inline int FN_INT32(popc)(int x); +inline int FN_INT32(logical_right_shift)(int x, int y); + #undef FN_INT32 #define FN_INT64(func) cinn_host_##func##_int64 diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index da2da2f373..9f12c774f7 100644 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -71,7 +71,7 @@ __device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; } __device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; } __device__ inline int FN_INT32(clz)(int a) { return __clz(a); } __device__ inline int FN_INT32(popc)(int a) { return __popc(a); } - +__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { return (a >> b) & ~(((0x1 << 31) >> b) << 1); } // *************************************************************** // diff --git a/cinn/runtime/cuda/cuda_intrinsics.cc b/cinn/runtime/cuda/cuda_intrinsics.cc index 0acf54bde2..ccce3384f9 100644 --- a/cinn/runtime/cuda/cuda_intrinsics.cc +++ b/cinn/runtime/cuda/cuda_intrinsics.cc @@ -113,6 +113,7 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { REGISTER_EXTERN_FUNC_2_IN_1_INT32(bitwise_or) REGISTER_EXTERN_FUNC_2_IN_1_INT32(bitwise_xor) REGISTER_EXTERN_FUNC_2_IN_1_INT32(floor_divide) + REGISTER_EXTERN_FUNC_2_IN_1_INT32(logical_right_shift) #undef REGISTER_EXTERN_FUNC_2_IN_1_INT32 From bea4c205b7c4487743c018a356ca7a717cfefab6 Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Sun, 4 Dec 2022 20:21:41 +0800 Subject: [PATCH 7/8] fix bugs --- cinn/hlir/op/contrib/logical_right_shift.cc | 5 ++++- python/tests/ops/test_logical_right_shift_op.py | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/cinn/hlir/op/contrib/logical_right_shift.cc b/cinn/hlir/op/contrib/logical_right_shift.cc index afe1b26d76..a0d9936898 100644 --- a/cinn/hlir/op/contrib/logical_right_shift.cc +++ b/cinn/hlir/op/contrib/logical_right_shift.cc @@ -129,7 +129,10 @@ std::vector InferShapeForLogicalRightShift(const std::vector std::vector InferDtypeForLogicalRightShift(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + CHECK_EQ(inputs_type.size(), 2UL) << "The logical_right_shift op should has two inputs! Please check."; + CHECK_EQ(inputs_type[0], inputs_type[1]) + << "The data type of input tensors of logical_right_shift op should be equal, but here x:" << inputs_type[0] + << " != y:" << inputs_type[1] << "! Please check."; std::vector res{inputs_type[0]}; return res; } diff --git a/python/tests/ops/test_logical_right_shift_op.py b/python/tests/ops/test_logical_right_shift_op.py index ad2872b556..5d88ed34e4 100644 --- a/python/tests/ops/test_logical_right_shift_op.py +++ b/python/tests/ops/test_logical_right_shift_op.py @@ -34,19 +34,19 @@ def init_case(self): self.inputs = { # "x": self.random([1, 24], 'int32', low = -2147483648, high=2147483647) "x": - np.array([ + np.array([[ 1690476611, 142184466, -1752569340, 1860589058, -1295695292, 1912939056, -1416770533, -483282486, 284237925, -2094465968, -823026780, -1503970769, -535860601, 1515033359, -1212100470, -2008734407, 704803066, 1861454881, -479224831, 1939718614, -1903975007, -1197706543, 1327016838, -232019105 - ]).astype(np.int32), + ]]).astype(np.int32), # "y": self.random([1, 24], 'int32', low = 0, high=32) "y": - np.array([ + np.array([[ 20, 3, 12, 3, 0, 31, 0, 2, 6, 16, 1, 7, 6, 2, 19, 16, 7, 17, 10, 15, 8, 9, 24, 4 - ]).astype(np.int32) + ]]).astype(np.int32) } self.outputs = { "out": From 227fe3fc8208dd9c2642f1611a54d8222fa25b17 Mon Sep 17 00:00:00 2001 From: ccsuzzh <1719571694@qq.com> Date: Mon, 5 Dec 2022 22:50:20 +0800 Subject: [PATCH 8/8] fix logical_right_shift op bug --- cinn/hlir/op/contrib/logical_right_shift.cc | 3 ++- cinn/runtime/cpu/host_intrinsics.cc | 2 +- cinn/runtime/cuda/cinn_cuda_runtime_source.cuh | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cinn/hlir/op/contrib/logical_right_shift.cc b/cinn/hlir/op/contrib/logical_right_shift.cc index a0d9936898..6dcfad2b89 100644 --- a/cinn/hlir/op/contrib/logical_right_shift.cc +++ b/cinn/hlir/op/contrib/logical_right_shift.cc @@ -25,6 +25,7 @@ #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/op/op_util.h" #include "cinn/hlir/pe/ir_schedule_pe.h" #include "cinn/hlir/pe/nn.h" #include "cinn/hlir/pe/schedule.h" @@ -113,7 +114,7 @@ std::shared_ptr StrategyForLogicalRightShift(const framework::NodeAt auto strategy = std::make_shared(); strategy->AddImpl(logical_right_shift_compute, - framework::GetInjectiveScheduleFunc(output_shapes, target), + GetInjectiveScheduleFunc(output_shapes, target), "strategy.logical_right_shift.x86", 1); return strategy; diff --git a/cinn/runtime/cpu/host_intrinsics.cc b/cinn/runtime/cpu/host_intrinsics.cc index 7d95f80b68..3a3153a2a9 100644 --- a/cinn/runtime/cpu/host_intrinsics.cc +++ b/cinn/runtime/cpu/host_intrinsics.cc @@ -134,7 +134,7 @@ inline int FN_INT32(clz)(int x) { return __builtin_clz(x); } inline int FN_INT32(popc)(int x) { return __builtin_popcount(x); } -inline int FN_INT32(logical_right_shift)(int x, int y) { return (x >> y) & ~(((0x1 << 31) >> y) << 1); } +inline int FN_INT32(logical_right_shift)(int x, int y) { return ((unsigned int)x >> y); } #undef FN_INT32 diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index 6f0628ec59..dc2ccccc90 100644 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -73,7 +73,7 @@ __device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; } __device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; } __device__ inline int FN_INT32(clz)(int a) { return __clz(a); } __device__ inline int FN_INT32(popc)(int a) { return __popc(a); } -__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { return (a >> b) & ~(((0x1 << 31) >> b) << 1); } +__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { return ((unsigned int)a >> b); } // *************************************************************** //