diff --git a/cinn/backends/CMakeLists.txt b/cinn/backends/CMakeLists.txt index 3dd2c6c23d..70b6d9e767 100755 --- a/cinn/backends/CMakeLists.txt +++ b/cinn/backends/CMakeLists.txt @@ -40,7 +40,7 @@ endif() if (WITH_CUDA) nv_test(test_codegen_cuda_generate SRCS codegen_cuda_generate_test.cc DEPS cinncore) - nv_test(test_codegen_debug SRCS codegen_debug_test.cc DEPS cinncore cinn_runtime) + nv_test(test_codegen_debug SRCS codegen_debug_test.cc DEPS cinncore) if (WITH_TESTING) cc_library(generated1_cuda SRCS generated1.cu DEPS cinncore) diff --git a/cinn/backends/codegen_debug_test.cc b/cinn/backends/codegen_debug_test.cc index e47d325019..a464f5b2fe 100644 --- a/cinn/backends/codegen_debug_test.cc +++ b/cinn/backends/codegen_debug_test.cc @@ -60,9 +60,9 @@ TEST(CodeGenDebug, RunCudaSourceCode) { common::Context::Global().ResetNameId(); std::string source_code = R"ROC( +extern "C" { #include "cinn_cuda_runtime_source.cuh" -extern "C" { #ifdef __CUDACC_RTC__ typedef int int32_t; diff --git a/cinn/backends/extern_func_protos.cc b/cinn/backends/extern_func_protos.cc index a54a883431..58472677b3 100644 --- a/cinn/backends/extern_func_protos.cc +++ b/cinn/backends/extern_func_protos.cc @@ -27,9 +27,7 @@ ExternFunctionProtoRegistry::ExternFunctionProtoRegistry() { static const std::vector extern_funcs_float_bool_unary = {"isnan", "isfinite", "isinf"}; static const std::vector extern_funcs_int_binary = { "left_shift", "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not"}; - static const std::vector extern_funcs_int_int_unary = {"bitwise_not"}; - static const std::vector extern_funcs_int_float_call = {"cinn_nvgpu_uniform_random_fp32"}; - static const std::vector extern_funcs_int_double_call = {"cinn_nvgpu_uniform_random_fp64"}; + static const std::vector extern_funcs_int_int_unary = {"bitwise_not"}; for (int i = 0; i < extern_funcs_fp32_unary.size(); ++i) { auto* proto = new FunctionProto(extern_funcs_fp32_unary[i], {Float(32)}, Float(32)); Register(proto->name, proto); @@ -46,14 +44,6 @@ ExternFunctionProtoRegistry::ExternFunctionProtoRegistry() { auto* proto = new FunctionProto(extern_funcs_int_int_unary[i], {Int(32)}, Int(32)); Register(proto->name, proto); } - for (int i = 0; i < extern_funcs_int_float_call.size(); ++i) { - auto* proto = new FunctionProto(extern_funcs_int_float_call[i], {Int(32)}, Float(32)); - Register(proto->name, proto); - } - for (int i = 0; i < extern_funcs_int_double_call.size(); ++i) { - auto* proto = new FunctionProto(extern_funcs_int_double_call[i], {Int(32)}, Float(64)); - Register(proto->name, proto); - } auto* n = detail::CreateTanhVProto(); Register(n->name, n); diff --git a/cinn/hlir/op/contrib/CMakeLists.txt b/cinn/hlir/op/contrib/CMakeLists.txt index 8252634e36..fde84a733e 100644 --- a/cinn/hlir/op/contrib/CMakeLists.txt +++ b/cinn/hlir/op/contrib/CMakeLists.txt @@ -31,6 +31,3 @@ cc_test(test_repeat SRCS repeat_test.cc DEPS cinncore) cc_test(test_one_hot SRCS one_hot_test.cc DEPS cinncore) cc_test(test_lookup_table SRCS lookup_table_test.cc DEPS cinncore) cc_test(test_reciprocal SRCS reciprocal_test.cc DEPS cinncore) -if (WITH_CUDA) - cc_test(test_uniform_random_gpu SRCS uniform_random_test.cc DEPS cinncore) -endif() diff --git a/cinn/hlir/op/contrib/uniform_random.cc b/cinn/hlir/op/contrib/uniform_random.cc index e74e3a5fb7..d4b2f9c8a9 100644 --- a/cinn/hlir/op/contrib/uniform_random.cc +++ b/cinn/hlir/op/contrib/uniform_random.cc @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "cinn/hlir/op/contrib/uniform_random.h" #include @@ -46,8 +45,6 @@ #include "cinn/poly/stage.h" #include "glog/logging.h" -DECLARE_bool(cinn_ir_schedule); - namespace cinn { namespace hlir { namespace op { @@ -55,34 +52,6 @@ namespace op { using common::CINNValue; using common::CINNValuePack; -// Only for min = 0. and max = 1. -ir::Tensor UniformRandom(const std::vector &shape, - int seed, - const std::string &dtype, - const Target &target, - const std::string &tensor_name) { - std::string extern_func = "cinn_nvgpu_uniform_random_"; - if (target != common::DefaultNVGPUTarget()) { - LOG(FATAL) << "Not Implemented UniformRandom for target: " << target; - } - - if (dtype == "float32") { - extern_func += "fp32"; - } else if (dtype == "float64") { - extern_func += "fp64"; - } else { - LOG(FATAL) << "Not Implemented UniformRandom for dtype: " << dtype; - } - - std::vector new_shape; - for (auto item : shape) { - new_shape.push_back(Expr(item)); - } - - return lang::Compute( - new_shape, [=]() { return lang::CallExtern(extern_func, {Expr(seed)}); }, tensor_name); -} - std::shared_ptr StrategyForUniformRandom(const framework::NodeAttr &attrs, const std::vector &inputs, const std::vector &out_type, @@ -91,22 +60,9 @@ std::shared_ptr StrategyForUniformRandom(const framework: framework::CINNCompute uniform_random_compute([=](lang::Args args, lang::RetValue *ret) { CHECK(attrs.attr_store.count("shape")); ir::Tensor shape_tensor; - CHECK(output_shapes.size() == 1UL); - CHECK(attrs.attr_store.count("seed")); - int seed = absl::get(attrs.attr_store.at("seed")); - std::string dtype = "float32"; - if (attrs.attr_store.find("dtype") != attrs.attr_store.end()) { - dtype = absl::get(attrs.attr_store.at("dtype")); - } - CINNValuePack arg_pack = args[0]; - std::string tensor_name = UniqName("uniform_random_out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(arg_pack.size(), 1U); - CHECK(arg_pack[0].is_string()); - tensor_name = arg_pack[0].operator std::string(); - } - auto out = UniformRandom(output_shapes[0], seed, dtype, target, tensor_name); - auto stages = CreateStages({out}); + std::string tensor_name = "uniform_random_out"; + auto out = pe::Identity(shape_tensor, tensor_name).front(); + auto stages = CreateStages({out}); std::vector res{CINNValue(out), CINNValue(stages)}; *ret = CINNValuePack{res}; }); @@ -148,7 +104,7 @@ CINN_REGISTER_HELPER(uniform_random_ops) { .set_attr("CINNStrategy", cinn::hlir::op::StrategyForUniformRandom) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForUniformRandom)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForUniformRandom)) - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; diff --git a/cinn/hlir/op/contrib/uniform_random.h b/cinn/hlir/op/contrib/uniform_random.h deleted file mode 100644 index 244605afd4..0000000000 --- a/cinn/hlir/op/contrib/uniform_random.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2022 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "cinn/ir/ir.h" -#include "cinn/ir/ir_base.h" -#include "cinn/ir/tensor.h" - -namespace cinn { -namespace hlir { -namespace op { - -// Only for min = 0. and max = 1. -ir::Tensor UniformRandom(const std::vector& shape, - int seed, - const std::string& dtype, - const Target& target, - const std::string& tensor_name); - -} // namespace op -} // namespace hlir -} // namespace cinn diff --git a/cinn/hlir/op/contrib/uniform_random_test.cc b/cinn/hlir/op/contrib/uniform_random_test.cc deleted file mode 100644 index fe07261c60..0000000000 --- a/cinn/hlir/op/contrib/uniform_random_test.cc +++ /dev/null @@ -1,167 +0,0 @@ -// Copyright (c) 2022 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "cinn/hlir/op/contrib/uniform_random.h" - -#include -#include - -#include -#include - -#include "cinn/backends/codegen_c.h" -#include "cinn/backends/codegen_c_x86.h" -#include "cinn/backends/codegen_cuda_dev.h" -#include "cinn/backends/codegen_cuda_util.h" -#include "cinn/common/context.h" -#include "cinn/frontend/net_builder.h" -#include "cinn/frontend/optimize.h" -#include "cinn/hlir/framework/graph.h" -#include "cinn/hlir/framework/graph_compiler.h" -#include "cinn/lang/lower.h" -#include "cinn/lang/placeholder.h" -#include "cinn/poly/stage.h" -#include "cinn/utils/data_util.h" - -namespace cinn { -namespace hlir { -namespace op { - -#ifdef CINN_WITH_CUDA -TEST(GenerateCode_CUDA, UniformRandomGPU) { - common::Context::Global().ResetNameId(); - - common::Target target = common::DefaultNVGPUTarget(); - - std::vector shape = {128, 12}; - int seed = 2023; - std::string dtype = "float32"; - - ir::Tensor res = UniformRandom(shape, seed, dtype, target, "uniform_random_out"); - - poly::StageMap stages = poly::CreateStages({res}); - std::vector funcs = - lang::LowerVec("TestGenerateCodeGPU_UniformRandom", stages, {res}, {}, {}, nullptr, target, true); - - VLOG(6) << "Expr before CUDA codegen:"; - VLOG(6) << funcs[0]->body; - - ir::Module::Builder builder("UniformRandom_Module", target); - for (auto& f : funcs) { - builder.AddFunction(f); - } - - auto module = builder.Build(); - auto host_module_device_module = backends::SplitCudaAndHostModule(module); // NOLINT - auto& host_module = std::get<0>(host_module_device_module); - auto& device_module = std::get<1>(host_module_device_module); - - backends::CodeGenCUDA_Dev codegen(target); - std::string source_code = codegen.Compile(device_module); - LOG(INFO) << "compiled code:\n" << source_code; -} - -} // namespace op -} // namespace hlir - -namespace frontend { - -TEST(Builder, UniformRandomFP32) { - NetBuilder builder("net_builder"); - - std::vector shape = {128, 12, 128, 128}; - int seed = 2023; - std::string dtype = "float32"; - auto out = builder.UniformRandom(shape, 0., 1., seed, dtype); - auto program = builder.Build(); - - for (int i = 0; i < program.size(); ++i) { - LOG(INFO) << "instruction: " << program[i]; - } - - Target target = common::DefaultNVGPUTarget(); - std::unordered_set fetch_ids; - auto graph = Optimize(&program, fetch_ids, target); - - LOG(INFO) << "graph: \n" << graph->Visualize(); - - auto scope = BuildScope(target, graph); - - hlir::framework::GraphCompiler gc(target, scope, graph); - auto runtime_program = gc.Build(); - - auto out_ten = scope->GetTensor(std::string(out->id)); - runtime_program->Execute(); - - EXPECT_EQ(out_ten->type(), Float(32)); - - std::vector data = GetTensorData(out_ten, target); - - int cnt = 0; - for (int i = 0; i < 128 * 12 * 128 * 128; ++i) { - if (data[i] > 0.5) cnt++; - } - float ratio = (float)cnt / (128 * 12 * 128 * 128); - LOG(INFO) << "count: " << cnt; - LOG(INFO) << "x > 0.5f ratio: " << ratio; - EXPECT_LE(ratio, 0.501f); - EXPECT_GE(ratio, 0.499f); -} - -TEST(Builder, UniformRandomFP64) { - NetBuilder builder("net_builder"); - - std::vector shape = {128, 12, 128, 128}; - int seed = 2023; - std::string dtype = "float64"; - auto out = builder.UniformRandom(shape, 0., 1., seed, dtype); - auto program = builder.Build(); - - for (int i = 0; i < program.size(); ++i) { - LOG(INFO) << "instruction: " << program[i]; - } - - Target target = common::DefaultNVGPUTarget(); - std::unordered_set fetch_ids; - auto graph = Optimize(&program, fetch_ids, target); - - LOG(INFO) << "graph: \n" << graph->Visualize(); - - auto scope = BuildScope(target, graph); - - hlir::framework::GraphCompiler gc(target, scope, graph); - auto runtime_program = gc.Build(); - - auto out_ten = scope->GetTensor(std::string(out->id)); - runtime_program->Execute(); - - EXPECT_EQ(out_ten->type(), Float(64)); - - std::vector data = GetTensorData(out_ten, target); - - int cnt = 0; - for (int i = 0; i < 128 * 12 * 128 * 128; ++i) { - if (data[i] > 0.5) cnt++; - } - - float ratio = (float)cnt / (128 * 12 * 128 * 128); - LOG(INFO) << "count: " << cnt; - LOG(INFO) << "x > 0.5f ratio: " << ratio; - EXPECT_LE(ratio, 0.501f); - EXPECT_GE(ratio, 0.499f); -} -#endif - -} // namespace frontend - -} // namespace cinn diff --git a/cinn/hlir/op/external_api_registry.cc b/cinn/hlir/op/external_api_registry.cc index 700377c05a..8928078be7 100644 --- a/cinn/hlir/op/external_api_registry.cc +++ b/cinn/hlir/op/external_api_registry.cc @@ -55,6 +55,7 @@ CINN_REGISTER_HELPER(op_external_api) { CINN_OP_REGISTER_EXTERNAL_API(cublas_gemm, default_nvgpu).set_api_name("cinn_call_cublas"); CINN_OP_REGISTER_EXTERNAL_API(cublas_matmul, default_nvgpu).set_api_name("cinn_call_cublas"); CINN_OP_REGISTER_EXTERNAL_API(gaussian_random, default_nvgpu).set_api_name("cinn_call_gaussian_random"); + CINN_OP_REGISTER_EXTERNAL_API(uniform_random, default_nvgpu).set_api_name("cinn_call_uniform_random"); CINN_OP_REGISTER_EXTERNAL_API(randint, default_nvgpu).set_api_name("cinn_call_randint"); CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_nvgpu).set_api_name("cinn_call_cholesky_nvgpu"); CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_host).set_api_name("cinn_call_cholesky_host"); diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index bd032d2295..a5d751cfa4 100644 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -1,10 +1,6 @@ /** * \file This file contains all the intrinsics available to be used in CUDA code generated by CodeGen. */ - -#include -#include - extern "C" { // *************************************************************** // // float32 unary and binary operator @@ -346,20 +342,6 @@ __device__ inline bool cinn_any(const bool left, const bool right) { return left shfl_res = __shfl_down_sync(mask, tmp_val, offset, 32); \ tmp_val = op((threadIdx.x & 0x1f) + offset < lane ? shfl_res : init, tmp_val); -__device__ inline float cinn_nvgpu_uniform_random_fp32(int seed){ - curandStatePhilox4_32_10_t state; - int idx = threadIdx.x + blockIdx.x * blockDim.x; - curand_init(seed, idx, 1, &state); - return curand_uniform(&state); -} - -__device__ inline double cinn_nvgpu_uniform_random_fp64(int seed){ - curandStatePhilox4_32_10_t state; - int idx = threadIdx.x + blockIdx.x * blockDim.x; - curand_init(seed, idx, 1, &state); - return curand_uniform_double(&state); -} - #define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ __device__ inline DTYPE cinn_warp_shuffle_##REDUCE_TYPE##_internal(const DTYPE value) { \ DTYPE tmp_val = value, shfl_res; \ diff --git a/cinn/utils/data_util.cc b/cinn/utils/data_util.cc index 963130f44d..907d931f72 100644 --- a/cinn/utils/data_util.cc +++ b/cinn/utils/data_util.cc @@ -116,7 +116,6 @@ std::vector GetTensorData(const hlir::framework::Tensor& tensor, const common } template std::vector GetTensorData(const hlir::framework::Tensor& tensor, const common::Target& target); -template std::vector GetTensorData(const hlir::framework::Tensor& tensor, const common::Target& target); template std::vector GetTensorData(const hlir::framework::Tensor& tensor, const common::Target& target); } // namespace cinn