From d9bc39fe81551569431d9a685766faee84f032bc Mon Sep 17 00:00:00 2001 From: Takuro Iizuka Date: Wed, 20 Dec 2023 23:35:55 -0800 Subject: [PATCH] Fixed python test --- cmake/IonUtil.cmake | 2 +- python/ionpy/Node.py | 4 +- python/ionpy/Port.py | 8 ++-- python/ionpy/native.py | 8 ++-- python/test/test_all.py | 4 +- python/test/test_node_port.py | 2 +- python/test/test_pipeline.py | 8 ++-- python/test/test_port.py | 2 +- python/test/test_port_access.py | 4 +- python/test/test_port_map_access.py | 4 +- src/builder.cc | 4 +- src/c_ion.cc | 65 +++++++++++++++++++++++++++++ 12 files changed, 90 insertions(+), 25 deletions(-) diff --git a/cmake/IonUtil.cmake b/cmake/IonUtil.cmake index 9b578109..20ee873d 100644 --- a/cmake/IonUtil.cmake +++ b/cmake/IonUtil.cmake @@ -19,7 +19,7 @@ function(ion_compile NAME) add_executable(${NAME} ${IEC_SRCS}) if(UNIX) target_compile_options(${NAME} - PUBLIC -fno-rtti # For Halide::Generator + #PUBLIC -fno-rtti # For Halide::Generator PUBLIC -rdynamic) # For JIT compiling endif() target_include_directories(${NAME} PUBLIC "${PROJECT_SOURCE_DIR}/include") diff --git a/python/ionpy/Node.py b/python/ionpy/Node.py index 5c70f811..948e3684 100644 --- a/python/ionpy/Node.py +++ b/python/ionpy/Node.py @@ -33,10 +33,10 @@ def __del__(self): if self.obj: # check not nullptr ion_node_destroy(self.obj) - def get_port(self, key: str) -> Port: + def get_port(self, name: str) -> Port: c_port = c_ion_port_t() - ret = ion_node_get_port(self.obj, key.encode(), ctypes.byref(c_port)) + ret = ion_node_get_port(self.obj, name.encode(), ctypes.byref(c_port)) if ret != 0: raise Exception('Invalid operation') diff --git a/python/ionpy/Port.py b/python/ionpy/Port.py index 1988250c..e74150be 100644 --- a/python/ionpy/Port.py +++ b/python/ionpy/Port.py @@ -4,8 +4,8 @@ from .native import ( c_ion_port_t, ion_port_create, - ion_port_destroy, ion_port_create_with_index, + ion_port_destroy, ) from .Type import Type @@ -23,18 +23,18 @@ def __init__(self, obj_ = c_ion_port_t() type_cobj = type.to_cobj() - ret = ion_port_create(ctypes.byref(obj_), self.name.encode(), type_cobj, self.dim) + ret = ion_port_create(ctypes.byref(obj_), name.encode(), type_cobj, dim) if ret != 0: raise Exception('Invalid operation') self.obj = obj_ - def __getitem__(self, i): + def __getitem__(self, index): new_obj = c_ion_port_t() ret = ion_port_create_with_index(ctypes.byref(new_obj), self.obj, index) if ret != 0: raise Exception('Invalid operation') - return new_obj + return Port(obj_=new_obj) def __del__(self): if self.obj: # check not nullptr diff --git a/python/ionpy/native.py b/python/ionpy/native.py index b33c6d93..66af1da7 100644 --- a/python/ionpy/native.py +++ b/python/ionpy/native.py @@ -39,10 +39,10 @@ class c_builder_compile_option_t(ctypes.Structure): ion_port_create.restype = ctypes.c_int ion_port_create.argtypes = [ ctypes.POINTER(c_ion_port_t), ctypes.c_char_p, c_ion_type_t, ctypes.c_int ] -# ion_port_index_access(ion_port_t, int); -ion_port_index_access = ion_core.ion_port_index_access -ion_port_index_access.restype = ctypes.c_int -ion_port_index_access.argtypes =[c_ion_port_t, ctypes.c_int ] +# ion_port_create_with_index(ion_port_t*, ion_port_t, int); +ion_port_create_with_index = ion_core.ion_port_create_with_index +ion_port_create_with_index.restype = ctypes.c_int +ion_port_create_with_index.argtypes =[ctypes.POINTER(c_ion_port_t), c_ion_port_t, ctypes.c_int ] # int ion_port_destroy(ion_port_t); ion_port_destroy = ion_core.ion_port_destroy diff --git a/python/test/test_all.py b/python/test/test_all.py index a8137dd9..96ab4b55 100644 --- a/python/test/test_all.py +++ b/python/test/test_all.py @@ -5,7 +5,7 @@ def test_all(): t = Type(code_=TypeCode.Int, bits_=32, lanes_=1) - input_port = Port(key='input', type=t, dim=2) + input_port = Port(name='input', type=t, dim=2) value41 = Param(key='v', val='41') builder = Builder() @@ -32,7 +32,7 @@ def test_all(): obuf.write(data=odata_bytes) port_map.set_buffer(port=input_port, buffer=ibuf) - port_map.set_buffer(port=node.get_port(key='output'), buffer=obuf) + port_map.set_buffer(port=node.get_port(name='output'), buffer=obuf) builder.run(port_map=port_map) diff --git a/python/test/test_node_port.py b/python/test/test_node_port.py index 1c0152ea..00feb47f 100644 --- a/python/test/test_node_port.py +++ b/python/test/test_node_port.py @@ -4,7 +4,7 @@ def test_node_port(): t = Type(code_=TypeCode.Int, bits_=32, lanes_=1) - port_to_set = Port(key='iamkey', type=t, dim=3) + port_to_set = Port(name='iamkey', type=t, dim=3) ports = [ port_to_set, ] diff --git a/python/test/test_pipeline.py b/python/test/test_pipeline.py index be71df0c..0307ef27 100644 --- a/python/test/test_pipeline.py +++ b/python/test/test_pipeline.py @@ -20,8 +20,8 @@ def test_pipeline(): builder.with_bb_module(path='ion-bb') node = builder.add('image_io_cameraN').set_param(params=[width, height, urls]) - node1 = builder.add("base_normalize_3d_uint8").set_port(ports=[node.get_port(key='output')[0], ]); - node2 = builder.add("base_normalize_3d_uint8").set_port(ports=[node.get_port(key='output')[1], ]); + node1 = builder.add("base_normalize_3d_uint8").set_port(ports=[node.get_port(name='output')[0], ]); + node2 = builder.add("base_normalize_3d_uint8").set_port(ports=[node.get_port(name='output')[1], ]); port_map = PortMap() @@ -36,8 +36,8 @@ def test_pipeline(): obuf1.write(data=odata_bytes1) obuf2.write(data=odata_bytes2) - port_map.set_buffer(port=node1.get_port(key='output'), buffer=obuf1) - port_map.set_buffer(port=node2.get_port(key='output'), buffer=obuf2) + port_map.set_buffer(port=node1.get_port(name='output'), buffer=obuf1) + port_map.set_buffer(port=node2.get_port(name='output'), buffer=obuf2) builder.run(port_map=port_map) diff --git a/python/test/test_port.py b/python/test/test_port.py index a83b9c1f..d78b11d2 100644 --- a/python/test/test_port.py +++ b/python/test/test_port.py @@ -4,5 +4,5 @@ def test_port(): t = Type(code_=TypeCode.Int, bits_=32, lanes_=1) - p = Port(key='iamkey', type=t, dim=3) + p = Port(name='iamkey', type=t, dim=3) print(p) diff --git a/python/test/test_port_access.py b/python/test/test_port_access.py index 2369895a..431f3877 100644 --- a/python/test/test_port_access.py +++ b/python/test/test_port_access.py @@ -5,6 +5,6 @@ def test_port(): t = Type(code_=TypeCode.Int, bits_=32, lanes_=1) - p = Port(key='iamkey', type=t, dim=3) + p = Port(name='iamkey', type=t, dim=3) p = p[1] - print(p) \ No newline at end of file + print(p) diff --git a/python/test/test_port_map_access.py b/python/test/test_port_map_access.py index 088fc95b..3e062f75 100644 --- a/python/test/test_port_map_access.py +++ b/python/test/test_port_map_access.py @@ -32,8 +32,8 @@ def test_portmap_access(): obuf2.write(data=odata_bytes2) port_map = PortMap() - port_map.set_buffer(port=node.get_port(key='output')[0], buffer=obuf1) - port_map.set_buffer(port=node.get_port(key='output')[1], buffer=obuf2) + port_map.set_buffer(port=node.get_port(name='output')[0], buffer=obuf1) + port_map.set_buffer(port=node.get_port(name='output')[1], buffer=obuf2) builder.run(port_map=port_map) diff --git a/src/builder.cc b/src/builder.cc index 7361bbfb..9e2b01a8 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -287,7 +287,7 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) { if (arginfo.kind == Halide::Internal::ArgInfoKind::Scalar) { if (pm.is_mapped(argument_name(port.node_id(), port.name()))) { // This block should be executed when g.run is called with appropriate PortMap. - const auto& params(pm.get_params(port.name())); + const auto& params(pm.get_params(argument_name(port.node_id(), port.name()))); // validation // if (arginfo.types.size() != vs.size()) { @@ -322,7 +322,7 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) { } else if (arginfo.kind == Halide::Internal::ArgInfoKind::Function) { if (pm.is_mapped(argument_name(port.node_id(), port.name()))) { // This block should be executed when g.run is called with appropriate PortMap. - const auto& params(pm.get_params(port.name())); + const auto& params(pm.get_params(argument_name(port.node_id(), port.name()))); std::vector fs; for (const auto& p : params) { diff --git a/src/c_ion.cc b/src/c_ion.cc index 53c21c0e..c99e8c15 100644 --- a/src/c_ion.cc +++ b/src/c_ion.cc @@ -7,6 +7,17 @@ using namespace ion; +namespace { +template +std::vector> convert(ion_buffer_t *b, int n) { + std::vector> bs(n); + for (int i=0; i*>(b[i]); + } + return bs; +} +} + // // ion_port_t // @@ -639,3 +650,57 @@ int ion_port_map_set_buffer(ion_port_map_t obj, ion_port_t p, ion_buffer_t b) return 0; } + +int ion_port_map_set_buffer_array(ion_port_map_t obj, ion_port_t p, ion_buffer_t *bs, int n) +{ + try { + // NOTE: Halide::Buffer class layout is safe to call Halide::Buffer::type() + auto type = reinterpret_cast*>(*bs)->type(); + if (type.is_int()) { + if (type.bits() == 8) { + reinterpret_cast(obj)->set(*reinterpret_cast(p), convert(bs, n)); + reinterpret_cast(obj)->set(*reinterpret_cast(p), convert(bs, n)); + } else if (type.bits() == 16) { + reinterpret_cast(obj)->set(*reinterpret_cast(p), convert(bs, n)); + } else if (type.bits() == 32) { + reinterpret_cast(obj)->set(*reinterpret_cast(p), convert(bs, n)); + } else if (type.bits() == 64) { + reinterpret_cast(obj)->set(*reinterpret_cast(p), convert(bs, n)); + } else { + throw std::runtime_error("Unsupported bits number"); + } + } else if (type.is_uint()) { + if (type.bits() == 1) { + reinterpret_cast(obj)->set(*reinterpret_cast(p), convert(bs, n)); + } else if (type.bits() == 8) { + reinterpret_cast(obj)->set(*reinterpret_cast(p), convert(bs, n)); + } else if (type.bits() == 16) { + reinterpret_cast(obj)->set(*reinterpret_cast(p), convert(bs, n)); + } else if (type.bits() == 32) { + reinterpret_cast(obj)->set(*reinterpret_cast(p), convert(bs, n)); + } else if (type.bits() == 64) { + reinterpret_cast(obj)->set(*reinterpret_cast(p), convert(bs, n)); + } else { + throw std::runtime_error("Unsupported bits number"); + } + } else if (type.is_float()) { + if (type.bits() == 32) { + reinterpret_cast(obj)->set(*reinterpret_cast(p), convert(bs, n)); + } else if (type.bits() == 64) { + reinterpret_cast(obj)->set(*reinterpret_cast(p), convert(bs, n)); + } else { + throw std::runtime_error("Unsupported bits number"); + } + } else { + throw std::runtime_error("Unsupported type code"); + } + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + return -1; + } catch (...) { + std::cerr << "Unknown exception was happened." << std::endl; + return -1; + } + + return 0; +}