diff --git a/include/ion/builder.h b/include/ion/builder.h
index 465002d7..5f0f97cf 100644
--- a/include/ion/builder.h
+++ b/include/ion/builder.h
@@ -39,14 +39,21 @@ class Builder {
~Builder();
/**
- * Adding new node to the graph.
+ * Adding new node to the builder.
* @arg k: The key of the node which should be matched with second argument of ION_REGISTER_BUILDING_BLOCK().
*/
Node add(const std::string& name);
+ /**
+ * Adding new node to the specific graph.
+ * @arg k: The key of the node which should be matched with second argument of ION_REGISTER_BUILDING_BLOCK().
+ * @arg id: graph unique identifier
+ */
+ Node add(const std::string& name, const GraphID& graph_id);
/**
- *
+ * Adding new node to the graph.
+ * @arg k: The key of the node which should be matched with second argument of ION_REGISTER_BUILDING_BLOCK().
*/
Graph add_graph(const std::string& name);
diff --git a/include/ion/c_ion.h b/include/ion/c_ion.h
index 9281adc3..6347c455 100644
--- a/include/ion/c_ion.h
+++ b/include/ion/c_ion.h
@@ -31,6 +31,7 @@ typedef struct ion_node_t_ *ion_node_t;
typedef struct ion_builder_t_ *ion_builder_t;
typedef struct ion_buffer_t_ *ion_buffer_t;
typedef struct ion_port_map_t_ *ion_port_map_t;
+typedef struct ion_graph_t_ *ion_graph_t;
int ion_port_create(ion_port_t *, const char *, ion_type_t, int);
int ion_port_create_with_index(ion_port_t *, ion_port_t , int);
@@ -62,6 +63,7 @@ int ion_builder_create(ion_builder_t *);
int ion_builder_destroy(ion_builder_t);
int ion_builder_set_target(ion_builder_t, const char *);
int ion_builder_with_bb_module(ion_builder_t, const char *);
+int ion_builder_add_graph(ion_builder_t, const char *, ion_graph_t *);
int ion_builder_add_node(ion_builder_t, const char *, ion_node_t *);
int ion_builder_compile(ion_builder_t, const char *, ion_builder_compile_option_t option);
int ion_builder_save(ion_builder_t, const char *);
@@ -76,6 +78,12 @@ int ion_buffer_destroy(ion_buffer_t);
int ion_buffer_write(ion_buffer_t, void *, int size);
int ion_buffer_read(ion_buffer_t, void *, int size);
+int ion_graph_create(ion_graph_t *, ion_builder_t, const char *);
+int ion_graph_add_node(ion_graph_t, const char*, ion_node_t *);
+int ion_graph_destroy(ion_graph_t);
+int ion_graph_run(ion_graph_t);
+int ion_graph_create_with_multiple(ion_graph_t * ptr, ion_graph_t* objs, int size);
+
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
int ion_port_map_create(ion_port_map_t *);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]]
diff --git a/include/ion/node.h b/include/ion/node.h
index 70ec314b..71e17fb8 100644
--- a/include/ion/node.h
+++ b/include/ion/node.h
@@ -20,16 +20,17 @@ class Node {
public:
struct Impl {
- std::string id;
+ NodeID id;
std::string name;
+ GraphID graph_id;
Halide::Target target;
std::vector params;
std::vector ports;
std::vector arginfos;
Impl(): id(), name(), target(), params(), ports() {}
-
- Impl(const std::string& id_, const std::string& name_, const Halide::Target& target_);
+ Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_);
+ Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_, const GraphID &graph_id_);
};
public:
@@ -76,7 +77,7 @@ class Node {
*/
template
Node operator()(Args ...args) {
- set_iport(std::vector{args...});
+ set_iport(std::vector{make_iport(args)...});
return *this;
}
@@ -94,7 +95,7 @@ class Node {
Port operator[](const std::string& name);
// Getter
- const std::string& id() const {
+ const NodeID & id() const {
return impl_->id;
}
@@ -121,11 +122,45 @@ class Node {
std::vector> oports() const;
private:
- Node(const std::string& id, const std::string& name, const Halide::Target& target)
+ Node(const NodeID& id, const std::string& name, const Halide::Target& target)
: impl_(new Impl{id, name, target})
{
}
+ Node(const NodeID&& id, const std::string& name, const Halide::Target& target, const GraphID& graph_id)
+ : impl_(new Impl{id, name, target, graph_id})
+ {
+ }
+
+ Port make_iport(Port arg) const {
+ return arg;
+ }
+
+ template
+ Port make_iport(T *vptr) const {
+ if (to_string(impl_->graph_id).empty())
+ return Port(vptr);
+ else
+ return Port(vptr, impl_->graph_id);
+ }
+
+ template
+ Port make_iport(Halide::Buffer& arg) const {
+ if (to_string(impl_->graph_id).empty())
+ return Port(arg);
+ else
+ return Port(arg, impl_->graph_id);
+ }
+
+ template
+ Port make_iport(std::vector>& arg) const {
+ if (to_string(impl_->graph_id).empty())
+ return Port(arg);
+ else
+ return Port(arg, impl_->graph_id);
+ }
+
+
std::shared_ptr impl_;
};
diff --git a/include/ion/port.h b/include/ion/port.h
index e6669b4c..ed36880b 100644
--- a/include/ion/port.h
+++ b/include/ion/port.h
@@ -49,11 +49,12 @@ class Port {
friend class Node;
public:
- using Channel = std::tuple;
+ using Channel = std::tuple;
private:
struct Impl {
- std::string id;
+ PortID id;
+ GraphID graph_id;
Channel pred_chan;
std::set succ_chans;
@@ -64,12 +65,12 @@ class Port {
std::unordered_map instances;
Impl();
- Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d);
+ Impl(const NodeID& nid, const std::string& pn, const Halide::Type& t, int32_t d, const GraphID &gid );
};
public:
- Port() : impl_(new Impl("", "", Halide::Type(), 0)), index_(-1) {}
+ Port() : impl_(new Impl(NodeID(""), "", Halide::Type(), 0, GraphID(""))), index_(-1) {}
Port(const std::shared_ptr& impl, int32_t index) : impl_(impl), index_(index) {}
@@ -78,7 +79,7 @@ class Port {
* @arg k: The key of the port which should be matched with BuildingBlock Input/Output name.
* @arg t: The type of the value.
*/
- Port(const std::string& n, Halide::Type t) : impl_(new Impl("", n, t, 0)), index_(-1) {}
+ Port(const std::string& n, Halide::Type t) : impl_(new Impl(NodeID(""), n, t, 0, GraphID(""))), index_(-1) {}
/**
* Construct new port for vector value.
@@ -86,22 +87,40 @@ class Port {
* @arg t: The type of the element value.
* @arg d: The dimension of the port. The range is 1 to 4.
*/
- Port(const std::string& n, Halide::Type t, int32_t d) : impl_(new Impl("", n, t, d)), index_(-1) {}
+ Port(const std::string& n, Halide::Type t, int32_t d) : impl_(new Impl(NodeID(""), n, t, d, GraphID(""))), index_(-1) {}
/**
* Construct new port from scalar pointer
*/
template::value>::type* = nullptr>
- Port(T *vptr) : impl_(new Impl("", Halide::Internal::unique_name("_ion_port_"), Halide::type_of(), 0)), index_(-1) {
+ Port(T *vptr) : impl_(new Impl(NodeID(""), Halide::Internal::unique_name("_ion_port_"), Halide::type_of(), 0, GraphID(""))), index_(-1) {
this->bind(vptr);
}
+ /**
+ * Construct new port from scalar pointer
+ */
+ template::value>::type* = nullptr>
+ Port(T *vptr, const GraphID & gid) : impl_(new Impl(NodeID(""), Halide::Internal::unique_name("_ion_port_"), Halide::type_of(), 0, gid)), index_(-1) {
+ this->bind(vptr);
+ }
+
+
/**
* Construct new port from buffer
*/
template
- Port(const Halide::Buffer& buf) : impl_(new Impl("", buf.name(), buf.type(), buf.dimensions())), index_(-1) {
+ Port(const Halide::Buffer& buf) : impl_(new Impl(NodeID(""), buf.name(), buf.type(), buf.dimensions(), GraphID(""))), index_(-1) {
+ this->bind(buf);
+ }
+
+ /**
+ * Construct new port from buffer and bind graph id to port
+ */
+ template
+ Port(const Halide::Buffer& buf, const GraphID & gid) : impl_(new Impl(NodeID(""), buf.name(), buf.type(), buf.dimensions(), gid)), index_(-1) {
this->bind(buf);
}
@@ -109,36 +128,45 @@ class Port {
* Construct new port from array of buffer
*/
template
- Port(const std::vector>& bufs) : impl_(new Impl("", unify_name(bufs), Halide::type_of(), unify_dimension(bufs))), index_(-1) {
+ Port(const std::vector>& bufs) : impl_(new Impl(NodeID(""), unify_name(bufs), Halide::type_of(), unify_dimension(bufs), GraphID(""))), index_(-1) {
+ this->bind(bufs);
+ }
+
+ /**
+ * Construct new port from array of buffer and bind graph id to port
+ */
+ template
+ Port(const std::vector>& bufs, const GraphID & gid) : impl_(new Impl(NodeID(""), unify_name(bufs), Halide::type_of(), unify_dimension(bufs), gid)), index_(-1) {
this->bind(bufs);
}
// Getter
- const std::string& id() const { return impl_->id; }
+ const PortID id() const { return impl_->id; }
const Channel& pred_chan() const { return impl_->pred_chan; }
- const std::string& pred_id() const { return std::get<0>(impl_->pred_chan); }
+ const NodeID& pred_id() const { return std::get<0>(impl_->pred_chan); }
const std::string& pred_name() const { return std::get<1>(impl_->pred_chan); }
const std::set& succ_chans() const { return impl_->succ_chans; }
const Halide::Type& type() const { return impl_->type; }
int32_t dimensions() const { return impl_->dimensions; }
int32_t size() const { return static_cast(impl_->params.size()); }
int32_t index() const { return index_; }
+ const GraphID& graph_id() const { return impl_->graph_id; }
// Setter
void set_index(int index) { index_ = index; }
// Util
- bool has_pred() const { return !std::get<0>(impl_->pred_chan).empty(); }
- bool has_pred_by_nid(const std::string& nid) const { return !std::get<0>(impl_->pred_chan).empty(); }
+ bool has_pred() const { return !std::get<0>(impl_->pred_chan).value().empty(); }
+ bool has_pred_by_nid(const NodeID & nid) const { return !to_string(std::get<0>(impl_->pred_chan)).empty(); }
bool has_succ() const { return !impl_->succ_chans.empty(); }
bool has_succ(const Channel& c) const { return impl_->succ_chans.count(c); }
- bool has_succ_by_nid(const std::string& nid) const {
+ bool has_succ_by_nid(const NodeID& nid) const {
return std::count_if(impl_->succ_chans.begin(),
impl_->succ_chans.end(),
[&](const Port::Channel& c) { return std::get<0>(c) == nid; });
}
- void determine_succ(const std::string& nid, const std::string& old_pn, const std::string& new_pn);
+ void determine_succ(const NodeID& nid, const std::string& old_pn, const std::string& new_pn);
/**
* Overloaded operator to set the port index and return a reference to the current port. eg. port[0]
@@ -153,22 +181,21 @@ class Port {
void bind(T *v) {
auto i = index_ == -1 ? 0 : index_;
if (has_pred()) {
- impl_->params[i] = Halide::Internal::Parameter{Halide::type_of(), false, 0, argument_name(pred_id(), pred_name(), i)};
+ impl_->params[i] = Halide::Internal::Parameter{Halide::type_of(), false, 0, argument_name(pred_id(), pred_name(), i, graph_id())};
} else {
- impl_->params[i] = Halide::Internal::Parameter{type(), false, dimensions(), argument_name(pred_id(), pred_name(), i)};
+ impl_->params[i] = Halide::Internal::Parameter{type(), false, dimensions(), argument_name(pred_id(), pred_name(), i, graph_id())};
}
impl_->instances[i] = v;
}
-
template
void bind(const Halide::Buffer& buf) {
auto i = index_ == -1 ? 0 : index_;
if (has_pred()) {
- impl_->params[i] = Halide::Internal::Parameter{buf.type(), true, buf.dimensions(), argument_name(pred_id(), pred_name(), i)};
+ impl_->params[i] = Halide::Internal::Parameter{buf.type(), true, buf.dimensions(), argument_name(pred_id(), pred_name(), i,graph_id())};
} else {
- impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i)};
+ impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i,graph_id())};
}
impl_->instances[i] = buf.raw_buffer();
@@ -178,9 +205,9 @@ class Port {
void bind(const std::vector>& bufs) {
for (int i=0; i(bufs.size()); ++i) {
if (has_pred()) {
- impl_->params[i] = Halide::Internal::Parameter{bufs[i].type(), true, bufs[i].dimensions(), argument_name(pred_id(), pred_name(), i)};
+ impl_->params[i] = Halide::Internal::Parameter{bufs[i].type(), true, bufs[i].dimensions(), argument_name(pred_id(), pred_name(), i, graph_id())};
} else {
- impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i)};
+ impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i, graph_id())};
}
impl_->instances[i] = bufs[i].raw_buffer();
@@ -199,7 +226,7 @@ class Port {
if (es.size() <= i) {
es.resize(i+1, Halide::Expr());
}
- es[i] = Halide::Internal::Variable::make(type(), argument_name(pred_id(), pred_name(), i), param);
+ es[i] = Halide::Internal::Variable::make(type(), argument_name(pred_id(), pred_name(), i, graph_id()), param);
}
return es;
}
@@ -220,7 +247,7 @@ class Port {
args.push_back(Halide::Var::implicit(i));
args_expr.push_back(Halide::Var::implicit(i));
}
- Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), pred_name(), i) + "_im");
+ Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), pred_name(), i, graph_id()) + "_im");
f(args) = Halide::Internal::Call::make(param, args_expr);
fs[i] = f;
}
@@ -234,7 +261,7 @@ class Port {
args.resize(i+1, Halide::Argument());
}
auto kind = dimensions() == 0 ? Halide::Argument::InputScalar : Halide::Argument::InputBuffer;
- args[i] = Halide::Argument(argument_name(pred_id(), pred_name(), i), kind, type(), dimensions(), Halide::ArgumentEstimates());
+ args[i] = Halide::Argument(argument_name(pred_id(), pred_name(), i, graph_id()), kind, type(), dimensions(), Halide::ArgumentEstimates());
}
return args;
}
@@ -257,7 +284,7 @@ class Port {
* pid and pn is stored in both pred and succ,
* then it will determined through pipeline build process.
*/
- Port(const std::string& pid, const std::string& pn) : impl_(new Impl(pid, pn, Halide::Type(), 0)), index_(-1) {}
+ Port(const NodeID & nid, const std::string& pn) : impl_(new Impl(nid, pn, Halide::Type(), 0, GraphID(""))), index_(-1) {}
std::shared_ptr impl_;
diff --git a/include/ion/port_map.h b/include/ion/port_map.h
index 82a05dc2..ee036910 100644
--- a/include/ion/port_map.h
+++ b/include/ion/port_map.h
@@ -16,7 +16,7 @@ class PortMap {
template
[[deprecated("Port::bind can be used instead of PortMap.")]]
void set(Port port, T v) {
- auto& buf(scalar_buffer_[argument_name(port.pred_id(), port.pred_name(), port.index())]);
+ auto& buf(scalar_buffer_[argument_name(port.pred_id(), port.pred_name(), port.index(), port.graph_id())]);
buf.resize(sizeof(v));
std::memcpy(buf.data(), &v, sizeof(v));
port.bind(reinterpret_cast(buf.data()));
diff --git a/include/ion/util.h b/include/ion/util.h
index 78358457..78ddf3b6 100644
--- a/include/ion/util.h
+++ b/include/ion/util.h
@@ -2,15 +2,61 @@
#define ION_UTIL_H
#include
-
namespace ion {
class Port;
-std::string argument_name(const std::string& node_id, const std::string& name, int32_t index);
-
std::string array_name(const std::string& port_name, size_t i);
+// a string-like identifier that is typed on a tag type
+template
+struct StringID {
+ using tag_type = Tag;
+
+ // needs to be default-constructable because of use in map[] below
+ StringID(std::string s) : _value(std::move(s)) {}
+ StringID() : _value() {}
+ // provide access to the underlying string value
+ const std::string &value() const { return _value; }
+
+ struct StringIDHash {
+ // Use hash of string as hash function.
+ size_t operator()(const StringID& id) const
+ {
+ return std::hash()(id.value());
+ }
+ };
+
+private:
+ std::string _value;
+
+ // will only compare against same type of id.
+ friend bool operator<(const StringID &l, const StringID &r) {
+ return l._value < r._value;
+ }
+
+ friend bool operator==(const StringID &l, const StringID &r) {
+ return l._value == r._value;
+ }
+
+ // and let's go ahead and provide expected free functions
+ friend
+ auto to_string(const StringID &r)
+ -> const std::string & {
+ return r._value;
+ }
+};
+
+struct node_tag {};
+struct graph_tag {};
+struct port_tag {};
+
+using NodeID = StringID;
+using GraphID = StringID;
+using PortID = StringID;
+
+std::string argument_name(const NodeID& node_id, const std::string& name, int32_t index, const GraphID& graph_id);
+
} // namespace ion
#endif
diff --git a/python/ionpy/Builder.py b/python/ionpy/Builder.py
index 05ad4177..7ea9ecae 100644
--- a/python/ionpy/Builder.py
+++ b/python/ionpy/Builder.py
@@ -9,6 +9,7 @@
ion_builder_set_target,
ion_builder_with_bb_module,
+ ion_builder_add_graph,
ion_builder_add_node,
ion_builder_compile,
@@ -19,7 +20,7 @@
ion_builder_run_with_port_map,
)
-
+from .Graph import Graph
from .Node import Node
from .BuilderCompileOption import BuilderCompileOption
from .PortMap import PortMap
@@ -62,6 +63,14 @@ def add(self, key: str) -> Node:
return Node(obj_=c_node)
+ def add_graph(self, name: str) -> Graph:
+ c_graph = c_ion_node_t()
+ ret = ion_builder_add_graph(self.obj, name.encode(), ctypes.byref(c_graph))
+ if ret != 0:
+ raise Exception('Invalid operation')
+
+ return Graph(obj_=c_graph)
+
def compile(self, function_name: str, option: BuilderCompileOption):
ret = ion_builder_compile(self.obj, function_name.encode(), option.to_cobj())
if ret != 0:
diff --git a/python/ionpy/Graph.py b/python/ionpy/Graph.py
new file mode 100644
index 00000000..ff7d1e49
--- /dev/null
+++ b/python/ionpy/Graph.py
@@ -0,0 +1,86 @@
+import ctypes
+from typing import Optional
+
+from .native import (
+ c_ion_node_t,
+ c_ion_graph_t,
+
+ ion_graph_create,
+ ion_graph_destroy,
+ ion_graph_create_with_multiple,
+ ion_graph_run,
+ ion_graph_add_node
+)
+
+from .Node import Node
+
+
+class Graph:
+ def __init__(self,
+ builder = None,
+ name: Optional[str] = None,
+ obj_: Optional[c_ion_graph_t] = None,
+ sub_graphs: [] = None,
+
+ ):
+ if obj_ is None and builder is not None and name is not None:
+ obj_ = c_ion_graph_t()
+ ret = ion_graph_create(ctypes.byref(obj_), builder.obj, name.encode())
+ if ret != 0:
+ raise Exception('Invalid operation')
+ elif obj_ is None and sub_graphs is not None:
+ num_graphs = len(sub_graphs)
+ c_ion_graph_sized_array_t = c_ion_graph_t * num_graphs # arraysize == num_graphs
+ c_graphs = c_ion_graph_sized_array_t() # instance
+ for i in range(num_graphs):
+ c_graphs[i] = sub_graphs[i].obj
+ obj_ = c_ion_graph_t()
+ ret = ion_graph_create_with_multiple(ctypes.byref(obj_), c_graphs, num_graphs)
+ if ret != 0:
+ raise Exception('Invalid operation')
+
+ self.obj = obj_
+
+ def __del__(self):
+ if self.obj: # check not nullptr
+ ion_graph_destroy(self.obj)
+
+ # adding two objects
+
+ def __add__(self, other):
+ if isinstance(other, Graph):
+ c_ion_graph_sized_array_t = c_ion_graph_t * 2 # arraysize == num_graphs
+ c_graphs = c_ion_graph_sized_array_t() # instance
+ c_graphs[0] = self.obj
+ c_graphs[1] = other.obj
+ new_obj = c_ion_graph_t()
+ ret = ion_graph_create_with_multiple(ctypes.byref(new_obj), c_graphs, 2)
+ if ret != 0:
+ raise Exception('Invalid operation')
+ return Graph(obj_=new_obj)
+
+ def __iadd__(self, other):
+ if isinstance(other, Graph):
+ c_ion_graph_sized_array_t = c_ion_graph_t * 2 # arraysize == num_graphs
+ c_graphs = c_ion_graph_sized_array_t() # instance
+ c_graphs[0] = self.obj
+ c_graphs[1] = other.obj
+ ret = ion_graph_create_with_multiple(ctypes.byref(self.obj), c_graphs, 2)
+ if ret != 0:
+ raise Exception('Invalid operation')
+ return self
+
+
+ def run(self):
+ ret = ion_graph_run(self.obj)
+ if ret != 0:
+ raise Exception('Invalid operation')
+
+ def add(self, key: str) -> Node:
+ c_node = c_ion_node_t()
+
+ ret = ion_graph_add_node(self.obj, key.encode(), ctypes.byref(c_node))
+ if ret != 0:
+ raise Exception('Invalid operation')
+
+ return Node(obj_=c_node)
diff --git a/python/ionpy/__init__.py b/python/ionpy/__init__.py
index e2d047ee..d6afedd1 100644
--- a/python/ionpy/__init__.py
+++ b/python/ionpy/__init__.py
@@ -5,7 +5,7 @@
from .Builder import Builder
from .Buffer import Buffer
from .PortMap import PortMap
-
+from .Graph import Graph
from .Type import Type
from .TypeCode import TypeCode
from .BuilderCompileOption import BuilderCompileOption
diff --git a/python/ionpy/native.py b/python/ionpy/native.py
index a551048f..7fa9ae2b 100644
--- a/python/ionpy/native.py
+++ b/python/ionpy/native.py
@@ -38,6 +38,7 @@ class c_builder_compile_option_t(ctypes.Structure):
c_ion_node_t = ctypes.POINTER(ctypes.c_int)
c_ion_builder_t = ctypes.POINTER(ctypes.c_int)
+c_ion_graph_t = ctypes.POINTER(ctypes.c_int)
c_ion_buffer_t = ctypes.POINTER(ctypes.c_int)
c_ion_port_map_t = ctypes.POINTER(ctypes.c_int)
@@ -212,6 +213,11 @@ class c_builder_compile_option_t(ctypes.Structure):
ion_builder_bb_metadata.restype = ctypes.c_int
ion_builder_bb_metadata.argtypes = [ c_ion_builder_t, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_int) ]
+# int ion_builder_add_graph(ion_builder_t, const char , ion_graph_t *)
+ion_builder_add_graph = ion_core.ion_builder_add_graph
+ion_builder_add_graph.restype = ctypes.c_int
+ion_builder_add_graph.argtypes = [ c_ion_builder_t, ctypes.c_char_p, ctypes.POINTER(c_ion_graph_t) ]
+
# int ion_builder_run(ion_builder_t, ion_port_map_t);
ion_builder_run = ion_core.ion_builder_run
@@ -251,6 +257,30 @@ class c_builder_compile_option_t(ctypes.Structure):
ion_buffer_read.restype = ctypes.c_int
ion_buffer_read.argtypes = [ c_ion_buffer_t, ctypes.c_void_p, ctypes.c_int ]
+# int ion_graph_create(ion_graph_t *, ion_builder_t, const char *)
+ion_graph_create = ion_core.ion_graph_create
+ion_graph_create.restype = ctypes.c_int
+ion_graph_create.argtypes =[ ctypes.POINTER(c_ion_graph_t), c_ion_builder_t, ctypes.c_char_p ]
+
+# int ion_graph_create_with_multiple(ion_graph_t*, ion_graph_t*, int size)
+ion_graph_create_with_multiple = ion_core.ion_graph_create_with_multiple
+ion_graph_create_with_multiple.restype = ctypes.c_int
+ion_graph_create_with_multiple.argtypes = [ ctypes.POINTER(c_ion_graph_t), ctypes.POINTER(c_ion_graph_t), ctypes.c_int]
+
+# int ion_graph_add_node(ion_graph_t, const char*, ion_node_t *)
+ion_graph_add_node = ion_core.ion_graph_add_node
+ion_graph_add_node.restype = ctypes.c_int
+ion_graph_add_node.argtypes =[ c_ion_graph_t, ctypes.c_char_p, ctypes.POINTER(c_ion_node_t) ]
+
+# int ion_graph_run(ion_graph_t)
+ion_graph_run=ion_core.ion_graph_run
+ion_graph_run.restype = ctypes.c_int
+ion_graph_run.argtypes =[ c_ion_graph_t]
+
+# ion_graph_destroy(ion_graph_t)
+ion_graph_destroy=ion_core.ion_graph_destroy
+ion_graph_destroy.restype = ctypes.c_int
+ion_graph_destroy.argtypes =[ c_ion_graph_t]
# int ion_port_map_create(ion_port_map_t *);
diff --git a/python/test/test_graph.py b/python/test/test_graph.py
new file mode 100644
index 00000000..6dab4d49
--- /dev/null
+++ b/python/test/test_graph.py
@@ -0,0 +1,52 @@
+from ionpy import Builder, Graph, Port, Buffer, Type
+import numpy as np
+
+def test_graph():
+ input_port0 = Port(name='input', type=Type.from_dtype(np.dtype(np.int32)), dim=2)
+ value_port0 = Port(name='v', type=Type.from_dtype(np.dtype(np.int32)), dim=0)
+
+ builder = Builder()
+ builder.set_target(target='host')
+ builder.with_bb_module(path='ion-bb-test')
+
+ graph0 = builder.add_graph("graph0")
+
+ node0 = graph0.add('test_incx_i32x2').set_iport([input_port0, value_port0])
+
+ idata0 = np.array([[42, 42]], dtype=np.int32)
+ ibuf0 = Buffer(array=idata0)
+
+ odata0 = np.array([[0, 0]], dtype=np.int32)
+ obuf0 = Buffer(array=odata0)
+
+ input_port0.bind(ibuf0)
+ output_port0 = node0.get_port(name='output')
+ output_port0.bind(obuf0)
+ value_port0.bind(0)
+
+ input_port1 = Port(name='input', type=Type.from_dtype(np.dtype(np.int32)), dim=2)
+ value_port1 = Port(name='v', type=Type.from_dtype(np.dtype(np.int32)), dim=0)
+
+ graph1 = builder.add_graph("graph1")
+ # graph1 = Graph(builder =builder, name="graph1") # alternative
+ node1 = graph1.add('test_incx_i32x2').set_iport([input_port1, value_port1])
+
+ idata1 = np.array([[42, 42]], dtype=np.int32)
+ ibuf1 = Buffer(array=idata1)
+
+ odata1 = np.array([[0, 0]], dtype=np.int32)
+ obuf1 = Buffer(array=odata1)
+
+ input_port1.bind(ibuf1)
+ output_port = node1.get_port(name='output')
+ output_port.bind(obuf1)
+ value_port1.bind(1)
+
+ g = graph0 + graph1
+ g.run()
+ # alternative
+ # graph1 += graph0
+ # g = g(sub_graphs=[graph1,graph0])
+
+ assert odata0[0][0] == 42
+ assert odata1[0][0] == 43
diff --git a/src/builder.cc b/src/builder.cc
index 75c6934a..27f342ce 100644
--- a/src/builder.cc
+++ b/src/builder.cc
@@ -88,6 +88,13 @@ Node Builder::add(const std::string& name)
return n;
}
+Node Builder::add(const std::string& name, const GraphID & graph_id)
+{
+ Node n(sole::uuid4().str(), name, impl_->target, graph_id);
+ impl_->nodes.push_back(n);
+ return n;
+}
+
Graph Builder::add_graph(const std::string& name) {
Graph g(*this, name);
impl_->graphs.push_back(g);
diff --git a/src/c_ion.cc b/src/c_ion.cc
index b5f11e3d..35798690 100644
--- a/src/c_ion.cc
+++ b/src/c_ion.cc
@@ -2,7 +2,6 @@
#include
#include
-
#include
#include "log.h"
@@ -437,6 +436,23 @@ int ion_builder_with_bb_module(ion_builder_t obj, const char *module_name)
log::error("Unknown exception was happened");
return 1;
}
+ return 0;
+}
+
+int ion_builder_add_graph(ion_builder_t obj, const char *name, ion_graph_t *graph_ptr)
+{
+ try {
+ *graph_ptr = reinterpret_cast(new Graph(reinterpret_cast(obj)->add_graph(name)));
+ } catch (const Halide::Error& e) {
+ log::error(e.what());
+ return 1;
+ } catch (const std::exception& e) {
+ log::error(e.what());
+ return 1;
+ } catch (...) {
+ log::error("Unknown exception was happened");
+ return 1;
+ }
return 0;
}
@@ -1026,3 +1042,95 @@ int ion_port_map_set_buffer_array(ion_port_map_t obj, ion_port_t p, ion_buffer_t
return 0;
}
+
+
+int ion_graph_create(ion_graph_t *ptr, ion_builder_t obj, const char * name)
+{
+ try {
+ *ptr = reinterpret_cast(new Graph(*reinterpret_cast(obj), name));
+ } catch (const Halide::Error& e) {
+ log::error(e.what());
+ return 1;
+ } catch (const std::exception& e) {
+ log::error(e.what());
+ return 1;
+ } catch (...) {
+ log::error("Unknown exception was happened");
+ return 1;
+ }
+
+ return 0;
+}
+
+int ion_graph_add_node(ion_graph_t obj, const char *name, ion_node_t *node_ptr)
+{
+ try {
+ *node_ptr = reinterpret_cast(new Node(reinterpret_cast(obj)->add(name)));
+ } catch (const Halide::Error& e) {
+ log::error(e.what());
+ return 1;
+ } catch (const std::exception& e) {
+ log::error(e.what());
+ return 1;
+ } catch (...) {
+ log::error("Unknown exception was happened");
+ return 1;
+ }
+
+ return 0;
+}
+
+int ion_graph_create_with_multiple(ion_graph_t * ptr, ion_graph_t* graphs_ptr, int graphs_num) {
+ try {
+ auto sum_graph = *reinterpret_cast(graphs_ptr[0]);
+ for (int i=1; i(graphs_ptr[i]);
+ }
+ *ptr = reinterpret_cast(new Graph(sum_graph));
+ } catch (const Halide::Error& e) {
+ log::error(e.what());
+ return 1;
+ } catch (const std::exception& e) {
+ log::error(e.what());
+ return 1;
+ } catch (...) {
+ log::error("Unknown exception was happened");
+ return 1;
+ }
+
+ return 0;
+}
+
+int ion_graph_run(ion_graph_t obj)
+{
+ try {
+ reinterpret_cast(obj)->run();
+ } catch (const Halide::Error& e) {
+ log::error(e.what());
+ return 1;
+ } catch (const std::exception& e) {
+ log::error(e.what());
+ return 1;
+ } catch (...) {
+ log::error("Unknown exception was happened");
+ return 1;
+ }
+
+ return 0;
+}
+
+int ion_graph_destroy(ion_graph_t obj){
+ try {
+ delete reinterpret_cast(obj);
+ } catch (const Halide::Error& e) {
+ log::error(e.what());
+ return 1;
+ } catch (const std::exception& e) {
+ log::error(e.what());
+ return 1;
+ } catch (...) {
+ log::error("Unknown exception was happened");
+ return 1;
+ }
+ return 0;
+}
diff --git a/src/graph.cc b/src/graph.cc
index 819262f8..b3ddff04 100644
--- a/src/graph.cc
+++ b/src/graph.cc
@@ -3,24 +3,27 @@
#include "log.h"
#include "lower.h"
-
+#include "uuid/sole.hpp"
namespace ion {
struct Graph::Impl {
Builder builder;
std::string name;
+ GraphID id;
std::vector nodes;
-
// Cacheable
Halide::Pipeline pipeline;
Halide::Callable callable;
std::unique_ptr jit_ctx;
Halide::JITUserContext* jit_ctx_ptr;
std::vector args;
-
- Impl(Builder b, const std::string& n)
- : builder(b), name(n), jit_ctx(new Halide::JITUserContext), jit_ctx_ptr(jit_ctx.get())
+ Impl()
+ : id(sole::uuid4().str())
{}
+ Impl(Builder b, const std::string& n)
+ : id(sole::uuid4().str()), builder(b), name(n), jit_ctx(new Halide::JITUserContext), jit_ctx_ptr(jit_ctx.get())
+ {
+ }
};
Graph::Graph()
@@ -48,7 +51,8 @@ Graph operator+(const Graph& lhs, const Graph& rhs)
Node Graph::add(const std::string& name)
{
- auto n = impl_->builder.add(name);
+ auto n = impl_->builder.add(name,impl_->id);
+
impl_->nodes.push_back(n);
return n;
}
diff --git a/src/lower.cc b/src/lower.cc
index 406642f4..07725941 100644
--- a/src/lower.cc
+++ b/src/lower.cc
@@ -218,14 +218,14 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_
topological_sort(nodes);
// Constructing Generator object and setting static parameters
- std::unordered_map bbs;
+ std::unordered_map bbs;
for (auto n : nodes) {
auto bb(Halide::Internal::GeneratorRegistry::create(n.name(), Halide::GeneratorContext(n.target())));
// Default parameter
Halide::GeneratorParamsMap params;
params["builder_impl_ptr"] = std::to_string(reinterpret_cast(builder.impl_ptr()));
- params["bb_id"] = n.id();
+ params["bb_id"] = to_string(n.id());
// User defined parameter
for (const auto& p : n.params()) {
@@ -243,7 +243,7 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_
for (const auto& [pn, port] : n.iports()) {
// Find arginfo
- auto it = std::find_if(arginfos.begin(), arginfos.end(), [pn](const ArgInfo& arginfo) { return arginfo.name == pn; });
+ auto it = std::find_if(arginfos.begin(), arginfos.end(), [&pn=pn](const ArgInfo& arginfo) { return arginfo.name == pn; });
if (it == arginfos.end()) {
auto msg = fmt::format("Argument {} is not defined in node {}", pn, n.name());
log::error(msg);
@@ -295,7 +295,7 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_
if (implicit_output) {
// Collects all output which is never referenced.
// This mode is used for AOT compilation
- std::unordered_map> referenced;
+ std::unordered_map, NodeID::StringIDHash> referenced;
for (const auto& n : nodes) {
for (const auto& [pn, port] : n.iports()) {
if (port.has_pred()) {
diff --git a/src/node.cc b/src/node.cc
index 816eaa31..73257a0c 100644
--- a/src/node.cc
+++ b/src/node.cc
@@ -4,7 +4,8 @@
namespace ion {
-Node::Impl::Impl(const std::string& id_, const std::string& name_, const Halide::Target& target_)
+
+Node::Impl::Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_)
: id(id_), name(name_), target(target_), params(), ports()
{
auto bb(Halide::Internal::GeneratorRegistry::create(name_, Halide::GeneratorContext(target_)));
@@ -16,10 +17,22 @@ Node::Impl::Impl(const std::string& id_, const std::string& name_, const Halide:
arginfos = bb->arginfos();
}
+Node::Impl::Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_, const GraphID& graph_id_)
+ : id(id_), name(name_), target(target_), params(), ports(), graph_id(graph_id_)
+{
+ auto bb(Halide::Internal::GeneratorRegistry::create(name_, Halide::GeneratorContext(target_)));
+ if (!bb) {
+ log::error("BuildingBlock {} is not found", name_);
+ throw std::runtime_error("Failed to create building block object");
+ }
+
+ arginfos = bb->arginfos();
+}
+
void Node::set_iport(const std::vector& ports) {
impl_->ports.erase(std::remove_if(impl_->ports.begin(), impl_->ports.end(),
- [&](const Port &p) { return p.has_succ_by_nid(this->id()); }),
+ [&](const Port &p) { return p.has_succ_by_nid(this->id());}),
impl_->ports.end());
size_t i = 0;
@@ -37,7 +50,7 @@ void Node::set_iport(const std::vector& ports) {
// NOTE: Is succ_chans name OK to be just leave as it is?
port.impl_->succ_chans.insert({id(), "_ion_iport_" + std::to_string(i)});
-
+ port.impl_ ->graph_id = impl_->graph_id;
impl_->ports.push_back(port);
i++;
@@ -45,11 +58,13 @@ void Node::set_iport(const std::vector& ports) {
}
void Node::set_iport(Port port) {
+ port.impl_ ->graph_id = impl_->graph_id;
port.impl_->succ_chans.insert({id(), port.pred_name()});
impl_->ports.push_back(port);
}
void Node::set_iport(const std::string& name, Port port) {
+ port.impl_ ->graph_id = impl_->graph_id;
port.impl_->succ_chans.insert({id(), name});
impl_->ports.push_back(port);
}
@@ -60,7 +75,8 @@ Port Node::operator[](const std::string& name) {
if (it == impl_->ports.end()) {
// This is output port which is never referenced.
// Bind myself as a predecessor and register
- Port port(impl_->id, name);
+ Port port(id(), name);
+ port.impl_ ->graph_id = impl_->graph_id;
impl_->ports.push_back(port);
return port;
} else {
diff --git a/src/port.cc b/src/port.cc
index a90a1506..752a8d25 100644
--- a/src/port.cc
+++ b/src/port.cc
@@ -6,24 +6,24 @@
namespace ion {
Port::Impl::Impl()
- : id(sole::uuid4().str()), pred_chan{"", ""}, succ_chans{}, type(), dimensions(-1)
+ : id(PortID(sole::uuid4().str())), pred_chan{"", ""}, succ_chans{}, type(), dimensions(-1)
{
}
-Port::Impl::Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d)
- : id(sole::uuid4().str()), pred_chan{pid, pn}, succ_chans{}, type(t), dimensions(d)
+Port::Impl::Impl(const NodeID & nid, const std::string& pn, const Halide::Type& t, int32_t d, const GraphID & gid)
+ : id(PortID(sole::uuid4().str())), pred_chan{nid, pn}, succ_chans{}, type(t), dimensions(d), graph_id(gid)
{
- params[0] = Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(pid, pn, 0));
+ params[0] = Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(nid, pn, 0, gid));
}
-void Port::determine_succ(const std::string& nid, const std::string& old_pn, const std::string& new_pn) {
+void Port::determine_succ(const NodeID& nid, const std::string& old_pn, const std::string& new_pn) {
auto it = std::find(impl_->succ_chans.begin(), impl_->succ_chans.end(), Channel{nid, old_pn});
if (it == impl_->succ_chans.end()) {
log::error("fixme");
throw std::runtime_error("fixme");
}
- log::debug("Determine free port {} as {} on Node {}", old_pn, new_pn, nid);
+ log::debug("Determine free port {} as {} on Node {}", old_pn, new_pn, nid.value());
impl_->succ_chans.erase(it);
impl_->succ_chans.insert(Channel{nid, new_pn});
}
diff --git a/src/serializer.h b/src/serializer.h
index c1f42967..92ee377e 100644
--- a/src/serializer.h
+++ b/src/serializer.h
@@ -42,9 +42,14 @@ static void from_json(const json& j, ion::Param& v) {
template<>
struct adl_serializer {
static void to_json(json& j, const ion::Port& v) {
- j["id"] = v.id();
- j["pred_chan"] = v.pred_chan();
- j["succ_chans"] = v.succ_chans();
+ j["id"] = to_string(v.id());
+ std::map stringMap;
+ j["pred_chan"] = std::make_tuple(to_string(std::get<0>(v.pred_chan())), std::get<1>(v.pred_chan()));
+ std::set> succ_chans;
+ for (auto& c:v.succ_chans()){
+ succ_chans.insert(std::make_tuple(to_string(std::get<0>(c)), std::get<1>(c)));
+ }
+ j["succ_chans"] = succ_chans;
j["type"] = static_cast(v.type());
j["dimensions"] = v.dimensions();
j["size"] = v.size();
@@ -54,13 +59,17 @@ struct adl_serializer {
static void from_json(const json& j, ion::Port& v) {
auto [impl, found] = ion::Port::find_impl(j["id"].get());
if (!found) {
- impl->pred_chan = j["pred_chan"].get();
- impl->succ_chans = j["succ_chans"].get>();
+ impl->pred_chan = j["pred_chan"].get>();
+ std::set succ_chans;
+ for (auto & p : j["succ_chans"]){
+ succ_chans.insert(p.get>());
+ }
+ impl->succ_chans = succ_chans;
impl->type = j["type"].get();
impl->dimensions = j["dimensions"];
for (auto i=0; iparams[i] = Halide::Internal::Parameter(impl->type, impl->dimensions != 0, impl->dimensions,
- ion::argument_name(std::get<0>(impl->pred_chan), std::get<1>(impl->pred_chan), i));
+ ion::argument_name(std::get<0>(impl->pred_chan), std::get<1>(impl->pred_chan), i, impl->graph_id.value()));
}
}
v = ion::Port(impl, j["index"]);
@@ -70,7 +79,7 @@ struct adl_serializer {
template <>
struct adl_serializer {
static void to_json(json& j, const ion::Node& v) {
- j["id"] = v.id();
+ j["id"] = to_string(v.id());
j["name"] = v.name();
j["target"] = v.target().to_string();
j["params"] = v.params();
diff --git a/src/util.cc b/src/util.cc
index b17bbcb7..8988cbd4 100644
--- a/src/util.cc
+++ b/src/util.cc
@@ -5,12 +5,12 @@
namespace ion {
-std::string argument_name(const std::string& node_id, const std::string& name, int32_t index) {
+std::string argument_name(const NodeID & node_id, const std::string& name, int32_t index, const GraphID & graph_id) {
if (index == -1) {
index = 0;
}
- std::string s = "_" + node_id + "_" + name + std::to_string(index);;
+ std::string s = "_" + node_id.value() + "_" + name + std::to_string(index) + "_" + graph_id.value();
std::replace(s.begin(), s.end(), '-', '_');
return s;
diff --git a/test/c_api.cc b/test/c_api.cc
index c483b2b8..2088d534 100644
--- a/test/c_api.cc
+++ b/test/c_api.cc
@@ -311,6 +311,200 @@ int main()
if (ret != 0)
return ret;
}
+ {
+
+ ion_type_t t = {.code=ion_type_int, .bits=32, .lanes=1};
+
+
+ ion_builder_t b;
+ ret = ion_builder_create(&b);
+ if (ret != 0)
+ return ret;
+
+ ret = ion_builder_set_target(b, "host");
+ if (ret != 0)
+ return ret;
+
+ ret = ion_builder_with_bb_module(b, "ion-bb-test");
+ if (ret != 0)
+ return ret;
+
+ ion_graph_t g0;
+ ret = ion_builder_add_graph(b, "graph0", &g0);
+ if (ret != 0)
+ return ret;
+
+ ion_param_t v41;
+ ret = ion_param_create(&v41, "v", "41");
+ if (ret != 0)
+ return ret;
+
+ ion_node_t n0;
+ ret = ion_graph_add_node(g0, "test_inc_i32x2" , &n0);
+ if (ret != 0)
+ return ret;
+ int sizes[] = {16, 16};
+
+ ret = ion_node_set_param(n0, &v41, 1);
+ if (ret != 0)
+ return ret;
+
+ ion_port_t ip0;
+ ret = ion_port_create(&ip0, "input0", t, 2);
+ if (ret != 0)
+ return ret;
+ ret = ion_node_set_iport(n0, &ip0, 1);
+ if (ret != 0)
+ return ret;
+
+
+ ion_buffer_t ibuf0;
+ ret = ion_buffer_create(&ibuf0, t, sizes, 2);
+ if (ret != 0)
+ return ret;
+
+ int in0[16*16];
+ for (int i=0; i<16*16; ++i) {
+ in0[i] = 0;
+ }
+ ret = ion_buffer_write(ibuf0, in0, 16*16*sizeof(int));
+ if (ret != 0)
+ return ret;
+
+ ret = ion_port_bind_buffer(ip0, ibuf0);
+ if (ret != 0)
+ return ret;
+
+ ion_port_t op0;
+ ret = ion_node_get_port(n0, "output", &op0);
+ if (ret != 0)
+ return ret;
+
+ ion_buffer_t obuf0;
+ ret = ion_buffer_create(&obuf0, t, sizes, 2);
+ if (ret != 0)
+ return ret;
+
+ ret = ion_port_bind_buffer(op0, obuf0);
+ if (ret != 0)
+ return ret;
+//
+// ret = ion_graph_run(g0);
+// if (ret != 0)
+// return ret;
+
+ int out0[16*16] = {0};
+// ret = ion_buffer_read(obuf0, out0, 16*16*sizeof(int));
+// for (int i=0;i<16*16; ++i) {
+// if (out0[i] != 41) {
+// printf("out0: %d\n", out0[i]);
+//
+// }
+// }
+
+ ion_graph_t g1;
+ ret = ion_builder_add_graph(b, "graph1", &g1);
+ if (ret != 0)
+ return ret;
+
+
+
+ ion_node_t n1;
+ ret = ion_graph_add_node(g1, "test_inc_i32x2" , &n1);
+ if (ret != 0)
+ return ret;
+
+ ret = ion_node_set_param(n1, &v41, 1);
+ if (ret != 0)
+ return ret;
+
+ ion_port_t ip1;
+ ret = ion_port_create(&ip1, "input1", t, 2);
+ if (ret != 0)
+ return ret;
+ ret = ion_node_set_iport(n1, &ip1, 1);
+ if (ret != 0)
+ return ret;
+
+
+ ion_buffer_t ibuf1;
+ ret = ion_buffer_create(&ibuf1, t, sizes, 2);
+ if (ret != 0)
+ return ret;
+
+ int in1[16*16];
+ for (int i=0; i<16*16; ++i) {
+ in1[i] = 1;
+ }
+ ret = ion_buffer_write(ibuf1, in1, 16*16*sizeof(int));
+ if (ret != 0)
+ return ret;
+
+ ret = ion_port_bind_buffer(ip1, ibuf1);
+ if (ret != 0)
+ return ret;
+
+ ion_port_t op1;
+ ret = ion_node_get_port(n1, "output", &op1);
+ if (ret != 0)
+ return ret;
+
+ ion_buffer_t obuf1;
+ ret = ion_buffer_create(&obuf1, t, sizes, 2);
+ if (ret != 0)
+ return ret;
+
+ ret = ion_port_bind_buffer(op1, obuf1);
+ if (ret != 0)
+ return ret;
+
+// ret = ion_graph_run(g1);
+ if (ret != 0)
+ return ret;
+
+ int out1[16*16] = {0};
+// ret = ion_buffer_read(obuf1, out1, 16*16*sizeof(int));
+// for (int i=0;i<16*16; ++i) {
+// if (out1[i] != 42) {
+// printf("out1: %d\n", out1[i]);
+// }
+// }
+
+ for (int i=0;i<16*16; ++i) {
+ out0[i] =0;
+ out1[i] =0;
+ }
+
+ ion_graph_t g2;
+ ion_graph_create(&g2, b,"graph2");
+
+ ion_graph_t *graphs = (ion_graph_t*)malloc(2 * sizeof(ion_graph_t));
+ graphs[0] = g0;
+ graphs[1] = g1;
+
+ ret = ion_graph_create_with_multiple(&g2, graphs, 2);
+ if (ret != 0)
+ return ret;
+ ret = ion_graph_run(g2);
+ if (ret != 0)
+ return ret;
+ ret = ion_buffer_read(obuf0, out0, 16*16*sizeof(int));
+ ret = ion_buffer_read(obuf1, out1, 16*16*sizeof(int));
+
+ for (int i=0;i<16*16; ++i) {
+ if (out0[i] != 41 ) {
+ printf("out0: %d\n", out0[i]);
+ ;
+ }
+ if (out1[i] != 42 ) {
+ printf("out1: %d\n", out1[i]);
+ }
+ }
+ ret = ion_graph_destroy(g0);
+ ret = ion_graph_destroy(g1);
+ ret = ion_graph_destroy(g2);
+
+ }
} catch (Halide::Error &e) {
std::cerr << e.what() << std::endl;
diff --git a/test/graph.cc b/test/graph.cc
index bb5722f9..98f94ab6 100644
--- a/test/graph.cc
+++ b/test/graph.cc
@@ -4,83 +4,173 @@ using namespace ion;
int main()
{
- try {
- Builder b;
- b.with_bb_module("ion-bb-test");
- b.set_target(Halide::get_host_target());
-
- int32_t size = 16;
-
- Buffer in0(size, size);
- in0.fill(1);
-
- Buffer in1(size, size);
- in1.fill(1);
+ {
+ try {
+ Builder b;
+ b.with_bb_module("ion-bb-test");
+ b.set_target(Halide::get_host_target());
+
+ int32_t size = 16;
+
+ Buffer in0(size, size);
+ in0.fill(1);
+
+ Buffer in1(size, size);
+ in1.fill(1);
+
+ Buffer out0(size, size);
+ out0.fill(0);
+
+ Buffer out1(size, size);
+ out1.fill(0);
+
+ Graph g0 = b.add_graph("graph0");
+ Node n0 = g0.add("test_inc_i32x2")(in0).set_param(Param("v", 40));
+ n0["output"].bind(out0);
+ g0.run();
+
+ for (int y=0; y out0(size, size);
- out0.fill(0);
+ Graph g1 = b.add_graph("graph1");
+ Node n1 = g1.add("test_inc_i32x2")(in1).set_param(Param("v", 41));
+ n1["output"].bind(out1);
+ g1.run();
+
+ for (int y=0; y out1(size, size);
- out1.fill(0);
+ out0.fill(0);
+ out1.fill(0);
- Graph g0 = b.add_graph("graph0");
- Node n0 = g0.add("test_inc_i32x2")(in0).set_param(Param("v", 40));
- n0["output"].bind(out0);
- g0.run();
+ Graph g2(g0 + g1);
+ g2.run();
- for (int y=0; y ibuf0(std::vector{1, 1});
+
+
+ Port ip0{"input", Halide::type_of(), 2};
+ Port vp0{"v", Halide::type_of()};
+
+ Graph g0 = b.add_graph("graph0");
+
+ Node n0 = g0.add("test_incx_i32x2")(ip0, vp0);
+
+ ip0.bind(ibuf0);
+ int32_t v0 = 0;
+ vp0.bind(&v0);
+
+ Buffer obuf0(std::vector{1, 1});
+ n0["output"].bind(obuf0);
+
+ ibuf0(0, 0) = 42;
+ v0 = 0;
+ obuf0(0, 0) = 0;
+
+ g0.run();
+ if (obuf0(0, 0) != 42) {
+ std::cerr << "Expected: " << 42 << " Actual:" << obuf0(0, 0) << std::endl;
+ return 1;
}
- }
- out0.fill(0);
- out1.fill(0);
+ // Test 2
+ Port ip1{"input", Halide::type_of(), 2};
+ Port vp1{"v", Halide::type_of()};
- Graph g2(g0 + g1);
- g2.run();
+ Graph g1 = b.add_graph("graph1");
+ Node n1 = g1.add("test_incx_i32x2")(ip1, vp1);
- for (int y=0; y ibuf1(std::vector{1, 1});
+ ip1.bind(ibuf1);
+ Buffer obuf1(std::vector{1, 1});
+ obuf1.fill(0);
+ n1["output"].bind(obuf1);
+
+
+ ibuf1(0, 0) = 42;
+ v1 = 1;
+ obuf1(0, 0) = 0;
+
+ g1.run();
+
+ if (obuf1(0, 0) != 43) {
+ std::cerr << "Expected: " << 43 << " Actual:" << obuf1(0, 0) << std::endl;
+ return 1;
}
- }
- } catch (Halide::Error& e) {
- std::cerr << e.what() << std::endl;
- return 1;
- } catch (const std::exception& e) {
- std::cerr << e.what() << std::endl;
- return 1;
- }
+ // Test 3
+ Graph g2(g0 + g1);
- std::cout << "Passed" << std::endl;
+ g2.run();
+ if (obuf1(0, 0) != 43) {
+ std::cerr << "Expected: " << 43 << " Actual:" << obuf1(0, 0) << std::endl;
+ return 1;
+ }
+ if (obuf0(0, 0) != 42) {
+ std::cerr << "Expected: " << 42 << " Actual:" << obuf0(0, 0) << std::endl;
+ return 1;
+ }
+ std::cout << "second test Passed" << std::endl;
+
+ } catch (Halide::Error &e) {
+ std::cerr << e.what() << std::endl;
+ return 1;
+ } catch (const std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ return 1;
+ }
+ }
+ std::cout << "All Passed" << std::endl;
return 0;
+
}