diff --git a/CMakeLists.txt b/CMakeLists.txt index e63dd9f2..ae834f2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,7 +30,7 @@ endif() # find_package(Threads REQUIRED) find_package(Halide REQUIRED COMPONENTS shared) -find_package(CUDA) +find_package(CUDAToolkit) # # Version diff --git a/include/ion/builder.h b/include/ion/builder.h index a7be4463..074c54a9 100644 --- a/include/ion/builder.h +++ b/include/ion/builder.h @@ -63,6 +63,12 @@ class Builder { */ Builder& set_target(const Target& target); + /** + * Set the user context which will be applied the pipeline built with this builder. + * @arg user_context_ptr: The pointer to the user context. + */ + Builder& set_jit_context(Halide::JITUserContext *user_context_ptr); + /** * Load bb module dynamically and enable it to compile your pipeline. * @arg module_path: DSO path on your filesystem. diff --git a/include/ion/port.h b/include/ion/port.h index e199d620..817d2970 100644 --- a/include/ion/port.h +++ b/include/ion/port.h @@ -190,7 +190,7 @@ class Port { } impl_->instances[i] = v; - impl_->bound_address[i] = std::make_tuple(v,false); + impl_->bound_address[i] = std::make_tuple(v, false); } template @@ -202,8 +202,9 @@ class Port { impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i,graph_id())}; } - impl_->instances[i] = buf.raw_buffer(); - impl_->bound_address[i] = std::make_tuple(buf.data(),false); + auto raw_buf = buf.raw_buffer(); + impl_->instances[i] = raw_buf; + impl_->bound_address[i] = std::make_tuple(raw_buf->host ? reinterpret_cast(raw_buf->host) : reinterpret_cast(raw_buf->device), false); } template @@ -215,8 +216,9 @@ class Port { impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())}; } - impl_->instances[i] = bufs[i].raw_buffer(); - impl_->bound_address[i] = std::make_tuple(bufs[i].data(),false); + auto raw_buf = bufs[i].raw_buffer(); + impl_->instances[i] = raw_buf; + impl_->bound_address[i] = std::make_tuple(raw_buf->host ? reinterpret_cast(raw_buf->host) : reinterpret_cast(raw_buf->device), false); } } @@ -239,10 +241,6 @@ class Port { } std::vector as_func() const { -// if (dimensions() == 0) { -// throw std::runtime_error("Unreachable"); -// } - std::vector fs; for (const auto& [i, param] : impl_->params ) { if (fs.size() <= i) { @@ -250,14 +248,14 @@ class Port { } std::vector args; std::vector args_expr; - for (int i = 0; i < dimensions(); ++i) { - args.push_back(Halide::Var::implicit(i)); - args_expr.push_back(Halide::Var::implicit(i)); + for (int j = 0; j < dimensions(); ++j) { + args.push_back(Halide::Var::implicit(j)); + args_expr.push_back(Halide::Var::implicit(j)); } Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id()) + "_im"); f(args) = Halide::Internal::Call::make(param, args_expr); fs[i] = f; - if(std::get<1>(impl_->bound_address[i])){ + if (std::get<1>(impl_->bound_address[i])) { f.compute_root(); } } diff --git a/src/builder.cc b/src/builder.cc index 3f6774a7..3227194f 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -106,6 +106,11 @@ Builder& Builder::set_target(const Halide::Target& target) { return *this; } +Builder& Builder::set_jit_context(Halide::JITUserContext *user_context_ptr) { + impl_->jit_ctx_ptr = user_context_ptr; + return *this; +} + Builder& Builder::with_bb_module(const std::string& module_name_or_path) { auto bb_module = std::make_shared(module_name_or_path); auto register_extern = bb_module->get_symbol&)>("register_externs"); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2b90c1ce..83d3e890 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -54,9 +54,9 @@ ion_jit_executable(port-assign SRCS port-assign.cc) # zero copy i/o for extern functions ion_jit_executable(direct-extern SRCS direct-extern.cc) -if (${CUDA_FOUND}) - ion_jit_executable(gpu-extern SRCS gpu-extern.cc) - cuda_add_library(gpu-extern-lib SHARED gpu-extern-lib.cu) +if (${CUDAToolkit_FOUND}) + ion_jit_executable(gpu-extern SRCS gpu-extern.cc gpu-extern-lib.cu) + ion_jit_executable(cuda-interop SRCS cuda-interop.cc LIBS CUDA::cuda_driver CUDA::cudart) endif() # Duplicate name test diff --git a/test/cuda-interop.cc b/test/cuda-interop.cc new file mode 100644 index 00000000..c128b36b --- /dev/null +++ b/test/cuda-interop.cc @@ -0,0 +1,184 @@ +#include +#include +#include + +#include +#include + +#include "ion/ion.h" + +#define CUDA_SAFE_CALL(x) \ + do { \ + cudaError_t err = x; \ + if (err != cudaSuccess) { \ + std::stringstream ss; \ + ss << "CUDA error: " << cudaGetErrorString(err) << "(" << err << ") at " << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(ss.str()); \ + } \ + } while (0) + +#define CU_SAFE_CALL(x) \ + do { \ + CUresult err = x; \ + if (err != CUDA_SUCCESS) { \ + const char *err_str; \ + cuGetErrorString(err, &err_str); \ + std::stringstream ss; \ + ss << "CUDA error: " << err_str << "(" << err << ") at " << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(ss.str()); \ + } \ + } while (0) + + +struct Sqrt : ion::BuildingBlock { + ion::Input input{"input0", Int(32), 2}; + ion::Output output{"output0", Int(32), 2}; + Halide::Var x, y; + + void generate() { + output(x, y) = cast(sqrt(input(x, y))); + } + + virtual void schedule() { + Target target = get_target(); + if (target.has_gpu_feature()) { + Var block, thread; + if (target.has_feature(Target::OpenCL)) { + std::cout << "Using OpenCL" << std::endl; + + } else if (target.has_feature(Target::CUDA)) { + std::cout << "Using CUDA" << std::endl; + } else if (target.has_feature(Target::Metal)) { + std::cout << "Using Metal" << std::endl; + } + output.compute_root().gpu_tile(x, y, block, thread, 16, 16); + } + + // Fallback to CPU scheduling + else { + output.compute_root().parallel(y).vectorize(x, 8); + } + } +}; + +ION_REGISTER_BUILDING_BLOCK(Sqrt, Sqrt_gen) + +struct CudaState : public Halide::JITUserContext { + void *cuda_context = nullptr, *cuda_stream = nullptr; + std::atomic acquires = 0, releases = 0; + + static int my_cuda_acquire_context(JITUserContext *ctx, void **cuda_ctx, bool create) { + CudaState *state = (CudaState *)ctx; + *cuda_ctx = state->cuda_context; + state->acquires++; + return 0; + } + + static int my_cuda_release_context(JITUserContext *ctx) { + CudaState *state = (CudaState *)ctx; + state->releases++; + return 0; + } + + static int my_cuda_get_stream(JITUserContext *ctx, void *cuda_ctx, void **stream) { + CudaState *state = (CudaState *)ctx; + *stream = state->cuda_stream; + return 0; + } + + CudaState() { + handlers.custom_cuda_acquire_context = my_cuda_acquire_context; + handlers.custom_cuda_release_context = my_cuda_release_context; + handlers.custom_cuda_get_stream = my_cuda_get_stream; + } +}; + +int main() { + using namespace Halide; + + constexpr int width = 16, height = 16; + bool flip = true; + std::vector input_vec(width * height, 2024); + std::vector output_vec(width * height, 0); + try { + // CUDA setup + CudaState state; + + // Ensure to initialize cuda Context under the hood + CU_SAFE_CALL(cuInit(0)); + + CUdevice device; + CU_SAFE_CALL(cuDeviceGet(&device, 0)); + + CU_SAFE_CALL(cuCtxCreate(reinterpret_cast(&state.cuda_context), 0, device)); + + std::cout << "CUcontext is created on application side : " << state.cuda_context << std::endl; + + CU_SAFE_CALL(cuCtxSetCurrent(reinterpret_cast(state.cuda_context))); + + // CUstream is interchangeable with cudaStream_t (ref: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER.html) + CU_SAFE_CALL(cuStreamCreate(reinterpret_cast(&state.cuda_stream), CU_STREAM_DEFAULT)); + + constexpr int size = width * height * sizeof(int32_t); + void *src; + CUDA_SAFE_CALL(cudaMalloc(&src, size)); + CUDA_SAFE_CALL(cudaMemcpy(src, input_vec.data(), size, cudaMemcpyHostToDevice)); + + void *dst; + CUDA_SAFE_CALL(cudaMalloc(&dst, size)); + + // ION execution + { + + Target target = get_host_target().with_feature(Target::CUDA).with_feature(Target::TracePipeline).with_feature(Target::Debug); + auto device_interface = get_device_interface_for_device_api(DeviceAPI::CUDA, target); + + Halide::Buffer<> inputBuffer(Halide::Int(32), nullptr, height, width); + inputBuffer.device_wrap_native(device_interface, reinterpret_cast(src), &state); + inputBuffer.set_device_dirty(true); + + Halide::Buffer<> outputBuffer(Halide::Int(32), nullptr, height, width); + outputBuffer.device_wrap_native(device_interface, reinterpret_cast(dst), &state); + + ion::Builder b; + b.set_target(target); + b.set_jit_context(&state); + + ion::Graph graph(b); + + ion::Node cn = graph.add("Sqrt_gen")(inputBuffer); + cn["output0"].bind(outputBuffer); + + b.run(); + + cudaMemcpy(output_vec.data(), dst, size, cudaMemcpyDeviceToHost); + for (int i = 0; i < height; i++) { + for (int j = 0; j < width; j++) { + std::cerr << output_vec[i * width + j] << " "; + if (44 != output_vec[i * width + j]) { + return -1; + } + } + std::cerr << std::endl; + } + } + + // CUDA cleanup + CUDA_SAFE_CALL(cudaFree(src)); + CUDA_SAFE_CALL(cudaFree(dst)); + + CU_SAFE_CALL(cuStreamDestroy(reinterpret_cast(state.cuda_stream))); + CU_SAFE_CALL(cuCtxDestroy(reinterpret_cast(state.cuda_context))); + + } catch (const Halide::Error &e) { + std::cerr << e.what() << std::endl; + return 1; + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; + } + + std::cout << "Passed" << std::endl; + + return 0; +}