Skip to content

Commit

Permalink
implement 'get_default_context' as a free function
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev committed Nov 18, 2024
1 parent 03a32c5 commit d132f22
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 24 deletions.
35 changes: 23 additions & 12 deletions third_party/intel/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,28 @@ struct BuildFlags {
}
};

#ifdef WIN32
sycl::context get_default_context(const sycl::device &sycl_device) {
auto platform = sycl_device.get_platform();
sycl::context ctx;
try {
ctx = platform.ext_oneapi_get_default_context();
} catch (const std::runtime_error &ex) {
// This exception is thrown on Windows because
// ext_oneapi_get_default_context is not implemented. But it can be safely
// ignored it seems.
#if _DEBUG
std::cout << "ERROR: " << ex.what() << std::endl;
#endif
}
return ctx;
}
#else
sycl::context get_default_context(const sycl::device &sycl_device) {
return sycl_device.get_platform().ext_oneapi_get_default_context();
}
#endif

static PyObject *loadBinary(PyObject *self, PyObject *args) {
const char *name, *build_flags_ptr;
int shared;
Expand Down Expand Up @@ -194,18 +216,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
const size_t binary_size = PyBytes_Size(py_bytes);

uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes);
auto platform = sycl_device.get_platform();
sycl::context ctx;
try {
ctx = platform.ext_oneapi_get_default_context();
} catch (const std::runtime_error &ex) {
// This exception is thrown on Windows because
// ext_oneapi_get_default_context is not implemented. But it can be safely
// ignored it seems.
#if _DEBUG
std::cout << "ERROR: " << ex.what() << std::endl;
#endif
}
sycl::context ctx = get_default_context(sycl_device);
const auto l0_device =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
const auto l0_context =
Expand Down
36 changes: 24 additions & 12 deletions utils/SPIRVRunner/SPIRVRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,28 @@ static inline T checkSyclErrors(const std::tuple<T, ze_result_t> tuple) {
return std::get<0>(tuple);
}

#ifdef WIN32
sycl::context get_default_context(const sycl::device &sycl_device) {
auto platform = sycl_device.get_platform();
sycl::context ctx;
try {
ctx = platform.ext_oneapi_get_default_context();
} catch (const std::runtime_error &ex) {
// This exception is thrown on Windows because
// ext_oneapi_get_default_context is not implemented. But it can be safely
// ignored it seems.
#if _DEBUG
std::cout << "ERROR: " << ex.what() << std::endl;
#endif
}
return ctx;
}
#else
sycl::context get_default_context(const sycl::device &sycl_device) {
return sycl_device.get_platform().ext_oneapi_get_default_context();
}
#endif

/** SYCL Functions **/
std::tuple<sycl::kernel_bundle<sycl::bundle_state::executable>, sycl::kernel,
int32_t, int32_t>
Expand All @@ -138,18 +160,8 @@ loadBinary(const std::string &kernel_name, const std::string &build_flags,
const auto &sycl_l0_device_pair = g_sycl_l0_device_list[deviceId];
const sycl::device sycl_device = sycl_l0_device_pair.first;

auto platform = sycl_device.get_platform();
sycl::context ctx;
try {
ctx = platform.ext_oneapi_get_default_context();
} catch (const std::runtime_error &ex) {
// This exception is thrown on Windows because
// ext_oneapi_get_default_context is not implemented. But it can be safely
// ignored it seems.
#if _DEBUG
std::cout << "ERROR: " << ex.what() << std::endl;
#endif
}
sycl::context ctx = get_default_context(sycl_device);

const auto l0_device =
sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
const auto l0_context =
Expand Down

0 comments on commit d132f22

Please sign in to comment.