From ef3f3c988f82e01b7e5062112eff214491c021cf Mon Sep 17 00:00:00 2001 From: Takuro Iizuka Date: Thu, 27 Jun 2024 13:03:17 -0700 Subject: [PATCH 1/4] Backported test --- CMakeLists.txt | 2 +- test/CMakeLists.txt | 6 +-- test/cuda-interop.cc | 119 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 test/cuda-interop.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index e63dd9f2..ae834f2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,7 +30,7 @@ endif() # find_package(Threads REQUIRED) find_package(Halide REQUIRED COMPONENTS shared) -find_package(CUDA) +find_package(CUDAToolkit) # # Version diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2b90c1ce..3df35f70 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -54,9 +54,9 @@ ion_jit_executable(port-assign SRCS port-assign.cc) # zero copy i/o for extern functions ion_jit_executable(direct-extern SRCS direct-extern.cc) -if (${CUDA_FOUND}) - ion_jit_executable(gpu-extern SRCS gpu-extern.cc) - cuda_add_library(gpu-extern-lib SHARED gpu-extern-lib.cu) +if (${CUDAToolkit_FOUND}) + ion_jit_executable(gpu-extern SRCS gpu-extern.cc gpu-extern-lib.cu) + ion_jit_executable(cuda-interop SRCS cuda-interop.cc LIBS CUDA::cudart) endif() # Duplicate name test diff --git a/test/cuda-interop.cc b/test/cuda-interop.cc new file mode 100644 index 00000000..cdef9fb1 --- /dev/null +++ b/test/cuda-interop.cc @@ -0,0 +1,119 @@ +#include +#include + +#include + +#include "ion/ion.h" + +#define CUDA_SAFE_CALL(x) \ + do { \ + cudaError_t err = x; \ + if (err != cudaSuccess) { \ + std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__ << std::endl; \ + exit(1); \ + } \ + } while (0) + +struct Sqrt : ion::BuildingBlock { + ion::Input input{"input0", Int(32), 2}; + ion::Output output{"output0", Int(32), 2}; + Halide::Var x, y; + + void generate() { + output(x, y) = cast(sqrt(input(x, y))); + } + + virtual void schedule() { + Target target = get_target(); + if (target.has_gpu_feature()) { + Var block, thread; + if (target.has_feature(Target::OpenCL)) { + std::cout << "Using OpenCL" << std::endl; + + } else if (target.has_feature(Target::CUDA)) { + std::cout << "Using CUDA" << std::endl; + } else if (target.has_feature(Target::Metal)) { + std::cout << "Using Metal" << std::endl; + } + output.compute_root().gpu_tile(x, y, block, thread, 16, 16); + } + + // Fallback to CPU scheduling + else { + output.compute_root().parallel(y).vectorize(x, 8); + } + } +}; + +ION_REGISTER_BUILDING_BLOCK(Sqrt, Sqrt_gen) + +int main() { + using namespace Halide; + + constexpr int width = 16, height = 16; + bool flip = true; + std::vector input_vec(width * height, 2024); + std::vector output_vec(width * height, 0); + try { + constexpr int size = width * height * sizeof(int32_t); + // void *src = subsystem->malloc(size); + void *src; + CUDA_SAFE_CALL(cudaMalloc(&src, size)); + + // void *dst = subsystem->malloc(size); + void *dst; + CUDA_SAFE_CALL(cudaMalloc(&dst, size)); + + Target target = get_host_target().with_feature(Target::CUDA); + // CudaState state(subsystem->getContext(), subsystem->getStream()); + + // subsystem->memcpy(src, input_vec.data(), size, RocaMemcpyKind::kMemcpyHostToDevice); + auto device_interface = get_device_interface_for_device_api(DeviceAPI::CUDA, target); + + assert(device_interface); + + Halide::Buffer<> inputBuffer(Halide::Int(32), nullptr, height, width); + inputBuffer.device_wrap_native(device_interface, (uintptr_t)src); + // inputBuffer.set_device_dirty(true); + inputBuffer.set_host_dirty(false); + Halide::Buffer<> outputBuffer(Halide::Int(32), nullptr, height, width); + outputBuffer.device_wrap_native(device_interface, (uintptr_t)dst); + outputBuffer.set_device_dirty(true); + + ion::Port input{"input0", Int(32), 2}; + ion::Builder b; + b.set_target(target); + ion::Graph graph(b); + + ion::Node cn = graph.add("Sqrt_gen")(input); + cn(inputBuffer); + cn["output0"].bind(outputBuffer); + + b.run(); + outputBuffer.device_sync(); // Figure out how to replace halide's stream + + cudaMemcpy(output_vec.data(), dst, size, cudaMemcpyDeviceToHost); + for (int i = 0; i < height; i++) { + for (int j = 0; j < width; j++) { + assert(44 == output_vec[i * width + j]); // 44 IS sqrt of 2024 + } + } + + // subsystem->free(src); + CUDA_SAFE_CALL(cudaFree(src)); + + // subsystem->free(dst); + CUDA_SAFE_CALL(cudaFree(dst)); + + } catch (const Halide::Error &e) { + std::cerr << e.what() << std::endl; + return 1; + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; + } + + std::cout << "Passed" << std::endl; + + return 0; +} From 13a9430c6e1028a60e64a0ee574a3dd169e6c92a Mon Sep 17 00:00:00 2001 From: Takuro Iizuka Date: Thu, 27 Jun 2024 19:30:07 -0700 Subject: [PATCH 2/4] WIP --- test/cuda-interop.cc | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/test/cuda-interop.cc b/test/cuda-interop.cc index cdef9fb1..c4545428 100644 --- a/test/cuda-interop.cc +++ b/test/cuda-interop.cc @@ -59,12 +59,13 @@ int main() { // void *src = subsystem->malloc(size); void *src; CUDA_SAFE_CALL(cudaMalloc(&src, size)); + CUDA_SAFE_CALL(cudaMemcpy(src, input_vec.data(), size, cudaMemcpyHostToDevice)); // void *dst = subsystem->malloc(size); void *dst; CUDA_SAFE_CALL(cudaMalloc(&dst, size)); - Target target = get_host_target().with_feature(Target::CUDA); + Target target = get_host_target().with_feature(Target::CUDA).with_feature(Target::TracePipeline).with_feature(Target::Debug); // CudaState state(subsystem->getContext(), subsystem->getStream()); // subsystem->memcpy(src, input_vec.data(), size, RocaMemcpyKind::kMemcpyHostToDevice); @@ -73,30 +74,42 @@ int main() { assert(device_interface); Halide::Buffer<> inputBuffer(Halide::Int(32), nullptr, height, width); - inputBuffer.device_wrap_native(device_interface, (uintptr_t)src); + inputBuffer.device_wrap_native(device_interface, reinterpret_cast(src)); + + // inputBuffer.device_wrap_native(device_interface, (uintptr_t)src); // inputBuffer.set_device_dirty(true); - inputBuffer.set_host_dirty(false); - Halide::Buffer<> outputBuffer(Halide::Int(32), nullptr, height, width); - outputBuffer.device_wrap_native(device_interface, (uintptr_t)dst); - outputBuffer.set_device_dirty(true); + // inputBuffer.set_host_dirty(false); + // Halide::Buffer outputBuffer(Halide::Int(32), height, width); + // outputBuffer.device_malloc(device_interface); + // Halide::Buffer<> outputBuffer(Halide::Int(32), reinterpret_cast(1), height, width); + Halide::Buffer<> outputBuffer(Halide::Int(32), nullptr, height, width); + outputBuffer.device_wrap_native(device_interface, reinterpret_cast(dst)); + // outputBuffer.set_device_dirty(true); ion::Port input{"input0", Int(32), 2}; ion::Builder b; b.set_target(target); ion::Graph graph(b); - ion::Node cn = graph.add("Sqrt_gen")(input); - cn(inputBuffer); + ion::Node cn = graph.add("Sqrt_gen")(inputBuffer); + // input.bind(inputBuffer); + // cn(inputBuffer); cn["output0"].bind(outputBuffer); b.run(); outputBuffer.device_sync(); // Figure out how to replace halide's stream + // outputBuffer.copy_to_host(); cudaMemcpy(output_vec.data(), dst, size, cudaMemcpyDeviceToHost); + // memcpy(output_vec.data(), outputBuffer.data(), size); for (int i = 0; i < height; i++) { for (int j = 0; j < width; j++) { - assert(44 == output_vec[i * width + j]); // 44 IS sqrt of 2024 + std::cerr << output_vec[i * width + j] << " "; + if (44 != output_vec[i * width +j]) { + return -1; + } } + std::cerr << std::endl; } // subsystem->free(src); From cd87b3879208d7f2b98fc0ef86e790dd7ec82be4 Mon Sep 17 00:00:00 2001 From: Takuro Iizuka Date: Fri, 28 Jun 2024 13:49:26 -0700 Subject: [PATCH 3/4] WIP --- include/ion/port.h | 24 ++++++------ test/CMakeLists.txt | 2 +- test/cuda-interop.cc | 90 ++++++++++++++++++++++++++++++++++---------- 3 files changed, 82 insertions(+), 34 deletions(-) diff --git a/include/ion/port.h b/include/ion/port.h index e199d620..817d2970 100644 --- a/include/ion/port.h +++ b/include/ion/port.h @@ -190,7 +190,7 @@ class Port { } impl_->instances[i] = v; - impl_->bound_address[i] = std::make_tuple(v,false); + impl_->bound_address[i] = std::make_tuple(v, false); } template @@ -202,8 +202,9 @@ class Port { impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i,graph_id())}; } - impl_->instances[i] = buf.raw_buffer(); - impl_->bound_address[i] = std::make_tuple(buf.data(),false); + auto raw_buf = buf.raw_buffer(); + impl_->instances[i] = raw_buf; + impl_->bound_address[i] = std::make_tuple(raw_buf->host ? reinterpret_cast(raw_buf->host) : reinterpret_cast(raw_buf->device), false); } template @@ -215,8 +216,9 @@ class Port { impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())}; } - impl_->instances[i] = bufs[i].raw_buffer(); - impl_->bound_address[i] = std::make_tuple(bufs[i].data(),false); + auto raw_buf = bufs[i].raw_buffer(); + impl_->instances[i] = raw_buf; + impl_->bound_address[i] = std::make_tuple(raw_buf->host ? reinterpret_cast(raw_buf->host) : reinterpret_cast(raw_buf->device), false); } } @@ -239,10 +241,6 @@ class Port { } std::vector as_func() const { -// if (dimensions() == 0) { -// throw std::runtime_error("Unreachable"); -// } - std::vector fs; for (const auto& [i, param] : impl_->params ) { if (fs.size() <= i) { @@ -250,14 +248,14 @@ class Port { } std::vector args; std::vector args_expr; - for (int i = 0; i < dimensions(); ++i) { - args.push_back(Halide::Var::implicit(i)); - args_expr.push_back(Halide::Var::implicit(i)); + for (int j = 0; j < dimensions(); ++j) { + args.push_back(Halide::Var::implicit(j)); + args_expr.push_back(Halide::Var::implicit(j)); } Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id()) + "_im"); f(args) = Halide::Internal::Call::make(param, args_expr); fs[i] = f; - if(std::get<1>(impl_->bound_address[i])){ + if (std::get<1>(impl_->bound_address[i])) { f.compute_root(); } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3df35f70..83d3e890 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -56,7 +56,7 @@ ion_jit_executable(direct-extern SRCS direct-extern.cc) if (${CUDAToolkit_FOUND}) ion_jit_executable(gpu-extern SRCS gpu-extern.cc gpu-extern-lib.cu) - ion_jit_executable(cuda-interop SRCS cuda-interop.cc LIBS CUDA::cudart) + ion_jit_executable(cuda-interop SRCS cuda-interop.cc LIBS CUDA::cuda_driver CUDA::cudart) endif() # Duplicate name test diff --git a/test/cuda-interop.cc b/test/cuda-interop.cc index c4545428..3ff730b8 100644 --- a/test/cuda-interop.cc +++ b/test/cuda-interop.cc @@ -1,19 +1,35 @@ #include +#include #include #include +#include #include "ion/ion.h" -#define CUDA_SAFE_CALL(x) \ - do { \ - cudaError_t err = x; \ - if (err != cudaSuccess) { \ - std::cerr << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__ << std::endl; \ - exit(1); \ - } \ +#define CUDA_SAFE_CALL(x) \ + do { \ + cudaError_t err = x; \ + if (err != cudaSuccess) { \ + std::stringstream ss; \ + ss << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(ss.str()); \ + } \ } while (0) +#define CU_SAFE_CALL(x) \ + do { \ + CUresult err = x; \ + if (err != CUDA_SUCCESS) { \ + const char *err_str; \ + cuGetErrorString(err, &err_str); \ + std::stringstream ss; \ + ss << "CUDA error: " << err_str << " at " << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(ss.str()); \ + } \ + } while (0) + + struct Sqrt : ion::BuildingBlock { ion::Input input{"input0", Int(32), 2}; ion::Output output{"output0", Int(32), 2}; @@ -47,6 +63,36 @@ struct Sqrt : ion::BuildingBlock { ION_REGISTER_BUILDING_BLOCK(Sqrt, Sqrt_gen) +struct CudaState : public Halide::JITUserContext { + void *cuda_context = nullptr, *cuda_stream = nullptr; + std::atomic acquires = 0, releases = 0; + + static int my_cuda_acquire_context(JITUserContext *ctx, void **cuda_ctx, bool create) { + CudaState *state = (CudaState *)ctx; + *cuda_ctx = state->cuda_context; + state->acquires++; + return 0; + } + + static int my_cuda_release_context(JITUserContext *ctx) { + CudaState *state = (CudaState *)ctx; + state->releases++; + return 0; + } + + static int my_cuda_get_stream(JITUserContext *ctx, void *cuda_ctx, void **stream) { + CudaState *state = (CudaState *)ctx; + *stream = state->cuda_stream; + return 0; + } + + CudaState() { + handlers.custom_cuda_acquire_context = my_cuda_acquire_context; + handlers.custom_cuda_release_context = my_cuda_release_context; + handlers.custom_cuda_get_stream = my_cuda_get_stream; + } +}; + int main() { using namespace Halide; @@ -55,6 +101,19 @@ int main() { std::vector input_vec(width * height, 2024); std::vector output_vec(width * height, 0); try { + // Ensure to initialize cuda Context under the hood + CUDA_SAFE_CALL(cudaSetDevice(0)); + + CudaState state; + + CUstream stream; // This is interchangeable with cudaStream_t (ref: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER.html) + CU_SAFE_CALL(cuStreamCreate(&stream, CU_STREAM_DEFAULT)); + state.cuda_stream = reinterpret_cast(stream); + + CUcontext ctx; + CU_SAFE_CALL(cuStreamGetCtx(stream, &ctx)); + state.cuda_context = reinterpret_cast(ctx); + constexpr int size = width * height * sizeof(int32_t); // void *src = subsystem->malloc(size); void *src; @@ -75,33 +134,24 @@ int main() { Halide::Buffer<> inputBuffer(Halide::Int(32), nullptr, height, width); inputBuffer.device_wrap_native(device_interface, reinterpret_cast(src)); - - // inputBuffer.device_wrap_native(device_interface, (uintptr_t)src); - // inputBuffer.set_device_dirty(true); - // inputBuffer.set_host_dirty(false); - // Halide::Buffer outputBuffer(Halide::Int(32), height, width); - // outputBuffer.device_malloc(device_interface); - // Halide::Buffer<> outputBuffer(Halide::Int(32), reinterpret_cast(1), height, width); + inputBuffer.set_device_dirty(true); + Halide::Buffer<> outputBuffer(Halide::Int(32), nullptr, height, width); outputBuffer.device_wrap_native(device_interface, reinterpret_cast(dst)); - // outputBuffer.set_device_dirty(true); - ion::Port input{"input0", Int(32), 2}; ion::Builder b; b.set_target(target); + ion::Graph graph(b); ion::Node cn = graph.add("Sqrt_gen")(inputBuffer); - // input.bind(inputBuffer); - // cn(inputBuffer); cn["output0"].bind(outputBuffer); b.run(); - outputBuffer.device_sync(); // Figure out how to replace halide's stream + // outputBuffer.device_sync(); // Figure out how to replace halide's stream // outputBuffer.copy_to_host(); cudaMemcpy(output_vec.data(), dst, size, cudaMemcpyDeviceToHost); - // memcpy(output_vec.data(), outputBuffer.data(), size); for (int i = 0; i < height; i++) { for (int j = 0; j < width; j++) { std::cerr << output_vec[i * width + j] << " "; From dccc453636846a0d15816d5af5b6962c7022e451 Mon Sep 17 00:00:00 2001 From: Takuro Iizuka Date: Fri, 28 Jun 2024 15:24:18 -0700 Subject: [PATCH 4/4] Injected CUcontext and CUstream --- include/ion/builder.h | 6 +++ src/builder.cc | 5 ++ test/cuda-interop.cc | 116 +++++++++++++++++++++--------------------- 3 files changed, 70 insertions(+), 57 deletions(-) diff --git a/include/ion/builder.h b/include/ion/builder.h index a7be4463..074c54a9 100644 --- a/include/ion/builder.h +++ b/include/ion/builder.h @@ -63,6 +63,12 @@ class Builder { */ Builder& set_target(const Target& target); + /** + * Set the user context which will be applied the pipeline built with this builder. + * @arg user_context_ptr: The pointer to the user context. + */ + Builder& set_jit_context(Halide::JITUserContext *user_context_ptr); + /** * Load bb module dynamically and enable it to compile your pipeline. * @arg module_path: DSO path on your filesystem. diff --git a/src/builder.cc b/src/builder.cc index 3f6774a7..3227194f 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -106,6 +106,11 @@ Builder& Builder::set_target(const Halide::Target& target) { return *this; } +Builder& Builder::set_jit_context(Halide::JITUserContext *user_context_ptr) { + impl_->jit_ctx_ptr = user_context_ptr; + return *this; +} + Builder& Builder::with_bb_module(const std::string& module_name_or_path) { auto bb_module = std::make_shared(module_name_or_path); auto register_extern = bb_module->get_symbol&)>("register_externs"); diff --git a/test/cuda-interop.cc b/test/cuda-interop.cc index 3ff730b8..c128b36b 100644 --- a/test/cuda-interop.cc +++ b/test/cuda-interop.cc @@ -7,26 +7,26 @@ #include "ion/ion.h" -#define CUDA_SAFE_CALL(x) \ - do { \ - cudaError_t err = x; \ - if (err != cudaSuccess) { \ - std::stringstream ss; \ - ss << "CUDA error: " << cudaGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__; \ - throw std::runtime_error(ss.str()); \ - } \ +#define CUDA_SAFE_CALL(x) \ + do { \ + cudaError_t err = x; \ + if (err != cudaSuccess) { \ + std::stringstream ss; \ + ss << "CUDA error: " << cudaGetErrorString(err) << "(" << err << ") at " << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(ss.str()); \ + } \ } while (0) -#define CU_SAFE_CALL(x) \ - do { \ - CUresult err = x; \ - if (err != CUDA_SUCCESS) { \ - const char *err_str; \ - cuGetErrorString(err, &err_str); \ - std::stringstream ss; \ - ss << "CUDA error: " << err_str << " at " << __FILE__ << ":" << __LINE__; \ - throw std::runtime_error(ss.str()); \ - } \ +#define CU_SAFE_CALL(x) \ + do { \ + CUresult err = x; \ + if (err != CUDA_SUCCESS) { \ + const char *err_str; \ + cuGetErrorString(err, &err_str); \ + std::stringstream ss; \ + ss << "CUDA error: " << err_str << "(" << err << ") at " << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(ss.str()); \ + } \ } while (0) @@ -101,72 +101,74 @@ int main() { std::vector input_vec(width * height, 2024); std::vector output_vec(width * height, 0); try { + // CUDA setup + CudaState state; + // Ensure to initialize cuda Context under the hood - CUDA_SAFE_CALL(cudaSetDevice(0)); + CU_SAFE_CALL(cuInit(0)); - CudaState state; + CUdevice device; + CU_SAFE_CALL(cuDeviceGet(&device, 0)); - CUstream stream; // This is interchangeable with cudaStream_t (ref: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER.html) - CU_SAFE_CALL(cuStreamCreate(&stream, CU_STREAM_DEFAULT)); - state.cuda_stream = reinterpret_cast(stream); + CU_SAFE_CALL(cuCtxCreate(reinterpret_cast(&state.cuda_context), 0, device)); - CUcontext ctx; - CU_SAFE_CALL(cuStreamGetCtx(stream, &ctx)); - state.cuda_context = reinterpret_cast(ctx); + std::cout << "CUcontext is created on application side : " << state.cuda_context << std::endl; + + CU_SAFE_CALL(cuCtxSetCurrent(reinterpret_cast(state.cuda_context))); + + // CUstream is interchangeable with cudaStream_t (ref: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER.html) + CU_SAFE_CALL(cuStreamCreate(reinterpret_cast(&state.cuda_stream), CU_STREAM_DEFAULT)); constexpr int size = width * height * sizeof(int32_t); - // void *src = subsystem->malloc(size); void *src; CUDA_SAFE_CALL(cudaMalloc(&src, size)); CUDA_SAFE_CALL(cudaMemcpy(src, input_vec.data(), size, cudaMemcpyHostToDevice)); - // void *dst = subsystem->malloc(size); void *dst; CUDA_SAFE_CALL(cudaMalloc(&dst, size)); - Target target = get_host_target().with_feature(Target::CUDA).with_feature(Target::TracePipeline).with_feature(Target::Debug); - // CudaState state(subsystem->getContext(), subsystem->getStream()); - - // subsystem->memcpy(src, input_vec.data(), size, RocaMemcpyKind::kMemcpyHostToDevice); - auto device_interface = get_device_interface_for_device_api(DeviceAPI::CUDA, target); + // ION execution + { - assert(device_interface); + Target target = get_host_target().with_feature(Target::CUDA).with_feature(Target::TracePipeline).with_feature(Target::Debug); + auto device_interface = get_device_interface_for_device_api(DeviceAPI::CUDA, target); - Halide::Buffer<> inputBuffer(Halide::Int(32), nullptr, height, width); - inputBuffer.device_wrap_native(device_interface, reinterpret_cast(src)); - inputBuffer.set_device_dirty(true); + Halide::Buffer<> inputBuffer(Halide::Int(32), nullptr, height, width); + inputBuffer.device_wrap_native(device_interface, reinterpret_cast(src), &state); + inputBuffer.set_device_dirty(true); - Halide::Buffer<> outputBuffer(Halide::Int(32), nullptr, height, width); - outputBuffer.device_wrap_native(device_interface, reinterpret_cast(dst)); + Halide::Buffer<> outputBuffer(Halide::Int(32), nullptr, height, width); + outputBuffer.device_wrap_native(device_interface, reinterpret_cast(dst), &state); - ion::Builder b; - b.set_target(target); + ion::Builder b; + b.set_target(target); + b.set_jit_context(&state); - ion::Graph graph(b); + ion::Graph graph(b); - ion::Node cn = graph.add("Sqrt_gen")(inputBuffer); - cn["output0"].bind(outputBuffer); + ion::Node cn = graph.add("Sqrt_gen")(inputBuffer); + cn["output0"].bind(outputBuffer); - b.run(); - // outputBuffer.device_sync(); // Figure out how to replace halide's stream - // outputBuffer.copy_to_host(); + b.run(); - cudaMemcpy(output_vec.data(), dst, size, cudaMemcpyDeviceToHost); - for (int i = 0; i < height; i++) { - for (int j = 0; j < width; j++) { - std::cerr << output_vec[i * width + j] << " "; - if (44 != output_vec[i * width +j]) { - return -1; + cudaMemcpy(output_vec.data(), dst, size, cudaMemcpyDeviceToHost); + for (int i = 0; i < height; i++) { + for (int j = 0; j < width; j++) { + std::cerr << output_vec[i * width + j] << " "; + if (44 != output_vec[i * width + j]) { + return -1; + } } + std::cerr << std::endl; } - std::cerr << std::endl; } - // subsystem->free(src); + // CUDA cleanup CUDA_SAFE_CALL(cudaFree(src)); - - // subsystem->free(dst); CUDA_SAFE_CALL(cudaFree(dst)); + + CU_SAFE_CALL(cuStreamDestroy(reinterpret_cast(state.cuda_stream))); + CU_SAFE_CALL(cuCtxDestroy(reinterpret_cast(state.cuda_context))); } catch (const Halide::Error &e) { std::cerr << e.what() << std::endl;