Skip to content

Commit

Permalink
Merge pull request #290 from fixstars/feature/add-cuda-interop-test
Browse files Browse the repository at this point in the history
Improve CUDA interoperability
  • Loading branch information
iitaku authored Jun 28, 2024
2 parents 19f6b51 + dccc453 commit 435701f
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 17 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ endif()
#
find_package(Threads REQUIRED)
find_package(Halide REQUIRED COMPONENTS shared)
find_package(CUDA)
find_package(CUDAToolkit)

#
# Version
Expand Down
6 changes: 6 additions & 0 deletions include/ion/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class Builder {
*/
Builder& set_target(const Target& target);

/**
* Set the user context which will be applied the pipeline built with this builder.
* @arg user_context_ptr: The pointer to the user context.
*/
Builder& set_jit_context(Halide::JITUserContext *user_context_ptr);

/**
* Load bb module dynamically and enable it to compile your pipeline.
* @arg module_path: DSO path on your filesystem.
Expand Down
24 changes: 11 additions & 13 deletions include/ion/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class Port {
}

impl_->instances[i] = v;
impl_->bound_address[i] = std::make_tuple(v,false);
impl_->bound_address[i] = std::make_tuple(v, false);
}

template<typename T>
Expand All @@ -202,8 +202,9 @@ class Port {
impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i,graph_id())};
}

impl_->instances[i] = buf.raw_buffer();
impl_->bound_address[i] = std::make_tuple(buf.data(),false);
auto raw_buf = buf.raw_buffer();
impl_->instances[i] = raw_buf;
impl_->bound_address[i] = std::make_tuple(raw_buf->host ? reinterpret_cast<const char*>(raw_buf->host) : reinterpret_cast<const char*>(raw_buf->device), false);
}

template<typename T>
Expand All @@ -215,8 +216,9 @@ class Port {
impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())};
}

impl_->instances[i] = bufs[i].raw_buffer();
impl_->bound_address[i] = std::make_tuple(bufs[i].data(),false);
auto raw_buf = bufs[i].raw_buffer();
impl_->instances[i] = raw_buf;
impl_->bound_address[i] = std::make_tuple(raw_buf->host ? reinterpret_cast<const char*>(raw_buf->host) : reinterpret_cast<const char*>(raw_buf->device), false);
}

}
Expand All @@ -239,25 +241,21 @@ class Port {
}

std::vector<Halide::Func> as_func() const {
// if (dimensions() == 0) {
// throw std::runtime_error("Unreachable");
// }

std::vector<Halide::Func> fs;
for (const auto& [i, param] : impl_->params ) {
if (fs.size() <= i) {
fs.resize(i+1, Halide::Func());
}
std::vector<Halide::Var> args;
std::vector<Halide::Expr> args_expr;
for (int i = 0; i < dimensions(); ++i) {
args.push_back(Halide::Var::implicit(i));
args_expr.push_back(Halide::Var::implicit(i));
for (int j = 0; j < dimensions(); ++j) {
args.push_back(Halide::Var::implicit(j));
args_expr.push_back(Halide::Var::implicit(j));
}
Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id()) + "_im");
f(args) = Halide::Internal::Call::make(param, args_expr);
fs[i] = f;
if(std::get<1>(impl_->bound_address[i])){
if (std::get<1>(impl_->bound_address[i])) {
f.compute_root();
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ Builder& Builder::set_target(const Halide::Target& target) {
return *this;
}

Builder& Builder::set_jit_context(Halide::JITUserContext *user_context_ptr) {
impl_->jit_ctx_ptr = user_context_ptr;
return *this;
}

Builder& Builder::with_bb_module(const std::string& module_name_or_path) {
auto bb_module = std::make_shared<DynamicModule>(module_name_or_path);
auto register_extern = bb_module->get_symbol<void (*)(std::map<std::string, Halide::JITExtern>&)>("register_externs");
Expand Down
6 changes: 3 additions & 3 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ ion_jit_executable(port-assign SRCS port-assign.cc)
# zero copy i/o for extern functions
ion_jit_executable(direct-extern SRCS direct-extern.cc)

if (${CUDA_FOUND})
ion_jit_executable(gpu-extern SRCS gpu-extern.cc)
cuda_add_library(gpu-extern-lib SHARED gpu-extern-lib.cu)
if (${CUDAToolkit_FOUND})
ion_jit_executable(gpu-extern SRCS gpu-extern.cc gpu-extern-lib.cu)
ion_jit_executable(cuda-interop SRCS cuda-interop.cc LIBS CUDA::cuda_driver CUDA::cudart)
endif()

# Duplicate name test
Expand Down
184 changes: 184 additions & 0 deletions test/cuda-interop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#include <cassert>
#include <sstream>
#include <vector>

#include <cuda_runtime.h>
#include <cuda.h>

#include "ion/ion.h"

#define CUDA_SAFE_CALL(x) \
do { \
cudaError_t err = x; \
if (err != cudaSuccess) { \
std::stringstream ss; \
ss << "CUDA error: " << cudaGetErrorString(err) << "(" << err << ") at " << __FILE__ << ":" << __LINE__; \
throw std::runtime_error(ss.str()); \
} \
} while (0)

#define CU_SAFE_CALL(x) \
do { \
CUresult err = x; \
if (err != CUDA_SUCCESS) { \
const char *err_str; \
cuGetErrorString(err, &err_str); \
std::stringstream ss; \
ss << "CUDA error: " << err_str << "(" << err << ") at " << __FILE__ << ":" << __LINE__; \
throw std::runtime_error(ss.str()); \
} \
} while (0)


struct Sqrt : ion::BuildingBlock<Sqrt> {
ion::Input<Halide::Func> input{"input0", Int(32), 2};
ion::Output<Halide::Func> output{"output0", Int(32), 2};
Halide::Var x, y;

void generate() {
output(x, y) = cast<int>(sqrt(input(x, y)));
}

virtual void schedule() {
Target target = get_target();
if (target.has_gpu_feature()) {
Var block, thread;
if (target.has_feature(Target::OpenCL)) {
std::cout << "Using OpenCL" << std::endl;

} else if (target.has_feature(Target::CUDA)) {
std::cout << "Using CUDA" << std::endl;
} else if (target.has_feature(Target::Metal)) {
std::cout << "Using Metal" << std::endl;
}
output.compute_root().gpu_tile(x, y, block, thread, 16, 16);
}

// Fallback to CPU scheduling
else {
output.compute_root().parallel(y).vectorize(x, 8);
}
}
};

ION_REGISTER_BUILDING_BLOCK(Sqrt, Sqrt_gen)

struct CudaState : public Halide::JITUserContext {
void *cuda_context = nullptr, *cuda_stream = nullptr;
std::atomic<int> acquires = 0, releases = 0;

static int my_cuda_acquire_context(JITUserContext *ctx, void **cuda_ctx, bool create) {
CudaState *state = (CudaState *)ctx;
*cuda_ctx = state->cuda_context;
state->acquires++;
return 0;
}

static int my_cuda_release_context(JITUserContext *ctx) {
CudaState *state = (CudaState *)ctx;
state->releases++;
return 0;
}

static int my_cuda_get_stream(JITUserContext *ctx, void *cuda_ctx, void **stream) {
CudaState *state = (CudaState *)ctx;
*stream = state->cuda_stream;
return 0;
}

CudaState() {
handlers.custom_cuda_acquire_context = my_cuda_acquire_context;
handlers.custom_cuda_release_context = my_cuda_release_context;
handlers.custom_cuda_get_stream = my_cuda_get_stream;
}
};

int main() {
using namespace Halide;

constexpr int width = 16, height = 16;
bool flip = true;
std::vector<int32_t> input_vec(width * height, 2024);
std::vector<int32_t> output_vec(width * height, 0);
try {
// CUDA setup
CudaState state;

// Ensure to initialize cuda Context under the hood
CU_SAFE_CALL(cuInit(0));

CUdevice device;
CU_SAFE_CALL(cuDeviceGet(&device, 0));

CU_SAFE_CALL(cuCtxCreate(reinterpret_cast<CUcontext *>(&state.cuda_context), 0, device));

std::cout << "CUcontext is created on application side : " << state.cuda_context << std::endl;

CU_SAFE_CALL(cuCtxSetCurrent(reinterpret_cast<CUcontext>(state.cuda_context)));

// CUstream is interchangeable with cudaStream_t (ref: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER.html)
CU_SAFE_CALL(cuStreamCreate(reinterpret_cast<CUstream *>(&state.cuda_stream), CU_STREAM_DEFAULT));

constexpr int size = width * height * sizeof(int32_t);
void *src;
CUDA_SAFE_CALL(cudaMalloc(&src, size));
CUDA_SAFE_CALL(cudaMemcpy(src, input_vec.data(), size, cudaMemcpyHostToDevice));

void *dst;
CUDA_SAFE_CALL(cudaMalloc(&dst, size));

// ION execution
{

Target target = get_host_target().with_feature(Target::CUDA).with_feature(Target::TracePipeline).with_feature(Target::Debug);
auto device_interface = get_device_interface_for_device_api(DeviceAPI::CUDA, target);

Halide::Buffer<> inputBuffer(Halide::Int(32), nullptr, height, width);
inputBuffer.device_wrap_native(device_interface, reinterpret_cast<uint64_t>(src), &state);
inputBuffer.set_device_dirty(true);

Halide::Buffer<> outputBuffer(Halide::Int(32), nullptr, height, width);
outputBuffer.device_wrap_native(device_interface, reinterpret_cast<uint64_t>(dst), &state);

ion::Builder b;
b.set_target(target);
b.set_jit_context(&state);

ion::Graph graph(b);

ion::Node cn = graph.add("Sqrt_gen")(inputBuffer);
cn["output0"].bind(outputBuffer);

b.run();

cudaMemcpy(output_vec.data(), dst, size, cudaMemcpyDeviceToHost);
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
std::cerr << output_vec[i * width + j] << " ";
if (44 != output_vec[i * width + j]) {
return -1;
}
}
std::cerr << std::endl;
}
}

// CUDA cleanup
CUDA_SAFE_CALL(cudaFree(src));
CUDA_SAFE_CALL(cudaFree(dst));

CU_SAFE_CALL(cuStreamDestroy(reinterpret_cast<CUstream>(state.cuda_stream)));
CU_SAFE_CALL(cuCtxDestroy(reinterpret_cast<CUcontext>(state.cuda_context)));

} catch (const Halide::Error &e) {
std::cerr << e.what() << std::endl;
return 1;
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
return 1;
}

std::cout << "Passed" << std::endl;

return 0;
}

0 comments on commit 435701f

Please sign in to comment.