Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feature/restructure-port-map' in…
Browse files Browse the repository at this point in the history
…to feature/restructure-port-map
  • Loading branch information
Fixstars-iizuka committed Dec 21, 2023
2 parents 024d9c5 + d9bc39f commit 95c13a6
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 25 deletions.
1 change: 0 additions & 1 deletion cmake/IonUtil.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ function(ion_compile NAME)
# Build compile
add_executable(${NAME} ${IEC_SRCS})
if(UNIX)
target_compile_options(${NAME} PUBLIC -fno-rtti) # For Halide::Generator
target_link_options(${NAME} PUBLIC -Wl,--export-dynamic) # For JIT compiling
endif()
target_include_directories(${NAME} PUBLIC "${PROJECT_SOURCE_DIR}/include")
Expand Down
4 changes: 2 additions & 2 deletions python/ionpy/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def __del__(self):
if self.obj: # check not nullptr
ion_node_destroy(self.obj)

def get_port(self, key: str) -> Port:
def get_port(self, name: str) -> Port:
c_port = c_ion_port_t()

ret = ion_node_get_port(self.obj, key.encode(), ctypes.byref(c_port))
ret = ion_node_get_port(self.obj, name.encode(), ctypes.byref(c_port))
if ret != 0:
raise Exception('Invalid operation')

Expand Down
8 changes: 4 additions & 4 deletions python/ionpy/Port.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from .native import (
c_ion_port_t,
ion_port_create,
ion_port_destroy,
ion_port_create_with_index,
ion_port_destroy,
)

from .Type import Type
Expand All @@ -23,18 +23,18 @@ def __init__(self,
obj_ = c_ion_port_t()
type_cobj = type.to_cobj()

ret = ion_port_create(ctypes.byref(obj_), self.name.encode(), type_cobj, self.dim)
ret = ion_port_create(ctypes.byref(obj_), name.encode(), type_cobj, dim)
if ret != 0:
raise Exception('Invalid operation')

self.obj = obj_

def __getitem__(self, i):
def __getitem__(self, index):
new_obj = c_ion_port_t()
ret = ion_port_create_with_index(ctypes.byref(new_obj), self.obj, index)
if ret != 0:
raise Exception('Invalid operation')
return new_obj
return Port(obj_=new_obj)

def __del__(self):
if self.obj: # check not nullptr
Expand Down
8 changes: 4 additions & 4 deletions python/ionpy/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class c_builder_compile_option_t(ctypes.Structure):
ion_port_create.restype = ctypes.c_int
ion_port_create.argtypes = [ ctypes.POINTER(c_ion_port_t), ctypes.c_char_p, c_ion_type_t, ctypes.c_int ]

# ion_port_index_access(ion_port_t, int);
ion_port_index_access = ion_core.ion_port_index_access
ion_port_index_access.restype = ctypes.c_int
ion_port_index_access.argtypes =[c_ion_port_t, ctypes.c_int ]
# ion_port_create_with_index(ion_port_t*, ion_port_t, int);
ion_port_create_with_index = ion_core.ion_port_create_with_index
ion_port_create_with_index.restype = ctypes.c_int
ion_port_create_with_index.argtypes =[ctypes.POINTER(c_ion_port_t), c_ion_port_t, ctypes.c_int ]

# int ion_port_destroy(ion_port_t);
ion_port_destroy = ion_core.ion_port_destroy
Expand Down
4 changes: 2 additions & 2 deletions python/test/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def test_all():
t = Type(code_=TypeCode.Int, bits_=32, lanes_=1)
input_port = Port(key='input', type=t, dim=2)
input_port = Port(name='input', type=t, dim=2)
value41 = Param(key='v', val='41')

builder = Builder()
Expand All @@ -32,7 +32,7 @@ def test_all():
obuf.write(data=odata_bytes)

port_map.set_buffer(port=input_port, buffer=ibuf)
port_map.set_buffer(port=node.get_port(key='output'), buffer=obuf)
port_map.set_buffer(port=node.get_port(name='output'), buffer=obuf)

builder.run(port_map=port_map)

Expand Down
2 changes: 1 addition & 1 deletion python/test/test_node_port.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def test_node_port():
t = Type(code_=TypeCode.Int, bits_=32, lanes_=1)

port_to_set = Port(key='iamkey', type=t, dim=3)
port_to_set = Port(name='iamkey', type=t, dim=3)

ports = [ port_to_set, ]

Expand Down
8 changes: 4 additions & 4 deletions python/test/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def test_pipeline():
builder.with_bb_module(path='ion-bb')

node = builder.add('image_io_cameraN').set_param(params=[width, height, urls])
node1 = builder.add("base_normalize_3d_uint8").set_port(ports=[node.get_port(key='output')[0], ]);
node2 = builder.add("base_normalize_3d_uint8").set_port(ports=[node.get_port(key='output')[1], ]);
node1 = builder.add("base_normalize_3d_uint8").set_port(ports=[node.get_port(name='output')[0], ]);
node2 = builder.add("base_normalize_3d_uint8").set_port(ports=[node.get_port(name='output')[1], ]);

port_map = PortMap()

Expand All @@ -36,8 +36,8 @@ def test_pipeline():
obuf1.write(data=odata_bytes1)
obuf2.write(data=odata_bytes2)

port_map.set_buffer(port=node1.get_port(key='output'), buffer=obuf1)
port_map.set_buffer(port=node2.get_port(key='output'), buffer=obuf2)
port_map.set_buffer(port=node1.get_port(name='output'), buffer=obuf1)
port_map.set_buffer(port=node2.get_port(name='output'), buffer=obuf2)

builder.run(port_map=port_map)

Expand Down
2 changes: 1 addition & 1 deletion python/test/test_port.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
def test_port():
t = Type(code_=TypeCode.Int, bits_=32, lanes_=1)

p = Port(key='iamkey', type=t, dim=3)
p = Port(name='iamkey', type=t, dim=3)
print(p)
4 changes: 2 additions & 2 deletions python/test/test_port_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
def test_port():
t = Type(code_=TypeCode.Int, bits_=32, lanes_=1)

p = Port(key='iamkey', type=t, dim=3)
p = Port(name='iamkey', type=t, dim=3)
p = p[1]
print(p)
print(p)
4 changes: 2 additions & 2 deletions python/test/test_port_map_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def test_portmap_access():
obuf2.write(data=odata_bytes2)

port_map = PortMap()
port_map.set_buffer(port=node.get_port(key='output')[0], buffer=obuf1)
port_map.set_buffer(port=node.get_port(key='output')[1], buffer=obuf2)
port_map.set_buffer(port=node.get_port(name='output')[0], buffer=obuf1)
port_map.set_buffer(port=node.get_port(name='output')[1], buffer=obuf2)

builder.run(port_map=port_map)

Expand Down
4 changes: 2 additions & 2 deletions src/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) {
if (arginfo.kind == Halide::Internal::ArgInfoKind::Scalar) {
if (pm.is_mapped(argument_name(port.node_id(), port.name()))) {
// This block should be executed when g.run is called with appropriate PortMap.
const auto& params(pm.get_params(port.name()));
const auto& params(pm.get_params(argument_name(port.node_id(), port.name())));

// validation
// if (arginfo.types.size() != vs.size()) {
Expand Down Expand Up @@ -333,7 +333,7 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) {
} else if (arginfo.kind == Halide::Internal::ArgInfoKind::Function) {
if (pm.is_mapped(argument_name(port.node_id(), port.name()))) {
// This block should be executed when g.run is called with appropriate PortMap.
const auto& params(pm.get_params(port.name()));
const auto& params(pm.get_params(argument_name(port.node_id(), port.name())));

std::vector<Halide::Func> fs;
for (const auto& p : params) {
Expand Down
65 changes: 65 additions & 0 deletions src/c_ion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@

using namespace ion;

namespace {
template<typename T>
std::vector<Halide::Buffer<T>> convert(ion_buffer_t *b, int n) {
std::vector<Halide::Buffer<T>> bs(n);
for (int i=0; i<n; ++i) {
bs[i] = *reinterpret_cast<Halide::Buffer<T>*>(b[i]);
}
return bs;
}
}

//
// ion_port_t
//
Expand Down Expand Up @@ -639,3 +650,57 @@ int ion_port_map_set_buffer(ion_port_map_t obj, ion_port_t p, ion_buffer_t b)

return 0;
}

int ion_port_map_set_buffer_array(ion_port_map_t obj, ion_port_t p, ion_buffer_t *bs, int n)
{
try {
// NOTE: Halide::Buffer class layout is safe to call Halide::Buffer<void>::type()
auto type = reinterpret_cast<Halide::Buffer<void>*>(*bs)->type();
if (type.is_int()) {
if (type.bits() == 8) {
reinterpret_cast<PortMap*>(obj)->set(*reinterpret_cast<Port*>(p), convert<int8_t>(bs, n));
reinterpret_cast<PortMap*>(obj)->set(*reinterpret_cast<Port*>(p), convert<int8_t>(bs, n));
} else if (type.bits() == 16) {
reinterpret_cast<PortMap*>(obj)->set(*reinterpret_cast<Port*>(p), convert<int16_t>(bs, n));
} else if (type.bits() == 32) {
reinterpret_cast<PortMap*>(obj)->set(*reinterpret_cast<Port*>(p), convert<int32_t>(bs, n));
} else if (type.bits() == 64) {
reinterpret_cast<PortMap*>(obj)->set(*reinterpret_cast<Port*>(p), convert<int64_t>(bs, n));
} else {
throw std::runtime_error("Unsupported bits number");
}
} else if (type.is_uint()) {
if (type.bits() == 1) {
reinterpret_cast<PortMap*>(obj)->set(*reinterpret_cast<Port*>(p), convert<bool>(bs, n));
} else if (type.bits() == 8) {
reinterpret_cast<PortMap*>(obj)->set(*reinterpret_cast<Port*>(p), convert<uint8_t>(bs, n));
} else if (type.bits() == 16) {
reinterpret_cast<PortMap*>(obj)->set(*reinterpret_cast<Port*>(p), convert<uint16_t>(bs, n));
} else if (type.bits() == 32) {
reinterpret_cast<PortMap*>(obj)->set(*reinterpret_cast<Port*>(p), convert<uint32_t>(bs, n));
} else if (type.bits() == 64) {
reinterpret_cast<PortMap*>(obj)->set(*reinterpret_cast<Port*>(p), convert<uint64_t>(bs, n));
} else {
throw std::runtime_error("Unsupported bits number");
}
} else if (type.is_float()) {
if (type.bits() == 32) {
reinterpret_cast<PortMap*>(obj)->set(*reinterpret_cast<Port*>(p), convert<float>(bs, n));
} else if (type.bits() == 64) {
reinterpret_cast<PortMap*>(obj)->set(*reinterpret_cast<Port*>(p), convert<double>(bs, n));
} else {
throw std::runtime_error("Unsupported bits number");
}
} else {
throw std::runtime_error("Unsupported type code");
}
} catch (const std::exception& e) {
std::cerr << e.what() << std::endl;
return -1;
} catch (...) {
std::cerr << "Unknown exception was happened." << std::endl;
return -1;
}

return 0;
}

0 comments on commit 95c13a6

Please sign in to comment.