-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #171 from clEsperanto/add-execute
Add-execute
- Loading branch information
Showing
6 changed files
with
187 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from ._core import ( | ||
gpu_info, | ||
execute, | ||
select_backend, | ||
select_device, | ||
get_device, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
|
||
#include "pycle_wrapper.hpp" | ||
|
||
#include "device.hpp" | ||
#include "array.hpp" | ||
#include "utils.hpp" | ||
#include "execution.hpp" | ||
|
||
#include <pybind11/functional.h> | ||
#include <pybind11/numpy.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
namespace py = pybind11; | ||
|
||
auto py_execute(const cle::Device::Pointer &device, const std::string &kernel_name, const std::string &kernel_source, const py::dict ¶meters, const py::tuple &range, const py::dict &constants) -> void | ||
{ | ||
cle::RangeArray global_range = {1, 1, 1}; | ||
switch (range.size()) | ||
{ | ||
case 1: | ||
global_range[0] = range[0].cast<size_t>(); | ||
break; | ||
case 2: | ||
global_range[0] = range[1].cast<size_t>(); | ||
global_range[1] = range[0].cast<size_t>(); | ||
break; | ||
case 3: | ||
global_range[0] = range[2].cast<size_t>(); | ||
global_range[1] = range[1].cast<size_t>(); | ||
global_range[2] = range[0].cast<size_t>(); | ||
break; | ||
default: | ||
throw std::invalid_argument("Error: range tuple must have 3 elements or less. Received " + std::to_string(range.size()) + " elements."); | ||
break; | ||
} | ||
|
||
// manage kernel name and code | ||
const cle::KernelInfo kernel_info = {kernel_name, kernel_source}; | ||
|
||
// convert py::dict paramter to vector<pair<string, array>> | ||
cle::ParameterList clic_parameters; | ||
for (auto item : parameters) | ||
{ | ||
// check if item.second is cle::Array::Pointer | ||
if (py::isinstance<cle::Array>(item.second)) | ||
{ | ||
clic_parameters.push_back({item.first.cast<std::string>(), item.second.cast<cle::Array::Pointer>()}); | ||
} | ||
else if (py::isinstance<py::int_>(item.second)) | ||
{ | ||
// convert py::int to int | ||
clic_parameters.push_back({item.first.cast<std::string>(), item.second.cast<int>()}); | ||
} | ||
else if (py::isinstance<py::float_>(item.second)) | ||
{ | ||
// convert py::float to float | ||
clic_parameters.push_back({item.first.cast<std::string>(), item.second.cast<float>()}); | ||
} | ||
else | ||
{ | ||
throw std::invalid_argument("Error: parameter type not supported. Received " + std::string(py::str(item.second.get_type()).cast<std::string>())); | ||
} | ||
} | ||
|
||
// convert py::dict constant to vector<pair<string, int>> | ||
cle::ConstantList clic_constants; | ||
if (!constants.empty()) | ||
{ | ||
for (auto item : constants) | ||
{ | ||
clic_constants.push_back({item.first.cast<std::string>(), item.second.cast<int>()}); | ||
} | ||
} | ||
|
||
// execute | ||
cle::execute(device, kernel_info, clic_parameters, global_range, clic_constants); | ||
} | ||
|
||
auto execute_(py::module_ &m) -> void | ||
{ | ||
m.def("_execute", &py_execute, "Call execute function from C++.", | ||
py::arg("device"), py::arg("kernel_name"), py::arg("kernel_source"), py::arg("parameters"), py::arg("range"), py::arg("constants")); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ PYBIND11_MODULE(_pyclesperanto, m) | |
types_(m); | ||
core_(m); | ||
array_(m); | ||
execute_(m); | ||
|
||
tier1_(m); | ||
tier2_(m); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import pyclesperanto as cle | ||
import numpy as np | ||
|
||
absolute_ocl = """ | ||
__constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST; | ||
__kernel void absolute( | ||
IMAGE_src_TYPE src, | ||
IMAGE_dst_TYPE dst | ||
) | ||
{ | ||
const int x = get_global_id(0); | ||
const int y = get_global_id(1); | ||
const int z = get_global_id(2); | ||
IMAGE_src_PIXEL_TYPE value = READ_IMAGE(src, sampler, POS_src_INSTANCE(x,y,z,0)).x; | ||
if ( value < 0 ) { | ||
value = -1 * value; | ||
} | ||
WRITE_IMAGE(dst, POS_dst_INSTANCE(x,y,z,0), CONVERT_dst_PIXEL_TYPE(value)); | ||
} | ||
""" | ||
|
||
def test_execute_absolute(): | ||
input = cle.push(np.asarray([ | ||
[1, -1], | ||
[1, -1] | ||
])) | ||
output = cle.create(input) | ||
|
||
param = {'src': input, 'dst': output} | ||
cle.execute(kernel_source=absolute_ocl, kernel_name="absolute", global_size=input.shape, parameters=param) | ||
|
||
print(output) | ||
|
||
a = cle.pull(output) | ||
assert (np.min(a) == 1) | ||
assert (np.max(a) == 1) | ||
assert (np.mean(a) == 1) |