Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve CUDA interoperability #290

Merged
merged 4 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ endif()
#
find_package(Threads REQUIRED)
find_package(Halide REQUIRED COMPONENTS shared)
find_package(CUDA)
find_package(CUDAToolkit)

#
# Version
Expand Down
6 changes: 6 additions & 0 deletions include/ion/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class Builder {
*/
Builder& set_target(const Target& target);

/**
* Set the user context which will be applied the pipeline built with this builder.
* @arg user_context_ptr: The pointer to the user context.
*/
Builder& set_jit_context(Halide::JITUserContext *user_context_ptr);

/**
* Load bb module dynamically and enable it to compile your pipeline.
* @arg module_path: DSO path on your filesystem.
Expand Down
24 changes: 11 additions & 13 deletions include/ion/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class Port {
}

impl_->instances[i] = v;
impl_->bound_address[i] = std::make_tuple(v,false);
impl_->bound_address[i] = std::make_tuple(v, false);
}

template<typename T>
Expand All @@ -202,8 +202,9 @@ class Port {
impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i,graph_id())};
}

impl_->instances[i] = buf.raw_buffer();
impl_->bound_address[i] = std::make_tuple(buf.data(),false);
auto raw_buf = buf.raw_buffer();
impl_->instances[i] = raw_buf;
impl_->bound_address[i] = std::make_tuple(raw_buf->host ? reinterpret_cast<const char*>(raw_buf->host) : reinterpret_cast<const char*>(raw_buf->device), false);
}

template<typename T>
Expand All @@ -215,8 +216,9 @@ class Port {
impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())};
}

impl_->instances[i] = bufs[i].raw_buffer();
impl_->bound_address[i] = std::make_tuple(bufs[i].data(),false);
auto raw_buf = bufs[i].raw_buffer();
impl_->instances[i] = raw_buf;
impl_->bound_address[i] = std::make_tuple(raw_buf->host ? reinterpret_cast<const char*>(raw_buf->host) : reinterpret_cast<const char*>(raw_buf->device), false);
}

}
Expand All @@ -239,25 +241,21 @@ class Port {
}

std::vector<Halide::Func> as_func() const {
// if (dimensions() == 0) {
// throw std::runtime_error("Unreachable");
// }

std::vector<Halide::Func> fs;
for (const auto& [i, param] : impl_->params ) {
if (fs.size() <= i) {
fs.resize(i+1, Halide::Func());
}
std::vector<Halide::Var> args;
std::vector<Halide::Expr> args_expr;
for (int i = 0; i < dimensions(); ++i) {
args.push_back(Halide::Var::implicit(i));
args_expr.push_back(Halide::Var::implicit(i));
for (int j = 0; j < dimensions(); ++j) {
args.push_back(Halide::Var::implicit(j));
args_expr.push_back(Halide::Var::implicit(j));
}
Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id()) + "_im");
f(args) = Halide::Internal::Call::make(param, args_expr);
fs[i] = f;
if(std::get<1>(impl_->bound_address[i])){
if (std::get<1>(impl_->bound_address[i])) {
f.compute_root();
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ Builder& Builder::set_target(const Halide::Target& target) {
return *this;
}

Builder& Builder::set_jit_context(Halide::JITUserContext *user_context_ptr) {
impl_->jit_ctx_ptr = user_context_ptr;
return *this;
}

Builder& Builder::with_bb_module(const std::string& module_name_or_path) {
auto bb_module = std::make_shared<DynamicModule>(module_name_or_path);
auto register_extern = bb_module->get_symbol<void (*)(std::map<std::string, Halide::JITExtern>&)>("register_externs");
Expand Down
6 changes: 3 additions & 3 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ ion_jit_executable(port-assign SRCS port-assign.cc)
# zero copy i/o for extern functions
ion_jit_executable(direct-extern SRCS direct-extern.cc)

if (${CUDA_FOUND})
ion_jit_executable(gpu-extern SRCS gpu-extern.cc)
cuda_add_library(gpu-extern-lib SHARED gpu-extern-lib.cu)
if (${CUDAToolkit_FOUND})
ion_jit_executable(gpu-extern SRCS gpu-extern.cc gpu-extern-lib.cu)
ion_jit_executable(cuda-interop SRCS cuda-interop.cc LIBS CUDA::cuda_driver CUDA::cudart)
endif()

# Duplicate name test
Expand Down
184 changes: 184 additions & 0 deletions test/cuda-interop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#include <cassert>
#include <sstream>
#include <vector>

#include <cuda_runtime.h>
#include <cuda.h>

#include "ion/ion.h"

#define CUDA_SAFE_CALL(x) \
do { \
cudaError_t err = x; \
if (err != cudaSuccess) { \
std::stringstream ss; \
ss << "CUDA error: " << cudaGetErrorString(err) << "(" << err << ") at " << __FILE__ << ":" << __LINE__; \
throw std::runtime_error(ss.str()); \
} \
} while (0)

#define CU_SAFE_CALL(x) \
do { \
CUresult err = x; \
if (err != CUDA_SUCCESS) { \
const char *err_str; \
cuGetErrorString(err, &err_str); \
std::stringstream ss; \
ss << "CUDA error: " << err_str << "(" << err << ") at " << __FILE__ << ":" << __LINE__; \
throw std::runtime_error(ss.str()); \
} \
} while (0)


struct Sqrt : ion::BuildingBlock<Sqrt> {
ion::Input<Halide::Func> input{"input0", Int(32), 2};
ion::Output<Halide::Func> output{"output0", Int(32), 2};
Halide::Var x, y;

void generate() {
output(x, y) = cast<int>(sqrt(input(x, y)));
}

virtual void schedule() {
Target target = get_target();
if (target.has_gpu_feature()) {
Var block, thread;
if (target.has_feature(Target::OpenCL)) {
std::cout << "Using OpenCL" << std::endl;

} else if (target.has_feature(Target::CUDA)) {
std::cout << "Using CUDA" << std::endl;
} else if (target.has_feature(Target::Metal)) {
std::cout << "Using Metal" << std::endl;
}
output.compute_root().gpu_tile(x, y, block, thread, 16, 16);
}

// Fallback to CPU scheduling
else {
output.compute_root().parallel(y).vectorize(x, 8);
}
}
};

ION_REGISTER_BUILDING_BLOCK(Sqrt, Sqrt_gen)

struct CudaState : public Halide::JITUserContext {
void *cuda_context = nullptr, *cuda_stream = nullptr;
std::atomic<int> acquires = 0, releases = 0;

static int my_cuda_acquire_context(JITUserContext *ctx, void **cuda_ctx, bool create) {
CudaState *state = (CudaState *)ctx;
*cuda_ctx = state->cuda_context;
state->acquires++;
return 0;
}

static int my_cuda_release_context(JITUserContext *ctx) {
CudaState *state = (CudaState *)ctx;
state->releases++;
return 0;
}

static int my_cuda_get_stream(JITUserContext *ctx, void *cuda_ctx, void **stream) {
CudaState *state = (CudaState *)ctx;
*stream = state->cuda_stream;
return 0;
}

CudaState() {
handlers.custom_cuda_acquire_context = my_cuda_acquire_context;
handlers.custom_cuda_release_context = my_cuda_release_context;
handlers.custom_cuda_get_stream = my_cuda_get_stream;
}
};

int main() {
using namespace Halide;

constexpr int width = 16, height = 16;
bool flip = true;
std::vector<int32_t> input_vec(width * height, 2024);
std::vector<int32_t> output_vec(width * height, 0);
try {
// CUDA setup
CudaState state;

// Ensure to initialize cuda Context under the hood
CU_SAFE_CALL(cuInit(0));

CUdevice device;
CU_SAFE_CALL(cuDeviceGet(&device, 0));

CU_SAFE_CALL(cuCtxCreate(reinterpret_cast<CUcontext *>(&state.cuda_context), 0, device));

std::cout << "CUcontext is created on application side : " << state.cuda_context << std::endl;

CU_SAFE_CALL(cuCtxSetCurrent(reinterpret_cast<CUcontext>(state.cuda_context)));

// CUstream is interchangeable with cudaStream_t (ref: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER.html)
CU_SAFE_CALL(cuStreamCreate(reinterpret_cast<CUstream *>(&state.cuda_stream), CU_STREAM_DEFAULT));

constexpr int size = width * height * sizeof(int32_t);
void *src;
CUDA_SAFE_CALL(cudaMalloc(&src, size));
CUDA_SAFE_CALL(cudaMemcpy(src, input_vec.data(), size, cudaMemcpyHostToDevice));

void *dst;
CUDA_SAFE_CALL(cudaMalloc(&dst, size));

// ION execution
{

Target target = get_host_target().with_feature(Target::CUDA).with_feature(Target::TracePipeline).with_feature(Target::Debug);
auto device_interface = get_device_interface_for_device_api(DeviceAPI::CUDA, target);

Halide::Buffer<> inputBuffer(Halide::Int(32), nullptr, height, width);
inputBuffer.device_wrap_native(device_interface, reinterpret_cast<uint64_t>(src), &state);
inputBuffer.set_device_dirty(true);

Halide::Buffer<> outputBuffer(Halide::Int(32), nullptr, height, width);
outputBuffer.device_wrap_native(device_interface, reinterpret_cast<uint64_t>(dst), &state);

ion::Builder b;
b.set_target(target);
b.set_jit_context(&state);

ion::Graph graph(b);

ion::Node cn = graph.add("Sqrt_gen")(inputBuffer);
cn["output0"].bind(outputBuffer);

b.run();

cudaMemcpy(output_vec.data(), dst, size, cudaMemcpyDeviceToHost);
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
std::cerr << output_vec[i * width + j] << " ";
if (44 != output_vec[i * width + j]) {
return -1;
}
}
std::cerr << std::endl;
}
}

// CUDA cleanup
CUDA_SAFE_CALL(cudaFree(src));
CUDA_SAFE_CALL(cudaFree(dst));

CU_SAFE_CALL(cuStreamDestroy(reinterpret_cast<CUstream>(state.cuda_stream)));
CU_SAFE_CALL(cuCtxDestroy(reinterpret_cast<CUcontext>(state.cuda_context)));

} catch (const Halide::Error &e) {
std::cerr << e.what() << std::endl;
return 1;
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
return 1;
}

std::cout << "Passed" << std::endl;

return 0;
}
Loading