Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
add clz op (#1059)
Browse files Browse the repository at this point in the history
* add clz op

* fix data type

* fix datatype
  • Loading branch information
zzk0 authored Nov 21, 2022
1 parent 3357aa2 commit 11e0fba
Show file tree
Hide file tree
Showing 14 changed files with 425 additions and 2 deletions.
2 changes: 2 additions & 0 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ std::string CodeGenC::GetTypeName(Type type) {
GET_SCALAR_TYPE(type.is_int(8), "int8_t");
GET_SCALAR_TYPE(type.is_int(32), "int32_t");
GET_SCALAR_TYPE(type.is_int(64), "int64_t");
GET_SCALAR_TYPE(type.is_uint(32), "uint32_t");
GET_SCALAR_TYPE(type.is_uint(64), "uint64_t");
GET_SCALAR_TYPE(type.is_float(32), "float")
GET_SCALAR_TYPE(type.is_float(64), "double")
#undef GET_SCALAR_TYPE
Expand Down
1 change: 1 addition & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ NETBUILDER_UNARY_OP_DEF(BitwiseNot, bitwise_not)
NETBUILDER_UNARY_OP_DEF(Negative, negative)
NETBUILDER_UNARY_OP_DEF(Sign, sign)
NETBUILDER_UNARY_OP_DEF(Abs, abs)
NETBUILDER_UNARY_OP_DEF(Clz, clz)

#undef NETBUILDER_UNARY_OP_DEF

Expand Down
3 changes: 2 additions & 1 deletion cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ namespace frontend {
macro__(BitwiseNot) \
macro__(Negative) \
macro__(Sign) \
macro__(Abs)
macro__(Abs) \
macro__(Clz)

// ******************************************* //
// The op has two input and one output, with a attribute [axis]
Expand Down
2 changes: 2 additions & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ gather_srcs(cinnapi_src SRCS
repeat.cc
lookup_table.cc
one_hot.cc
clz.cc
)

cc_test(test_cast SRCS cast_test.cc DEPS cinncore)
Expand All @@ -32,3 +33,4 @@ cc_test(test_flip SRCS flip_test.cc DEPS cinncore)
cc_test(test_repeat SRCS repeat_test.cc DEPS cinncore)
cc_test(test_one_hot SRCS one_hot_test.cc DEPS cinncore)
cc_test(test_lookup_table SRCS lookup_table_test.cc DEPS cinncore)
cc_test(test_clz SRCS clz_test.cc DEPS cinncore)
143 changes: 143 additions & 0 deletions cinn/hlir/op/contrib/clz.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "cinn/common/cas.h"
#include "cinn/common/common.h"
#include "cinn/common/context.h"
#include "cinn/common/macros.h"
#include "cinn/common/target.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/op/contrib/clip.h"
#include "cinn/hlir/pe/ir_schedule_pe.h"
#include "cinn/hlir/pe/nn.h"
#include "cinn/hlir/pe/schedule.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/ir_operators.h"
#include "cinn/ir/tensor.h"
#include "cinn/lang/builtin.h"
#include "cinn/lang/compute.h"
#include "gflags/gflags.h"

DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace hlir {
namespace op {

using common::_CINNValuePack_;
using common::CINNValue;
using common::CINNValuePack;
using framework::OpStrategy;
using framework::shape_t;
using framework::StrategyFunction;

ir::Tensor Clz(const ir::Tensor &input, const Target &target, const std::string &output_name) {
std::string extern_func = "cinn_";
if (target == common::DefaultHostTarget()) {
extern_func += "host_";
} else if (target == common::DefaultNVGPUTarget()) {
extern_func += "nvgpu_";
} else {
CINN_NOT_IMPLEMENTED
}

extern_func += "clz";

if (input->type().is_int(32) || input->type().is_uint(32)) {
extern_func += "_int32";
} else if (input->type().is_int(64) || input->type().is_uint(64)) {
extern_func += "_int64";
} else {
CINN_NOT_IMPLEMENTED
}

return Compute(
input->shape,
[=](const std::vector<Expr> &indices) {
Expr e = input(indices);
return lang::CallExtern(extern_func, {e});
},
output_name);
}

std::shared_ptr<OpStrategy> StrategyForClz(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
std::string op_name("clz");

framework::CINNCompute clz_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check.\n";
CINNValuePack pack_args = args[0];
CHECK(!pack_args.empty()) << "at least one input tensor for " << op_name << " compute\n";

std::string tensor_name = UniqName(op_name + "_Out");
if (FLAGS_cinn_ir_schedule) {
CHECK_EQ(pack_args.size(), 2);
CHECK(pack_args[1].is_string());
tensor_name = pack_args[1].operator std::string();
}

Expr A_expr = pack_args[0];
CHECK(A_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();
auto out = Clz(A, target, tensor_name);
auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(clz_compute, framework::GetInjectiveScheduleFunc(output_shapes, target), "strategy.clz.x86", 1);
return strategy;
}

std::vector<framework::shape_t> InferShapeForClz(const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again.";
std::vector<framework::shape_t> res{inputs_shape[0]};
return res;
}

std::vector<Type> InferDtypeForClz(const std::vector<Type> &inputs_type, const framework::AttrMapType &attrs) {
CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again.";
std::vector<Type> res{inputs_type[0]};
return res;
}

} // namespace op
} // namespace hlir
} // namespace cinn

CINN_REGISTER_HELPER(clz_ops) {
CINN_REGISTER_OP(clz)
.describe("Counting Leading Zeros.")
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForClz)
.set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForClz))
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForClz))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise)
.set_support_level(4);

return true;
}
32 changes: 32 additions & 0 deletions cinn/hlir/op/contrib/clz.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/tensor.h"

namespace cinn {
namespace hlir {
namespace op {

ir::Tensor Clz(const ir::Tensor& input, const Target& target, const std::string& output_name = "T_Clz_out");

} // namespace op
} // namespace hlir
} // namespace cinn
99 changes: 99 additions & 0 deletions cinn/hlir/op/contrib/clz_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/op/contrib/clz.h"

#include <glog/logging.h>
#include <gtest/gtest.h>

#include <string>
#include <vector>

#include "cinn/backends/codegen_c.h"
#include "cinn/backends/codegen_c_x86.h"
#include "cinn/backends/codegen_cuda_dev.h"
#include "cinn/common/context.h"
#include "cinn/lang/lower.h"
#include "cinn/lang/placeholder.h"
#include "cinn/poly/stage.h"

namespace cinn {
namespace hlir {
namespace op {
namespace {
bool IsCompiledWithCUDA() {
#if !defined(CINN_WITH_CUDA)
return false;
#else
return true;
#endif
}
} // namespace

TEST(GenerateCode_Cpu, Clz) {
common::Context::Global().ResetNameId();

common::Target target = common::DefaultHostTarget();
lang::Placeholder<int> in("in", std::vector<int>{10});
ir::Tensor res = Clz(in, target, "test_clz");

poly::StageMap stages = poly::CreateStages({res});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestGenerateCodeCpu_Clz", stages, {res}, {}, {}, nullptr, target, true);

VLOG(6) << "Expr before CPU codegen:";
VLOG(6) << funcs[0]->body;

ir::Module::Builder builder("Clz_Module", target);
for (auto& f : funcs) {
builder.AddFunction(f);
}

backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512);
codegen.SetInlineBuiltinCodes(false);
std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl);
VLOG(6) << "Cpu Codegen result:";
VLOG(6) << code;
}

TEST(GenerateCode_Cuda, Clz) {
if (!IsCompiledWithCUDA()) {
return;
}
common::Context::Global().ResetNameId();

common::Target target = common::DefaultNVGPUTarget();

lang::Placeholder<int64_t> in("in", std::vector<int>{10});
ir::Tensor res = Clz(in, target, "test_clz");

poly::StageMap stages = poly::CreateStages({res});
stages[res]->Bind(0, "blockIdx.x");
stages[res]->SetBuffer("global");

std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestGenerateCodeCuda_Clz", stages, {res}, {}, {}, nullptr, target, true);

VLOG(6) << "Expr before CUDA codegen:";
VLOG(6) << funcs[0]->body;

ir::Module::Builder builder("Clz_Module", target);
for (auto& f : funcs) {
builder.AddFunction(f);
}
}

} // namespace op
} // namespace hlir
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/hlir/op/use_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ CINN_USE_REGISTER(gelu_ops)
CINN_USE_REGISTER(repeat_ops)
CINN_USE_REGISTER(one_hot_ops)
CINN_USE_REGISTER(lookup_table_ops)
CINN_USE_REGISTER(clz_ops)
3 changes: 2 additions & 1 deletion cinn/pybind/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,8 @@ void BindFrontend(pybind11::module *m) {
.def("clip", &NetBuilder::Clip, py::arg("x"), py::arg("max"), py::arg("min"))
.def("arange", &NetBuilder::Arange, py::arg("start"), py::arg("end"), py::arg("step"), py::arg("dtype"))
.def("gather", &NetBuilder::Gather, py::arg("x"), py::arg("index"), py::arg("axis"))
.def("gather_nd", &NetBuilder::GatherNd, py::arg("x"), py::arg("index"), py::arg("axes"));
.def("gather_nd", &NetBuilder::GatherNd, py::arg("x"), py::arg("index"), py::arg("axes"))
.def("clz", &NetBuilder::Clz, py::arg("x"));

auto computation = py::class_<CinnComputation, std::shared_ptr<CinnComputation>>(*m, "Computation");
py::class_<CinnComputation::CompileOptions>(computation, "CompileOptions")
Expand Down
12 changes: 12 additions & 0 deletions cinn/runtime/cpu/host_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,15 @@ inline int FN_INT32(pow)(int x, int y) {
return res;
}

inline int FN_INT32(clz)(int x) { return __builtin_clz(x); }

#undef FN_INT32

#define FN_INT64(func) cinn_host_##func##_int64

inline int64_t FN_INT64(clz)(int64_t x) { return __builtin_clzll(x); }

#undef FN_INT64
}

CINN_REGISTER_HELPER(host_intrinsics) {
Expand Down Expand Up @@ -178,6 +186,10 @@ CINN_REGISTER_HELPER(host_intrinsics) {

#undef REGISTER_EXTERN_FUNC_2_IN_1_INT32

REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_host_clz_int32, host_target, int, int);

REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_host_clz_int64, host_target, int64_t, int64_t);

REGISTER_EXTERN_FUNC_HELPER(cinn_host_find_int, host_target)
.SetRetType<int>()
.AddInputType<cinn_buffer_t*>()
Expand Down
8 changes: 8 additions & 0 deletions cinn/runtime/cpu/host_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,13 @@ inline int cinn_host_gt_num_int(

inline int FN_INT32(pow)(int x, int y);

inline int FN_INT32(clz)(int x);

#undef FN_INT32

#define FN_INT64(func) cinn_host_##func##_int64

inline int64_t FN_INT64(clz)(int64_t x);

#undef FN_INT64
}
9 changes: 9 additions & 0 deletions cinn/runtime/cuda/cinn_cuda_runtime_source.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ __device__ inline int FN_INT32(bitwise_and)(int a, int b) { return a & b; }
__device__ inline int FN_INT32(bitwise_or)(int a, int b) { return a | b; }
__device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; }
__device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; }
__device__ inline int FN_INT32(clz)(int a) { return __clz(a); }


// *************************************************************** //
// int64 unary and binary operator
#define FN_INT64(func) cinn_nvgpu_##func##_int64

__device__ inline long long int FN_INT64(clz)(long long int a) { return __clzll(a); }


// *************************************************************** //
Expand Down Expand Up @@ -399,6 +407,7 @@ __device__ inline float cinn_cuda_index_add(const float x,
#undef FN_FP32
#undef FN_FP64
#undef FN_INT32
#undef FN_INT64

#ifdef CINN_CUDA_FP16
#undef FN_FP16
Expand Down
Loading

0 comments on commit 11e0fba

Please sign in to comment.