diff --git a/cinn/backends/compiler.cc b/cinn/backends/compiler.cc index 880441ff63..798b0a96a2 100644 --- a/cinn/backends/compiler.cc +++ b/cinn/backends/compiler.cc @@ -25,6 +25,7 @@ #include "cinn/backends/nvrtc/nvrtc_util.h" #include "cinn/runtime/cuda/cuda_module.h" #include "cinn/runtime/cuda/cuda_util.h" +#include "cinn/runtime/flags.h" #endif DECLARE_string(cinn_source_code_save_path); @@ -123,16 +124,13 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code) SourceCodePrint::GetInstance()->write(source_code); using runtime::cuda::CUDAModule; - backends::nvrtc::Compiler compiler; - + nvrtc::Compiler compiler; auto ptx = compiler(source_code); CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << source_code; - cuda_module_.reset( new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX)); RuntimeSymbols symbols; - for (auto& fn : device_module.functions()) { std::string kernel_fn_name = fn->name; auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name); diff --git a/cinn/backends/nvrtc/nvrtc_util.cc b/cinn/backends/nvrtc/nvrtc_util.cc index 101406984e..4598054701 100644 --- a/cinn/backends/nvrtc/nvrtc_util.cc +++ b/cinn/backends/nvrtc/nvrtc_util.cc @@ -17,12 +17,20 @@ #include #include #include +#include +#include +#include + +#include +#include #include "cinn/backends/cuda_util.h" #include "cinn/backends/nvrtc/header_generator.h" #include "cinn/common/common.h" +#include "cinn/runtime/flags.h" #include "cinn/utils/string.h" +DECLARE_string(cinn_nvcc_cmd_path); DECLARE_bool(nvrtc_compile_to_cubin); namespace cinn { @@ -30,6 +38,9 @@ namespace backends { namespace nvrtc { std::string Compiler::operator()(const std::string& code, bool include_headers) { + if (runtime::CanUseNvccCompiler()) { + return CompileWithNvcc(code); + } return CompileCudaSource(code, include_headers); } @@ -140,6 +151,89 @@ std::string Compiler::CompileCudaSource(const std::string& code, bool include_he return data; } +std::string Compiler::CompileWithNvcc(const std::string& cuda_c) { + // read dir source + std::string dir = "./source"; + if (access(dir.c_str(), 0) == -1) { + CHECK(mkdir(dir.c_str(), 7) != -1) << "Fail to mkdir " << dir; + } + + // get unqiue prefix name + prefix_name_ = dir + "/" + common::UniqName("rtc_tmp"); + + auto cuda_c_file = prefix_name_ + ".cu"; + std::ofstream ofs(cuda_c_file, std::ios::out); + CHECK(ofs.is_open()) << "Fail to open file " << cuda_c_file; + ofs << cuda_c; + ofs.close(); + + CompileToPtx(); + CompileToCubin(); + + return prefix_name_ + ".cubin"; +} + +// std::string Compiler::GetPtx() { return ReadFile(prefix_name_ + ".ptx", std::ios::in); } + +void Compiler::CompileToPtx() { + auto include_dir = common::Context::Global().runtime_include_dir(); + std::string include_dir_str = ""; + for (auto dir : include_dir) { + if (include_dir_str.empty()) { + include_dir_str = dir; + } else { + include_dir_str += ":" + dir; + } + } + + std::string options = std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path + + std::string(":$PATH && nvcc -std=c++14 --ptx -O3 -I ") + include_dir_str; + options += " -arch=" + GetDeviceArch(); + options += " -o " + prefix_name_ + ".ptx"; + options += " " + prefix_name_ + ".cu"; + + VLOG(2) << "Nvcc Compile Options : " << options; + CHECK(system(options.c_str()) == 0) << options; +} + +void Compiler::CompileToCubin() { + std::string options = + std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path + std::string(":$PATH && nvcc --cubin -O3"); + options += " -arch=" + GetDeviceArch(); + options += " -o " + prefix_name_ + ".cubin"; + options += " " + prefix_name_ + ".ptx"; + + VLOG(2) << "Nvcc Compile Options : " << options; + CHECK(system(options.c_str()) == 0) << options; +} + +std::string Compiler::GetDeviceArch() { + int major = 0, minor = 0; + if (cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0) == cudaSuccess && + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0) == cudaSuccess) { + return "sm_" + std::to_string(major) + std::to_string(minor); + } else { + LOG(WARNING) << "cannot detect compute capability from your device, " + << "fall back to compute_30."; + return "sm_30"; + } +} + +std::string Compiler::ReadFile(const std::string& file_name, std::ios_base::openmode mode) { + // open cubin file + std::ifstream ifs(file_name, mode); + CHECK(ifs.is_open()) << "Fail to open file " << file_name; + ifs.seekg(std::ios::end); + auto len = ifs.tellg(); + ifs.seekg(0); + + // read cubin file + std::string file_data(len, ' '); + ifs.read(&file_data[0], len); + ifs.close(); + return std::move(file_data); +} + } // namespace nvrtc } // namespace backends } // namespace cinn diff --git a/cinn/backends/nvrtc/nvrtc_util.h b/cinn/backends/nvrtc/nvrtc_util.h index a5f8424a31..b13c24c550 100644 --- a/cinn/backends/nvrtc/nvrtc_util.h +++ b/cinn/backends/nvrtc/nvrtc_util.h @@ -70,6 +70,19 @@ class Compiler { * whether to compile the source code into cubin, only works with cuda version > 11.1 */ bool compile_to_cubin_{false}; + + // compile with nvcc + std::string CompileWithNvcc(const std::string&); + + // compile to ptx + void CompileToPtx(); + // compile to cubin + void CompileToCubin(); + std::string GetDeviceArch(); + + std::string ReadFile(const std::string&, std::ios_base::openmode); + + std::string prefix_name_{""}; }; } // namespace nvrtc diff --git a/cinn/hlir/framework/parallel_compiler.cc b/cinn/hlir/framework/parallel_compiler.cc index aa22dfce65..ede13cab04 100644 --- a/cinn/hlir/framework/parallel_compiler.cc +++ b/cinn/hlir/framework/parallel_compiler.cc @@ -28,6 +28,7 @@ #include "cinn/common/context.h" #include "cinn/hlir/framework/pass.h" #include "cinn/ir/module.h" +#include "cinn/runtime/flags.h" DECLARE_int32(cinn_parallel_compile_size); DECLARE_int32(cinn_parallel_compile_thread); @@ -178,10 +179,9 @@ void ParallelCompiler::Task::CodegenAndJit() { backends::nvrtc::Compiler compiler; auto ptx = compiler(cuda_c); CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c; - graph->SavePTXCode(ptx); - // load cumodule cumodule.reset(new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX)); + // register kernel backends::RuntimeSymbols symbols; for (auto& fn : dmodule.functions()) { diff --git a/cinn/runtime/cuda/cuda_module.cc b/cinn/runtime/cuda/cuda_module.cc index d112f5d4e9..56963e4efb 100644 --- a/cinn/runtime/cuda/cuda_module.cc +++ b/cinn/runtime/cuda/cuda_module.cc @@ -25,6 +25,7 @@ #include "cinn/backends/cuda_util.h" #include "cinn/runtime/cuda/cuda_util.h" +#include "cinn/runtime/flags.h" #include "cinn/utils/profiler.h" namespace cinn { @@ -106,16 +107,11 @@ CUfunction CUDAModule::GetFunction(int device_id, const std::string& func_name) jit_options[4] = CU_JIT_GENERATE_LINE_INFO; jit_opt_vals[4] = reinterpret_cast(value); - CUresult status = cuModuleLoadDataEx( - &module_per_card_[device_id], data_.c_str(), jit_num_options, jit_options.data(), jit_opt_vals.data()); - - if (CUDA_SUCCESS != status) { - RAW_LOG(ERROR, "PTX JIT ERROR LOG: %s\n.", log_buffer.data()); - const char* name; - cuGetErrorName(status, &name); - const char* msg; - cuGetErrorString(status, &msg); - RAW_LOG(FATAL, "The error `%s` occurs while compiling the ptx! And its message is `%s`.", name, msg); + if (runtime::CanUseNvccCompiler()) { + CUDA_DRIVER_CALL(cuModuleLoad(&module_per_card_[device_id], data_.c_str())); + } else { + CUDA_DRIVER_CALL(cuModuleLoadDataEx( + &module_per_card_[device_id], data_.c_str(), jit_num_options, jit_options.data(), jit_opt_vals.data())); } } @@ -127,11 +123,15 @@ CUfunction CUDAModule::GetFunction(int device_id, const std::string& func_name) CUdeviceptr CUDAModule::GetGlobal(int device_id, const std::string& name, size_t nbytes) { if (!module_per_card_[device_id]) { std::lock_guard lock(mutex_); - CUDA_DRIVER_CALL(cuModuleLoadData(&module_per_card_[device_id], data_.c_str())); + if (runtime::CanUseNvccCompiler()) { + CUDA_DRIVER_CALL(cuModuleLoad(&module_per_card_[device_id], data_.c_str())); + } else { + CUDA_DRIVER_CALL(cuModuleLoadData(&module_per_card_[device_id], data_.c_str())); + } } - CUdeviceptr global; size_t _nbytes; + CUdeviceptr global; CUDA_DRIVER_CALL(cuModuleGetGlobal(&global, &_nbytes, module_per_card_[device_id], name.c_str())); return global; } diff --git a/cinn/runtime/flags.cc b/cinn/runtime/flags.cc index 1dc61b3a24..a22ca41b47 100644 --- a/cinn/runtime/flags.cc +++ b/cinn/runtime/flags.cc @@ -16,6 +16,9 @@ #include #include +#include +#include +#include #include @@ -35,6 +38,9 @@ using ::GFLAGS_NAMESPACE::Int64FromEnv; using ::GFLAGS_NAMESPACE::StringFromEnv; DEFINE_string(cinn_x86_builtin_code_root, StringFromEnv("FLAGS_cinn_x86_builtin_code_root", ""), ""); +DEFINE_string(cinn_nvcc_cmd_path, + StringFromEnv("FLAGS_cinn_nvcc_cmd_path", "/usr/local/cuda/bin"), + "Setting nvcc default path!"); DEFINE_int32(cinn_parallel_compile_size, Int32FromEnv("FLAGS_cinn_parallel_compile_size", 16), @@ -82,9 +88,13 @@ DEFINE_bool(cinn_use_dense_merge_pass, "Whether use dense merge pass."); DEFINE_bool(nvrtc_compile_to_cubin, - BoolFromEnv("FLAGS_nvrtc_compile_to_cubin", false), + BoolFromEnv("FLAGS_nvrtc_compile_to_cubin", true), "Whether nvrtc compile cuda source into cubin instead of ptx (only works after cuda-11.1)."); +DEFINE_bool(cinn_compile_with_nvrtc, + BoolFromEnv("FLAGS_cinn_compile_with_nvrtc", true), + "Whether nvrtc compile cuda source with nvrtc(default nvcc)."); + // FLAGS for performance analysis and accuracy debug DEFINE_bool(cinn_sync_run, BoolFromEnv("FLAGS_cinn_sync_run", false), @@ -180,6 +190,11 @@ unsigned long long RandomSeed::Clear() { return old_seed; } +bool CanUseNvccCompiler() { + std::string nvcc_dir = FLAGS_cinn_nvcc_cmd_path + "/nvcc"; + return (access(nvcc_dir.c_str(), 0) == -1 ? false : true) && (!FLAGS_cinn_compile_with_nvrtc); +} + bool IsCompiledWithCUDA() { #if !defined(CINN_WITH_CUDA) return false; diff --git a/cinn/runtime/flags.h b/cinn/runtime/flags.h index 4b4f19f322..6a663d12af 100644 --- a/cinn/runtime/flags.h +++ b/cinn/runtime/flags.h @@ -27,6 +27,8 @@ bool CheckStringFlagFalse(const std::string &flag); void SetCinnCudnnDeterministic(bool state); bool GetCinnCudnnDeterministic(); +bool CanUseNvccCompiler(); + class RandomSeed { public: static unsigned long long GetOrSet(unsigned long long seed = 0);