From 11e0fbac1a6bc9811e5a1069db0b2921cf035ac8 Mon Sep 17 00:00:00 2001 From: ZeKai Zhou <30856589+zzk0@users.noreply.github.com> Date: Mon, 21 Nov 2022 17:19:17 +0800 Subject: [PATCH] add clz op (#1059) * add clz op * fix data type * fix datatype --- cinn/backends/codegen_c.cc | 2 + cinn/frontend/net_builder.cc | 1 + cinn/frontend/net_builder.h | 3 +- cinn/hlir/op/contrib/CMakeLists.txt | 2 + cinn/hlir/op/contrib/clz.cc | 143 ++++++++++++++++++ cinn/hlir/op/contrib/clz.h | 32 ++++ cinn/hlir/op/contrib/clz_test.cc | 99 ++++++++++++ cinn/hlir/op/use_ops.h | 1 + cinn/pybind/frontend.cc | 3 +- cinn/runtime/cpu/host_intrinsics.cc | 12 ++ cinn/runtime/cpu/host_intrinsics.h | 8 + .../runtime/cuda/cinn_cuda_runtime_source.cuh | 9 ++ cinn/runtime/cuda/cuda_intrinsics.cc | 8 + python/tests/ops/test_clz_op.py | 104 +++++++++++++ 14 files changed, 425 insertions(+), 2 deletions(-) create mode 100644 cinn/hlir/op/contrib/clz.cc create mode 100644 cinn/hlir/op/contrib/clz.h create mode 100644 cinn/hlir/op/contrib/clz_test.cc create mode 100644 python/tests/ops/test_clz_op.py diff --git a/cinn/backends/codegen_c.cc b/cinn/backends/codegen_c.cc index fd3504ef14..4cb9a57f1f 100644 --- a/cinn/backends/codegen_c.cc +++ b/cinn/backends/codegen_c.cc @@ -101,6 +101,8 @@ std::string CodeGenC::GetTypeName(Type type) { GET_SCALAR_TYPE(type.is_int(8), "int8_t"); GET_SCALAR_TYPE(type.is_int(32), "int32_t"); GET_SCALAR_TYPE(type.is_int(64), "int64_t"); + GET_SCALAR_TYPE(type.is_uint(32), "uint32_t"); + GET_SCALAR_TYPE(type.is_uint(64), "uint64_t"); GET_SCALAR_TYPE(type.is_float(32), "float") GET_SCALAR_TYPE(type.is_float(64), "double") #undef GET_SCALAR_TYPE diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index dfae2206de..31ffea4d64 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -139,6 +139,7 @@ NETBUILDER_UNARY_OP_DEF(BitwiseNot, bitwise_not) NETBUILDER_UNARY_OP_DEF(Negative, negative) NETBUILDER_UNARY_OP_DEF(Sign, sign) NETBUILDER_UNARY_OP_DEF(Abs, abs) +NETBUILDER_UNARY_OP_DEF(Clz, clz) #undef NETBUILDER_UNARY_OP_DEF diff --git a/cinn/frontend/net_builder.h b/cinn/frontend/net_builder.h index f707802e6d..1544941580 100644 --- a/cinn/frontend/net_builder.h +++ b/cinn/frontend/net_builder.h @@ -71,7 +71,8 @@ namespace frontend { macro__(BitwiseNot) \ macro__(Negative) \ macro__(Sign) \ - macro__(Abs) + macro__(Abs) \ + macro__(Clz) // ******************************************* // // The op has two input and one output, with a attribute [axis] diff --git a/cinn/hlir/op/contrib/CMakeLists.txt b/cinn/hlir/op/contrib/CMakeLists.txt index ce62df9a76..061ba8f5dc 100644 --- a/cinn/hlir/op/contrib/CMakeLists.txt +++ b/cinn/hlir/op/contrib/CMakeLists.txt @@ -16,6 +16,7 @@ gather_srcs(cinnapi_src SRCS repeat.cc lookup_table.cc one_hot.cc + clz.cc ) cc_test(test_cast SRCS cast_test.cc DEPS cinncore) @@ -32,3 +33,4 @@ cc_test(test_flip SRCS flip_test.cc DEPS cinncore) cc_test(test_repeat SRCS repeat_test.cc DEPS cinncore) cc_test(test_one_hot SRCS one_hot_test.cc DEPS cinncore) cc_test(test_lookup_table SRCS lookup_table_test.cc DEPS cinncore) +cc_test(test_clz SRCS clz_test.cc DEPS cinncore) diff --git a/cinn/hlir/op/contrib/clz.cc b/cinn/hlir/op/contrib/clz.cc new file mode 100644 index 0000000000..50cb66001b --- /dev/null +++ b/cinn/hlir/op/contrib/clz.cc @@ -0,0 +1,143 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "cinn/common/cas.h" +#include "cinn/common/common.h" +#include "cinn/common/context.h" +#include "cinn/common/macros.h" +#include "cinn/common/target.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/op/contrib/clip.h" +#include "cinn/hlir/pe/ir_schedule_pe.h" +#include "cinn/hlir/pe/nn.h" +#include "cinn/hlir/pe/schedule.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" +#include "gflags/gflags.h" + +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace hlir { +namespace op { + +using common::_CINNValuePack_; +using common::CINNValue; +using common::CINNValuePack; +using framework::OpStrategy; +using framework::shape_t; +using framework::StrategyFunction; + +ir::Tensor Clz(const ir::Tensor &input, const Target &target, const std::string &output_name) { + std::string extern_func = "cinn_"; + if (target == common::DefaultHostTarget()) { + extern_func += "host_"; + } else if (target == common::DefaultNVGPUTarget()) { + extern_func += "nvgpu_"; + } else { + CINN_NOT_IMPLEMENTED + } + + extern_func += "clz"; + + if (input->type().is_int(32) || input->type().is_uint(32)) { + extern_func += "_int32"; + } else if (input->type().is_int(64) || input->type().is_uint(64)) { + extern_func += "_int64"; + } else { + CINN_NOT_IMPLEMENTED + } + + return Compute( + input->shape, + [=](const std::vector &indices) { + Expr e = input(indices); + return lang::CallExtern(extern_func, {e}); + }, + output_name); +} + +std::shared_ptr StrategyForClz(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + std::string op_name("clz"); + + framework::CINNCompute clz_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK(!pack_args.empty()) << "at least one input tensor for " << op_name << " compute\n"; + + std::string tensor_name = UniqName(op_name + "_Out"); + if (FLAGS_cinn_ir_schedule) { + CHECK_EQ(pack_args.size(), 2); + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } + + Expr A_expr = pack_args[0]; + CHECK(A_expr.as_tensor()); + ir::Tensor A = A_expr.as_tensor_ref(); + auto out = Clz(A, target, tensor_name); + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(clz_compute, framework::GetInjectiveScheduleFunc(output_shapes, target), "strategy.clz.x86", 1); + return strategy; +} + +std::vector InferShapeForClz(const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; + std::vector res{inputs_shape[0]}; + return res; +} + +std::vector InferDtypeForClz(const std::vector &inputs_type, const framework::AttrMapType &attrs) { + CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + std::vector res{inputs_type[0]}; + return res; +} + +} // namespace op +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(clz_ops) { + CINN_REGISTER_OP(clz) + .describe("Counting Leading Zeros.") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForClz) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForClz)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForClz)) + .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_support_level(4); + + return true; +} diff --git a/cinn/hlir/op/contrib/clz.h b/cinn/hlir/op/contrib/clz.h new file mode 100644 index 0000000000..90bf2b6a6b --- /dev/null +++ b/cinn/hlir/op/contrib/clz.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/tensor.h" + +namespace cinn { +namespace hlir { +namespace op { + +ir::Tensor Clz(const ir::Tensor& input, const Target& target, const std::string& output_name = "T_Clz_out"); + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/contrib/clz_test.cc b/cinn/hlir/op/contrib/clz_test.cc new file mode 100644 index 0000000000..7bf233bb0e --- /dev/null +++ b/cinn/hlir/op/contrib/clz_test.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/clz.h" + +#include +#include + +#include +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/codegen_c_x86.h" +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/common/context.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace hlir { +namespace op { +namespace { +bool IsCompiledWithCUDA() { +#if !defined(CINN_WITH_CUDA) + return false; +#else + return true; +#endif +} +} // namespace + +TEST(GenerateCode_Cpu, Clz) { + common::Context::Global().ResetNameId(); + + common::Target target = common::DefaultHostTarget(); + lang::Placeholder in("in", std::vector{10}); + ir::Tensor res = Clz(in, target, "test_clz"); + + poly::StageMap stages = poly::CreateStages({res}); + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_Clz", stages, {res}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CPU codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("Clz_Module", target); + for (auto& f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + VLOG(6) << "Cpu Codegen result:"; + VLOG(6) << code; +} + +TEST(GenerateCode_Cuda, Clz) { + if (!IsCompiledWithCUDA()) { + return; + } + common::Context::Global().ResetNameId(); + + common::Target target = common::DefaultNVGPUTarget(); + + lang::Placeholder in("in", std::vector{10}); + ir::Tensor res = Clz(in, target, "test_clz"); + + poly::StageMap stages = poly::CreateStages({res}); + stages[res]->Bind(0, "blockIdx.x"); + stages[res]->SetBuffer("global"); + + std::vector funcs = + lang::LowerVec("TestGenerateCodeCuda_Clz", stages, {res}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CUDA codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("Clz_Module", target); + for (auto& f : funcs) { + builder.AddFunction(f); + } +} + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/use_ops.h b/cinn/hlir/op/use_ops.h index 0371b31623..10681d0fa3 100644 --- a/cinn/hlir/op/use_ops.h +++ b/cinn/hlir/op/use_ops.h @@ -38,3 +38,4 @@ CINN_USE_REGISTER(gelu_ops) CINN_USE_REGISTER(repeat_ops) CINN_USE_REGISTER(one_hot_ops) CINN_USE_REGISTER(lookup_table_ops) +CINN_USE_REGISTER(clz_ops) diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc index 4b6ac26ff8..533d1b39a6 100644 --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -620,7 +620,8 @@ void BindFrontend(pybind11::module *m) { .def("clip", &NetBuilder::Clip, py::arg("x"), py::arg("max"), py::arg("min")) .def("arange", &NetBuilder::Arange, py::arg("start"), py::arg("end"), py::arg("step"), py::arg("dtype")) .def("gather", &NetBuilder::Gather, py::arg("x"), py::arg("index"), py::arg("axis")) - .def("gather_nd", &NetBuilder::GatherNd, py::arg("x"), py::arg("index"), py::arg("axes")); + .def("gather_nd", &NetBuilder::GatherNd, py::arg("x"), py::arg("index"), py::arg("axes")) + .def("clz", &NetBuilder::Clz, py::arg("x")); auto computation = py::class_>(*m, "Computation"); py::class_(computation, "CompileOptions") diff --git a/cinn/runtime/cpu/host_intrinsics.cc b/cinn/runtime/cpu/host_intrinsics.cc index c135ab82fb..92f1954d9f 100644 --- a/cinn/runtime/cpu/host_intrinsics.cc +++ b/cinn/runtime/cpu/host_intrinsics.cc @@ -126,7 +126,15 @@ inline int FN_INT32(pow)(int x, int y) { return res; } +inline int FN_INT32(clz)(int x) { return __builtin_clz(x); } + #undef FN_INT32 + +#define FN_INT64(func) cinn_host_##func##_int64 + +inline int64_t FN_INT64(clz)(int64_t x) { return __builtin_clzll(x); } + +#undef FN_INT64 } CINN_REGISTER_HELPER(host_intrinsics) { @@ -178,6 +186,10 @@ CINN_REGISTER_HELPER(host_intrinsics) { #undef REGISTER_EXTERN_FUNC_2_IN_1_INT32 + REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_host_clz_int32, host_target, int, int); + + REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_host_clz_int64, host_target, int64_t, int64_t); + REGISTER_EXTERN_FUNC_HELPER(cinn_host_find_int, host_target) .SetRetType() .AddInputType() diff --git a/cinn/runtime/cpu/host_intrinsics.h b/cinn/runtime/cpu/host_intrinsics.h index 7242b8a2e3..ef3211b9d9 100644 --- a/cinn/runtime/cpu/host_intrinsics.h +++ b/cinn/runtime/cpu/host_intrinsics.h @@ -49,5 +49,13 @@ inline int cinn_host_gt_num_int( inline int FN_INT32(pow)(int x, int y); +inline int FN_INT32(clz)(int x); + #undef FN_INT32 + +#define FN_INT64(func) cinn_host_##func##_int64 + +inline int64_t FN_INT64(clz)(int64_t x); + +#undef FN_INT64 } diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index 8743d27e81..866f63d903 100644 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -67,6 +67,14 @@ __device__ inline int FN_INT32(bitwise_and)(int a, int b) { return a & b; } __device__ inline int FN_INT32(bitwise_or)(int a, int b) { return a | b; } __device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; } __device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; } +__device__ inline int FN_INT32(clz)(int a) { return __clz(a); } + + +// *************************************************************** // +// int64 unary and binary operator +#define FN_INT64(func) cinn_nvgpu_##func##_int64 + +__device__ inline long long int FN_INT64(clz)(long long int a) { return __clzll(a); } // *************************************************************** // @@ -399,6 +407,7 @@ __device__ inline float cinn_cuda_index_add(const float x, #undef FN_FP32 #undef FN_FP64 #undef FN_INT32 +#undef FN_INT64 #ifdef CINN_CUDA_FP16 #undef FN_FP16 diff --git a/cinn/runtime/cuda/cuda_intrinsics.cc b/cinn/runtime/cuda/cuda_intrinsics.cc index 82db00de62..34f1c2b08c 100644 --- a/cinn/runtime/cuda/cuda_intrinsics.cc +++ b/cinn/runtime/cuda/cuda_intrinsics.cc @@ -90,9 +90,17 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_int32, target, int, int); REGISTER_EXTERN_FUNC_1_IN_1_INT32(bitwise_not) + REGISTER_EXTERN_FUNC_1_IN_1_INT32(clz) #undef REGISTER_EXTERN_FUNC_1_IN_1_INT32 +#define REGISTER_EXTERN_FUNC_1_IN_1_INT64(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_int64, target, int64_t, int64_t); + + REGISTER_EXTERN_FUNC_1_IN_1_INT64(clz) + +#undef REGISTER_EXTERN_FUNC_1_IN_1_INT64 + #define REGISTER_EXTERN_FUNC_2_IN_1_INT32(func__) \ REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(cinn_nvgpu_##func__##_int32, target, int, int, int); diff --git a/python/tests/ops/test_clz_op.py b/python/tests/ops/test_clz_op.py new file mode 100644 index 0000000000..dd5c3ca45c --- /dev/null +++ b/python/tests/ops/test_clz_op.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest, OpTestTool +import paddle +import cinn +from cinn.frontend import * +from cinn.common import * + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "x86 test will be skipped due to timeout.") +class TestClzOp(OpTest): + def setUp(self): + self.init_case() + + def init_case(self): + self.inputs = { + # "x": self.random([32, 64], 'int32', low = -2147483648, high=2147483647) + "x": + np.array([ + -1591895863, -1770335025, -1290313501, 478042597, 189030958, + -935228100, 718518127, -2066013593, -1028229638, -1930307001, + -858478166, -282304333 + ]).astype(np.int32) + } + self.outputs = { + "y": np.array([0, 0, 0, 3, 4, 0, 2, 0, 0, 0, 0, + 0]).astype(np.int32) + } + + def build_paddle_program(self, target): + y = paddle.to_tensor(self.outputs["y"], stop_gradient=False) + self.paddle_outputs = [y] + + def build_cinn_program(self, target): + builder = NetBuilder("clz") + x = builder.create_input( + self.nptype2cinntype(self.inputs["x"].dtype), + self.inputs["x"].shape, "x") + out = builder.clz(x) + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], + [out]) + self.cinn_outputs = [res[0]] + + def test_check_results(self): + self.check_outputs_and_grads() + + +class TestClzCase1(TestClzOp): + def init_case(self): + self.inputs = { + # "x": self.random([48, 36], 'int32', low = -2147483648, high=2147483647) + "x": + np.array([[ + -780762106, 2088944770, 1793870564, 995233974, -1566864405, + -1550063384 + ], + [ + 58189437, -585656506, 1058816786, -1676158651, + -175192886, 2129254990 + ]]).astype(np.int32) + } + self.outputs = { + "y": + np.array([[0, 1, 1, 2, 0, 0], [6, 0, 2, 0, 0, 1]]).astype(np.int32) + } + + +class TestClzCase2(TestClzOp): + def init_case(self): + self.inputs = { + # "x": self.random([4, 3, 5, 8], 'int64', low = -9223372036854775808, high=9223372036854775807) + "x": + np.array([ + -2603587548323400654, 5370659515557365091, + -2051413160116828951, 9015154622229049624, + -8328245342679021727, -8113334794330105534, + 7187230222985732039, 1835610600500058242 + ]).astype(np.int64) + } + self.outputs = { + "y": np.array([0, 1, 0, 1, 0, 0, 1, 3]).astype(np.int64) + } + + +if __name__ == "__main__": + unittest.main()