diff --git a/include/ion/builder.h b/include/ion/builder.h index 465002d7..5f0f97cf 100644 --- a/include/ion/builder.h +++ b/include/ion/builder.h @@ -39,14 +39,21 @@ class Builder { ~Builder(); /** - * Adding new node to the graph. + * Adding new node to the builder. * @arg k: The key of the node which should be matched with second argument of ION_REGISTER_BUILDING_BLOCK(). */ Node add(const std::string& name); + /** + * Adding new node to the specific graph. + * @arg k: The key of the node which should be matched with second argument of ION_REGISTER_BUILDING_BLOCK(). + * @arg id: graph unique identifier + */ + Node add(const std::string& name, const GraphID& graph_id); /** - * + * Adding new node to the graph. + * @arg k: The key of the node which should be matched with second argument of ION_REGISTER_BUILDING_BLOCK(). */ Graph add_graph(const std::string& name); diff --git a/include/ion/c_ion.h b/include/ion/c_ion.h index 9281adc3..6347c455 100644 --- a/include/ion/c_ion.h +++ b/include/ion/c_ion.h @@ -31,6 +31,7 @@ typedef struct ion_node_t_ *ion_node_t; typedef struct ion_builder_t_ *ion_builder_t; typedef struct ion_buffer_t_ *ion_buffer_t; typedef struct ion_port_map_t_ *ion_port_map_t; +typedef struct ion_graph_t_ *ion_graph_t; int ion_port_create(ion_port_t *, const char *, ion_type_t, int); int ion_port_create_with_index(ion_port_t *, ion_port_t , int); @@ -62,6 +63,7 @@ int ion_builder_create(ion_builder_t *); int ion_builder_destroy(ion_builder_t); int ion_builder_set_target(ion_builder_t, const char *); int ion_builder_with_bb_module(ion_builder_t, const char *); +int ion_builder_add_graph(ion_builder_t, const char *, ion_graph_t *); int ion_builder_add_node(ion_builder_t, const char *, ion_node_t *); int ion_builder_compile(ion_builder_t, const char *, ion_builder_compile_option_t option); int ion_builder_save(ion_builder_t, const char *); @@ -76,6 +78,12 @@ int ion_buffer_destroy(ion_buffer_t); int ion_buffer_write(ion_buffer_t, void *, int size); int ion_buffer_read(ion_buffer_t, void *, int size); +int ion_graph_create(ion_graph_t *, ion_builder_t, const char *); +int ion_graph_add_node(ion_graph_t, const char*, ion_node_t *); +int ion_graph_destroy(ion_graph_t); +int ion_graph_run(ion_graph_t); +int ion_graph_create_with_multiple(ion_graph_t * ptr, ion_graph_t* objs, int size); + [[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_create(ion_port_map_t *); [[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] diff --git a/include/ion/node.h b/include/ion/node.h index 70ec314b..71e17fb8 100644 --- a/include/ion/node.h +++ b/include/ion/node.h @@ -20,16 +20,17 @@ class Node { public: struct Impl { - std::string id; + NodeID id; std::string name; + GraphID graph_id; Halide::Target target; std::vector params; std::vector ports; std::vector arginfos; Impl(): id(), name(), target(), params(), ports() {} - - Impl(const std::string& id_, const std::string& name_, const Halide::Target& target_); + Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_); + Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_, const GraphID &graph_id_); }; public: @@ -76,7 +77,7 @@ class Node { */ template Node operator()(Args ...args) { - set_iport(std::vector{args...}); + set_iport(std::vector{make_iport(args)...}); return *this; } @@ -94,7 +95,7 @@ class Node { Port operator[](const std::string& name); // Getter - const std::string& id() const { + const NodeID & id() const { return impl_->id; } @@ -121,11 +122,45 @@ class Node { std::vector> oports() const; private: - Node(const std::string& id, const std::string& name, const Halide::Target& target) + Node(const NodeID& id, const std::string& name, const Halide::Target& target) : impl_(new Impl{id, name, target}) { } + Node(const NodeID&& id, const std::string& name, const Halide::Target& target, const GraphID& graph_id) + : impl_(new Impl{id, name, target, graph_id}) + { + } + + Port make_iport(Port arg) const { + return arg; + } + + template + Port make_iport(T *vptr) const { + if (to_string(impl_->graph_id).empty()) + return Port(vptr); + else + return Port(vptr, impl_->graph_id); + } + + template + Port make_iport(Halide::Buffer& arg) const { + if (to_string(impl_->graph_id).empty()) + return Port(arg); + else + return Port(arg, impl_->graph_id); + } + + template + Port make_iport(std::vector>& arg) const { + if (to_string(impl_->graph_id).empty()) + return Port(arg); + else + return Port(arg, impl_->graph_id); + } + + std::shared_ptr impl_; }; diff --git a/include/ion/port.h b/include/ion/port.h index e6669b4c..ed36880b 100644 --- a/include/ion/port.h +++ b/include/ion/port.h @@ -49,11 +49,12 @@ class Port { friend class Node; public: - using Channel = std::tuple; + using Channel = std::tuple; private: struct Impl { - std::string id; + PortID id; + GraphID graph_id; Channel pred_chan; std::set succ_chans; @@ -64,12 +65,12 @@ class Port { std::unordered_map instances; Impl(); - Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d); + Impl(const NodeID& nid, const std::string& pn, const Halide::Type& t, int32_t d, const GraphID &gid ); }; public: - Port() : impl_(new Impl("", "", Halide::Type(), 0)), index_(-1) {} + Port() : impl_(new Impl(NodeID(""), "", Halide::Type(), 0, GraphID(""))), index_(-1) {} Port(const std::shared_ptr& impl, int32_t index) : impl_(impl), index_(index) {} @@ -78,7 +79,7 @@ class Port { * @arg k: The key of the port which should be matched with BuildingBlock Input/Output name. * @arg t: The type of the value. */ - Port(const std::string& n, Halide::Type t) : impl_(new Impl("", n, t, 0)), index_(-1) {} + Port(const std::string& n, Halide::Type t) : impl_(new Impl(NodeID(""), n, t, 0, GraphID(""))), index_(-1) {} /** * Construct new port for vector value. @@ -86,22 +87,40 @@ class Port { * @arg t: The type of the element value. * @arg d: The dimension of the port. The range is 1 to 4. */ - Port(const std::string& n, Halide::Type t, int32_t d) : impl_(new Impl("", n, t, d)), index_(-1) {} + Port(const std::string& n, Halide::Type t, int32_t d) : impl_(new Impl(NodeID(""), n, t, d, GraphID(""))), index_(-1) {} /** * Construct new port from scalar pointer */ template::value>::type* = nullptr> - Port(T *vptr) : impl_(new Impl("", Halide::Internal::unique_name("_ion_port_"), Halide::type_of(), 0)), index_(-1) { + Port(T *vptr) : impl_(new Impl(NodeID(""), Halide::Internal::unique_name("_ion_port_"), Halide::type_of(), 0, GraphID(""))), index_(-1) { this->bind(vptr); } + /** + * Construct new port from scalar pointer + */ + template::value>::type* = nullptr> + Port(T *vptr, const GraphID & gid) : impl_(new Impl(NodeID(""), Halide::Internal::unique_name("_ion_port_"), Halide::type_of(), 0, gid)), index_(-1) { + this->bind(vptr); + } + + /** * Construct new port from buffer */ template - Port(const Halide::Buffer& buf) : impl_(new Impl("", buf.name(), buf.type(), buf.dimensions())), index_(-1) { + Port(const Halide::Buffer& buf) : impl_(new Impl(NodeID(""), buf.name(), buf.type(), buf.dimensions(), GraphID(""))), index_(-1) { + this->bind(buf); + } + + /** + * Construct new port from buffer and bind graph id to port + */ + template + Port(const Halide::Buffer& buf, const GraphID & gid) : impl_(new Impl(NodeID(""), buf.name(), buf.type(), buf.dimensions(), gid)), index_(-1) { this->bind(buf); } @@ -109,36 +128,45 @@ class Port { * Construct new port from array of buffer */ template - Port(const std::vector>& bufs) : impl_(new Impl("", unify_name(bufs), Halide::type_of(), unify_dimension(bufs))), index_(-1) { + Port(const std::vector>& bufs) : impl_(new Impl(NodeID(""), unify_name(bufs), Halide::type_of(), unify_dimension(bufs), GraphID(""))), index_(-1) { + this->bind(bufs); + } + + /** + * Construct new port from array of buffer and bind graph id to port + */ + template + Port(const std::vector>& bufs, const GraphID & gid) : impl_(new Impl(NodeID(""), unify_name(bufs), Halide::type_of(), unify_dimension(bufs), gid)), index_(-1) { this->bind(bufs); } // Getter - const std::string& id() const { return impl_->id; } + const PortID id() const { return impl_->id; } const Channel& pred_chan() const { return impl_->pred_chan; } - const std::string& pred_id() const { return std::get<0>(impl_->pred_chan); } + const NodeID& pred_id() const { return std::get<0>(impl_->pred_chan); } const std::string& pred_name() const { return std::get<1>(impl_->pred_chan); } const std::set& succ_chans() const { return impl_->succ_chans; } const Halide::Type& type() const { return impl_->type; } int32_t dimensions() const { return impl_->dimensions; } int32_t size() const { return static_cast(impl_->params.size()); } int32_t index() const { return index_; } + const GraphID& graph_id() const { return impl_->graph_id; } // Setter void set_index(int index) { index_ = index; } // Util - bool has_pred() const { return !std::get<0>(impl_->pred_chan).empty(); } - bool has_pred_by_nid(const std::string& nid) const { return !std::get<0>(impl_->pred_chan).empty(); } + bool has_pred() const { return !std::get<0>(impl_->pred_chan).value().empty(); } + bool has_pred_by_nid(const NodeID & nid) const { return !to_string(std::get<0>(impl_->pred_chan)).empty(); } bool has_succ() const { return !impl_->succ_chans.empty(); } bool has_succ(const Channel& c) const { return impl_->succ_chans.count(c); } - bool has_succ_by_nid(const std::string& nid) const { + bool has_succ_by_nid(const NodeID& nid) const { return std::count_if(impl_->succ_chans.begin(), impl_->succ_chans.end(), [&](const Port::Channel& c) { return std::get<0>(c) == nid; }); } - void determine_succ(const std::string& nid, const std::string& old_pn, const std::string& new_pn); + void determine_succ(const NodeID& nid, const std::string& old_pn, const std::string& new_pn); /** * Overloaded operator to set the port index and return a reference to the current port. eg. port[0] @@ -153,22 +181,21 @@ class Port { void bind(T *v) { auto i = index_ == -1 ? 0 : index_; if (has_pred()) { - impl_->params[i] = Halide::Internal::Parameter{Halide::type_of(), false, 0, argument_name(pred_id(), pred_name(), i)}; + impl_->params[i] = Halide::Internal::Parameter{Halide::type_of(), false, 0, argument_name(pred_id(), pred_name(), i, graph_id())}; } else { - impl_->params[i] = Halide::Internal::Parameter{type(), false, dimensions(), argument_name(pred_id(), pred_name(), i)}; + impl_->params[i] = Halide::Internal::Parameter{type(), false, dimensions(), argument_name(pred_id(), pred_name(), i, graph_id())}; } impl_->instances[i] = v; } - template void bind(const Halide::Buffer& buf) { auto i = index_ == -1 ? 0 : index_; if (has_pred()) { - impl_->params[i] = Halide::Internal::Parameter{buf.type(), true, buf.dimensions(), argument_name(pred_id(), pred_name(), i)}; + impl_->params[i] = Halide::Internal::Parameter{buf.type(), true, buf.dimensions(), argument_name(pred_id(), pred_name(), i,graph_id())}; } else { - impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i)}; + impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i,graph_id())}; } impl_->instances[i] = buf.raw_buffer(); @@ -178,9 +205,9 @@ class Port { void bind(const std::vector>& bufs) { for (int i=0; i(bufs.size()); ++i) { if (has_pred()) { - impl_->params[i] = Halide::Internal::Parameter{bufs[i].type(), true, bufs[i].dimensions(), argument_name(pred_id(), pred_name(), i)}; + impl_->params[i] = Halide::Internal::Parameter{bufs[i].type(), true, bufs[i].dimensions(), argument_name(pred_id(), pred_name(), i, graph_id())}; } else { - impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i)}; + impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i, graph_id())}; } impl_->instances[i] = bufs[i].raw_buffer(); @@ -199,7 +226,7 @@ class Port { if (es.size() <= i) { es.resize(i+1, Halide::Expr()); } - es[i] = Halide::Internal::Variable::make(type(), argument_name(pred_id(), pred_name(), i), param); + es[i] = Halide::Internal::Variable::make(type(), argument_name(pred_id(), pred_name(), i, graph_id()), param); } return es; } @@ -220,7 +247,7 @@ class Port { args.push_back(Halide::Var::implicit(i)); args_expr.push_back(Halide::Var::implicit(i)); } - Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), pred_name(), i) + "_im"); + Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), pred_name(), i, graph_id()) + "_im"); f(args) = Halide::Internal::Call::make(param, args_expr); fs[i] = f; } @@ -234,7 +261,7 @@ class Port { args.resize(i+1, Halide::Argument()); } auto kind = dimensions() == 0 ? Halide::Argument::InputScalar : Halide::Argument::InputBuffer; - args[i] = Halide::Argument(argument_name(pred_id(), pred_name(), i), kind, type(), dimensions(), Halide::ArgumentEstimates()); + args[i] = Halide::Argument(argument_name(pred_id(), pred_name(), i, graph_id()), kind, type(), dimensions(), Halide::ArgumentEstimates()); } return args; } @@ -257,7 +284,7 @@ class Port { * pid and pn is stored in both pred and succ, * then it will determined through pipeline build process. */ - Port(const std::string& pid, const std::string& pn) : impl_(new Impl(pid, pn, Halide::Type(), 0)), index_(-1) {} + Port(const NodeID & nid, const std::string& pn) : impl_(new Impl(nid, pn, Halide::Type(), 0, GraphID(""))), index_(-1) {} std::shared_ptr impl_; diff --git a/include/ion/port_map.h b/include/ion/port_map.h index 82a05dc2..ee036910 100644 --- a/include/ion/port_map.h +++ b/include/ion/port_map.h @@ -16,7 +16,7 @@ class PortMap { template [[deprecated("Port::bind can be used instead of PortMap.")]] void set(Port port, T v) { - auto& buf(scalar_buffer_[argument_name(port.pred_id(), port.pred_name(), port.index())]); + auto& buf(scalar_buffer_[argument_name(port.pred_id(), port.pred_name(), port.index(), port.graph_id())]); buf.resize(sizeof(v)); std::memcpy(buf.data(), &v, sizeof(v)); port.bind(reinterpret_cast(buf.data())); diff --git a/include/ion/util.h b/include/ion/util.h index 78358457..78ddf3b6 100644 --- a/include/ion/util.h +++ b/include/ion/util.h @@ -2,15 +2,61 @@ #define ION_UTIL_H #include - namespace ion { class Port; -std::string argument_name(const std::string& node_id, const std::string& name, int32_t index); - std::string array_name(const std::string& port_name, size_t i); +// a string-like identifier that is typed on a tag type +template +struct StringID { + using tag_type = Tag; + + // needs to be default-constructable because of use in map[] below + StringID(std::string s) : _value(std::move(s)) {} + StringID() : _value() {} + // provide access to the underlying string value + const std::string &value() const { return _value; } + + struct StringIDHash { + // Use hash of string as hash function. + size_t operator()(const StringID& id) const + { + return std::hash()(id.value()); + } + }; + +private: + std::string _value; + + // will only compare against same type of id. + friend bool operator<(const StringID &l, const StringID &r) { + return l._value < r._value; + } + + friend bool operator==(const StringID &l, const StringID &r) { + return l._value == r._value; + } + + // and let's go ahead and provide expected free functions + friend + auto to_string(const StringID &r) + -> const std::string & { + return r._value; + } +}; + +struct node_tag {}; +struct graph_tag {}; +struct port_tag {}; + +using NodeID = StringID; +using GraphID = StringID; +using PortID = StringID; + +std::string argument_name(const NodeID& node_id, const std::string& name, int32_t index, const GraphID& graph_id); + } // namespace ion #endif diff --git a/python/ionpy/Builder.py b/python/ionpy/Builder.py index 05ad4177..7ea9ecae 100644 --- a/python/ionpy/Builder.py +++ b/python/ionpy/Builder.py @@ -9,6 +9,7 @@ ion_builder_set_target, ion_builder_with_bb_module, + ion_builder_add_graph, ion_builder_add_node, ion_builder_compile, @@ -19,7 +20,7 @@ ion_builder_run_with_port_map, ) - +from .Graph import Graph from .Node import Node from .BuilderCompileOption import BuilderCompileOption from .PortMap import PortMap @@ -62,6 +63,14 @@ def add(self, key: str) -> Node: return Node(obj_=c_node) + def add_graph(self, name: str) -> Graph: + c_graph = c_ion_node_t() + ret = ion_builder_add_graph(self.obj, name.encode(), ctypes.byref(c_graph)) + if ret != 0: + raise Exception('Invalid operation') + + return Graph(obj_=c_graph) + def compile(self, function_name: str, option: BuilderCompileOption): ret = ion_builder_compile(self.obj, function_name.encode(), option.to_cobj()) if ret != 0: diff --git a/python/ionpy/Graph.py b/python/ionpy/Graph.py new file mode 100644 index 00000000..ff7d1e49 --- /dev/null +++ b/python/ionpy/Graph.py @@ -0,0 +1,86 @@ +import ctypes +from typing import Optional + +from .native import ( + c_ion_node_t, + c_ion_graph_t, + + ion_graph_create, + ion_graph_destroy, + ion_graph_create_with_multiple, + ion_graph_run, + ion_graph_add_node +) + +from .Node import Node + + +class Graph: + def __init__(self, + builder = None, + name: Optional[str] = None, + obj_: Optional[c_ion_graph_t] = None, + sub_graphs: [] = None, + + ): + if obj_ is None and builder is not None and name is not None: + obj_ = c_ion_graph_t() + ret = ion_graph_create(ctypes.byref(obj_), builder.obj, name.encode()) + if ret != 0: + raise Exception('Invalid operation') + elif obj_ is None and sub_graphs is not None: + num_graphs = len(sub_graphs) + c_ion_graph_sized_array_t = c_ion_graph_t * num_graphs # arraysize == num_graphs + c_graphs = c_ion_graph_sized_array_t() # instance + for i in range(num_graphs): + c_graphs[i] = sub_graphs[i].obj + obj_ = c_ion_graph_t() + ret = ion_graph_create_with_multiple(ctypes.byref(obj_), c_graphs, num_graphs) + if ret != 0: + raise Exception('Invalid operation') + + self.obj = obj_ + + def __del__(self): + if self.obj: # check not nullptr + ion_graph_destroy(self.obj) + + # adding two objects + + def __add__(self, other): + if isinstance(other, Graph): + c_ion_graph_sized_array_t = c_ion_graph_t * 2 # arraysize == num_graphs + c_graphs = c_ion_graph_sized_array_t() # instance + c_graphs[0] = self.obj + c_graphs[1] = other.obj + new_obj = c_ion_graph_t() + ret = ion_graph_create_with_multiple(ctypes.byref(new_obj), c_graphs, 2) + if ret != 0: + raise Exception('Invalid operation') + return Graph(obj_=new_obj) + + def __iadd__(self, other): + if isinstance(other, Graph): + c_ion_graph_sized_array_t = c_ion_graph_t * 2 # arraysize == num_graphs + c_graphs = c_ion_graph_sized_array_t() # instance + c_graphs[0] = self.obj + c_graphs[1] = other.obj + ret = ion_graph_create_with_multiple(ctypes.byref(self.obj), c_graphs, 2) + if ret != 0: + raise Exception('Invalid operation') + return self + + + def run(self): + ret = ion_graph_run(self.obj) + if ret != 0: + raise Exception('Invalid operation') + + def add(self, key: str) -> Node: + c_node = c_ion_node_t() + + ret = ion_graph_add_node(self.obj, key.encode(), ctypes.byref(c_node)) + if ret != 0: + raise Exception('Invalid operation') + + return Node(obj_=c_node) diff --git a/python/ionpy/__init__.py b/python/ionpy/__init__.py index e2d047ee..d6afedd1 100644 --- a/python/ionpy/__init__.py +++ b/python/ionpy/__init__.py @@ -5,7 +5,7 @@ from .Builder import Builder from .Buffer import Buffer from .PortMap import PortMap - +from .Graph import Graph from .Type import Type from .TypeCode import TypeCode from .BuilderCompileOption import BuilderCompileOption diff --git a/python/ionpy/native.py b/python/ionpy/native.py index a551048f..7fa9ae2b 100644 --- a/python/ionpy/native.py +++ b/python/ionpy/native.py @@ -38,6 +38,7 @@ class c_builder_compile_option_t(ctypes.Structure): c_ion_node_t = ctypes.POINTER(ctypes.c_int) c_ion_builder_t = ctypes.POINTER(ctypes.c_int) +c_ion_graph_t = ctypes.POINTER(ctypes.c_int) c_ion_buffer_t = ctypes.POINTER(ctypes.c_int) c_ion_port_map_t = ctypes.POINTER(ctypes.c_int) @@ -212,6 +213,11 @@ class c_builder_compile_option_t(ctypes.Structure): ion_builder_bb_metadata.restype = ctypes.c_int ion_builder_bb_metadata.argtypes = [ c_ion_builder_t, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_int) ] +# int ion_builder_add_graph(ion_builder_t, const char , ion_graph_t *) +ion_builder_add_graph = ion_core.ion_builder_add_graph +ion_builder_add_graph.restype = ctypes.c_int +ion_builder_add_graph.argtypes = [ c_ion_builder_t, ctypes.c_char_p, ctypes.POINTER(c_ion_graph_t) ] + # int ion_builder_run(ion_builder_t, ion_port_map_t); ion_builder_run = ion_core.ion_builder_run @@ -251,6 +257,30 @@ class c_builder_compile_option_t(ctypes.Structure): ion_buffer_read.restype = ctypes.c_int ion_buffer_read.argtypes = [ c_ion_buffer_t, ctypes.c_void_p, ctypes.c_int ] +# int ion_graph_create(ion_graph_t *, ion_builder_t, const char *) +ion_graph_create = ion_core.ion_graph_create +ion_graph_create.restype = ctypes.c_int +ion_graph_create.argtypes =[ ctypes.POINTER(c_ion_graph_t), c_ion_builder_t, ctypes.c_char_p ] + +# int ion_graph_create_with_multiple(ion_graph_t*, ion_graph_t*, int size) +ion_graph_create_with_multiple = ion_core.ion_graph_create_with_multiple +ion_graph_create_with_multiple.restype = ctypes.c_int +ion_graph_create_with_multiple.argtypes = [ ctypes.POINTER(c_ion_graph_t), ctypes.POINTER(c_ion_graph_t), ctypes.c_int] + +# int ion_graph_add_node(ion_graph_t, const char*, ion_node_t *) +ion_graph_add_node = ion_core.ion_graph_add_node +ion_graph_add_node.restype = ctypes.c_int +ion_graph_add_node.argtypes =[ c_ion_graph_t, ctypes.c_char_p, ctypes.POINTER(c_ion_node_t) ] + +# int ion_graph_run(ion_graph_t) +ion_graph_run=ion_core.ion_graph_run +ion_graph_run.restype = ctypes.c_int +ion_graph_run.argtypes =[ c_ion_graph_t] + +# ion_graph_destroy(ion_graph_t) +ion_graph_destroy=ion_core.ion_graph_destroy +ion_graph_destroy.restype = ctypes.c_int +ion_graph_destroy.argtypes =[ c_ion_graph_t] # int ion_port_map_create(ion_port_map_t *); diff --git a/python/test/test_graph.py b/python/test/test_graph.py new file mode 100644 index 00000000..6dab4d49 --- /dev/null +++ b/python/test/test_graph.py @@ -0,0 +1,52 @@ +from ionpy import Builder, Graph, Port, Buffer, Type +import numpy as np + +def test_graph(): + input_port0 = Port(name='input', type=Type.from_dtype(np.dtype(np.int32)), dim=2) + value_port0 = Port(name='v', type=Type.from_dtype(np.dtype(np.int32)), dim=0) + + builder = Builder() + builder.set_target(target='host') + builder.with_bb_module(path='ion-bb-test') + + graph0 = builder.add_graph("graph0") + + node0 = graph0.add('test_incx_i32x2').set_iport([input_port0, value_port0]) + + idata0 = np.array([[42, 42]], dtype=np.int32) + ibuf0 = Buffer(array=idata0) + + odata0 = np.array([[0, 0]], dtype=np.int32) + obuf0 = Buffer(array=odata0) + + input_port0.bind(ibuf0) + output_port0 = node0.get_port(name='output') + output_port0.bind(obuf0) + value_port0.bind(0) + + input_port1 = Port(name='input', type=Type.from_dtype(np.dtype(np.int32)), dim=2) + value_port1 = Port(name='v', type=Type.from_dtype(np.dtype(np.int32)), dim=0) + + graph1 = builder.add_graph("graph1") + # graph1 = Graph(builder =builder, name="graph1") # alternative + node1 = graph1.add('test_incx_i32x2').set_iport([input_port1, value_port1]) + + idata1 = np.array([[42, 42]], dtype=np.int32) + ibuf1 = Buffer(array=idata1) + + odata1 = np.array([[0, 0]], dtype=np.int32) + obuf1 = Buffer(array=odata1) + + input_port1.bind(ibuf1) + output_port = node1.get_port(name='output') + output_port.bind(obuf1) + value_port1.bind(1) + + g = graph0 + graph1 + g.run() + # alternative + # graph1 += graph0 + # g = g(sub_graphs=[graph1,graph0]) + + assert odata0[0][0] == 42 + assert odata1[0][0] == 43 diff --git a/src/builder.cc b/src/builder.cc index 75c6934a..27f342ce 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -88,6 +88,13 @@ Node Builder::add(const std::string& name) return n; } +Node Builder::add(const std::string& name, const GraphID & graph_id) +{ + Node n(sole::uuid4().str(), name, impl_->target, graph_id); + impl_->nodes.push_back(n); + return n; +} + Graph Builder::add_graph(const std::string& name) { Graph g(*this, name); impl_->graphs.push_back(g); diff --git a/src/c_ion.cc b/src/c_ion.cc index b5f11e3d..35798690 100644 --- a/src/c_ion.cc +++ b/src/c_ion.cc @@ -2,7 +2,6 @@ #include #include - #include #include "log.h" @@ -437,6 +436,23 @@ int ion_builder_with_bb_module(ion_builder_t obj, const char *module_name) log::error("Unknown exception was happened"); return 1; } + return 0; +} + +int ion_builder_add_graph(ion_builder_t obj, const char *name, ion_graph_t *graph_ptr) +{ + try { + *graph_ptr = reinterpret_cast(new Graph(reinterpret_cast(obj)->add_graph(name))); + } catch (const Halide::Error& e) { + log::error(e.what()); + return 1; + } catch (const std::exception& e) { + log::error(e.what()); + return 1; + } catch (...) { + log::error("Unknown exception was happened"); + return 1; + } return 0; } @@ -1026,3 +1042,95 @@ int ion_port_map_set_buffer_array(ion_port_map_t obj, ion_port_t p, ion_buffer_t return 0; } + + +int ion_graph_create(ion_graph_t *ptr, ion_builder_t obj, const char * name) +{ + try { + *ptr = reinterpret_cast(new Graph(*reinterpret_cast(obj), name)); + } catch (const Halide::Error& e) { + log::error(e.what()); + return 1; + } catch (const std::exception& e) { + log::error(e.what()); + return 1; + } catch (...) { + log::error("Unknown exception was happened"); + return 1; + } + + return 0; +} + +int ion_graph_add_node(ion_graph_t obj, const char *name, ion_node_t *node_ptr) +{ + try { + *node_ptr = reinterpret_cast(new Node(reinterpret_cast(obj)->add(name))); + } catch (const Halide::Error& e) { + log::error(e.what()); + return 1; + } catch (const std::exception& e) { + log::error(e.what()); + return 1; + } catch (...) { + log::error("Unknown exception was happened"); + return 1; + } + + return 0; +} + +int ion_graph_create_with_multiple(ion_graph_t * ptr, ion_graph_t* graphs_ptr, int graphs_num) { + try { + auto sum_graph = *reinterpret_cast(graphs_ptr[0]); + for (int i=1; i(graphs_ptr[i]); + } + *ptr = reinterpret_cast(new Graph(sum_graph)); + } catch (const Halide::Error& e) { + log::error(e.what()); + return 1; + } catch (const std::exception& e) { + log::error(e.what()); + return 1; + } catch (...) { + log::error("Unknown exception was happened"); + return 1; + } + + return 0; +} + +int ion_graph_run(ion_graph_t obj) +{ + try { + reinterpret_cast(obj)->run(); + } catch (const Halide::Error& e) { + log::error(e.what()); + return 1; + } catch (const std::exception& e) { + log::error(e.what()); + return 1; + } catch (...) { + log::error("Unknown exception was happened"); + return 1; + } + + return 0; +} + +int ion_graph_destroy(ion_graph_t obj){ + try { + delete reinterpret_cast(obj); + } catch (const Halide::Error& e) { + log::error(e.what()); + return 1; + } catch (const std::exception& e) { + log::error(e.what()); + return 1; + } catch (...) { + log::error("Unknown exception was happened"); + return 1; + } + return 0; +} diff --git a/src/graph.cc b/src/graph.cc index 819262f8..b3ddff04 100644 --- a/src/graph.cc +++ b/src/graph.cc @@ -3,24 +3,27 @@ #include "log.h" #include "lower.h" - +#include "uuid/sole.hpp" namespace ion { struct Graph::Impl { Builder builder; std::string name; + GraphID id; std::vector nodes; - // Cacheable Halide::Pipeline pipeline; Halide::Callable callable; std::unique_ptr jit_ctx; Halide::JITUserContext* jit_ctx_ptr; std::vector args; - - Impl(Builder b, const std::string& n) - : builder(b), name(n), jit_ctx(new Halide::JITUserContext), jit_ctx_ptr(jit_ctx.get()) + Impl() + : id(sole::uuid4().str()) {} + Impl(Builder b, const std::string& n) + : id(sole::uuid4().str()), builder(b), name(n), jit_ctx(new Halide::JITUserContext), jit_ctx_ptr(jit_ctx.get()) + { + } }; Graph::Graph() @@ -48,7 +51,8 @@ Graph operator+(const Graph& lhs, const Graph& rhs) Node Graph::add(const std::string& name) { - auto n = impl_->builder.add(name); + auto n = impl_->builder.add(name,impl_->id); + impl_->nodes.push_back(n); return n; } diff --git a/src/lower.cc b/src/lower.cc index 406642f4..07725941 100644 --- a/src/lower.cc +++ b/src/lower.cc @@ -218,14 +218,14 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_ topological_sort(nodes); // Constructing Generator object and setting static parameters - std::unordered_map bbs; + std::unordered_map bbs; for (auto n : nodes) { auto bb(Halide::Internal::GeneratorRegistry::create(n.name(), Halide::GeneratorContext(n.target()))); // Default parameter Halide::GeneratorParamsMap params; params["builder_impl_ptr"] = std::to_string(reinterpret_cast(builder.impl_ptr())); - params["bb_id"] = n.id(); + params["bb_id"] = to_string(n.id()); // User defined parameter for (const auto& p : n.params()) { @@ -243,7 +243,7 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_ for (const auto& [pn, port] : n.iports()) { // Find arginfo - auto it = std::find_if(arginfos.begin(), arginfos.end(), [pn](const ArgInfo& arginfo) { return arginfo.name == pn; }); + auto it = std::find_if(arginfos.begin(), arginfos.end(), [&pn=pn](const ArgInfo& arginfo) { return arginfo.name == pn; }); if (it == arginfos.end()) { auto msg = fmt::format("Argument {} is not defined in node {}", pn, n.name()); log::error(msg); @@ -295,7 +295,7 @@ Halide::Pipeline lower(Builder builder, std::vector& nodes, bool implicit_ if (implicit_output) { // Collects all output which is never referenced. // This mode is used for AOT compilation - std::unordered_map> referenced; + std::unordered_map, NodeID::StringIDHash> referenced; for (const auto& n : nodes) { for (const auto& [pn, port] : n.iports()) { if (port.has_pred()) { diff --git a/src/node.cc b/src/node.cc index 816eaa31..73257a0c 100644 --- a/src/node.cc +++ b/src/node.cc @@ -4,7 +4,8 @@ namespace ion { -Node::Impl::Impl(const std::string& id_, const std::string& name_, const Halide::Target& target_) + +Node::Impl::Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_) : id(id_), name(name_), target(target_), params(), ports() { auto bb(Halide::Internal::GeneratorRegistry::create(name_, Halide::GeneratorContext(target_))); @@ -16,10 +17,22 @@ Node::Impl::Impl(const std::string& id_, const std::string& name_, const Halide: arginfos = bb->arginfos(); } +Node::Impl::Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_, const GraphID& graph_id_) + : id(id_), name(name_), target(target_), params(), ports(), graph_id(graph_id_) +{ + auto bb(Halide::Internal::GeneratorRegistry::create(name_, Halide::GeneratorContext(target_))); + if (!bb) { + log::error("BuildingBlock {} is not found", name_); + throw std::runtime_error("Failed to create building block object"); + } + + arginfos = bb->arginfos(); +} + void Node::set_iport(const std::vector& ports) { impl_->ports.erase(std::remove_if(impl_->ports.begin(), impl_->ports.end(), - [&](const Port &p) { return p.has_succ_by_nid(this->id()); }), + [&](const Port &p) { return p.has_succ_by_nid(this->id());}), impl_->ports.end()); size_t i = 0; @@ -37,7 +50,7 @@ void Node::set_iport(const std::vector& ports) { // NOTE: Is succ_chans name OK to be just leave as it is? port.impl_->succ_chans.insert({id(), "_ion_iport_" + std::to_string(i)}); - + port.impl_ ->graph_id = impl_->graph_id; impl_->ports.push_back(port); i++; @@ -45,11 +58,13 @@ void Node::set_iport(const std::vector& ports) { } void Node::set_iport(Port port) { + port.impl_ ->graph_id = impl_->graph_id; port.impl_->succ_chans.insert({id(), port.pred_name()}); impl_->ports.push_back(port); } void Node::set_iport(const std::string& name, Port port) { + port.impl_ ->graph_id = impl_->graph_id; port.impl_->succ_chans.insert({id(), name}); impl_->ports.push_back(port); } @@ -60,7 +75,8 @@ Port Node::operator[](const std::string& name) { if (it == impl_->ports.end()) { // This is output port which is never referenced. // Bind myself as a predecessor and register - Port port(impl_->id, name); + Port port(id(), name); + port.impl_ ->graph_id = impl_->graph_id; impl_->ports.push_back(port); return port; } else { diff --git a/src/port.cc b/src/port.cc index a90a1506..752a8d25 100644 --- a/src/port.cc +++ b/src/port.cc @@ -6,24 +6,24 @@ namespace ion { Port::Impl::Impl() - : id(sole::uuid4().str()), pred_chan{"", ""}, succ_chans{}, type(), dimensions(-1) + : id(PortID(sole::uuid4().str())), pred_chan{"", ""}, succ_chans{}, type(), dimensions(-1) { } -Port::Impl::Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d) - : id(sole::uuid4().str()), pred_chan{pid, pn}, succ_chans{}, type(t), dimensions(d) +Port::Impl::Impl(const NodeID & nid, const std::string& pn, const Halide::Type& t, int32_t d, const GraphID & gid) + : id(PortID(sole::uuid4().str())), pred_chan{nid, pn}, succ_chans{}, type(t), dimensions(d), graph_id(gid) { - params[0] = Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(pid, pn, 0)); + params[0] = Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(nid, pn, 0, gid)); } -void Port::determine_succ(const std::string& nid, const std::string& old_pn, const std::string& new_pn) { +void Port::determine_succ(const NodeID& nid, const std::string& old_pn, const std::string& new_pn) { auto it = std::find(impl_->succ_chans.begin(), impl_->succ_chans.end(), Channel{nid, old_pn}); if (it == impl_->succ_chans.end()) { log::error("fixme"); throw std::runtime_error("fixme"); } - log::debug("Determine free port {} as {} on Node {}", old_pn, new_pn, nid); + log::debug("Determine free port {} as {} on Node {}", old_pn, new_pn, nid.value()); impl_->succ_chans.erase(it); impl_->succ_chans.insert(Channel{nid, new_pn}); } diff --git a/src/serializer.h b/src/serializer.h index c1f42967..92ee377e 100644 --- a/src/serializer.h +++ b/src/serializer.h @@ -42,9 +42,14 @@ static void from_json(const json& j, ion::Param& v) { template<> struct adl_serializer { static void to_json(json& j, const ion::Port& v) { - j["id"] = v.id(); - j["pred_chan"] = v.pred_chan(); - j["succ_chans"] = v.succ_chans(); + j["id"] = to_string(v.id()); + std::map stringMap; + j["pred_chan"] = std::make_tuple(to_string(std::get<0>(v.pred_chan())), std::get<1>(v.pred_chan())); + std::set> succ_chans; + for (auto& c:v.succ_chans()){ + succ_chans.insert(std::make_tuple(to_string(std::get<0>(c)), std::get<1>(c))); + } + j["succ_chans"] = succ_chans; j["type"] = static_cast(v.type()); j["dimensions"] = v.dimensions(); j["size"] = v.size(); @@ -54,13 +59,17 @@ struct adl_serializer { static void from_json(const json& j, ion::Port& v) { auto [impl, found] = ion::Port::find_impl(j["id"].get()); if (!found) { - impl->pred_chan = j["pred_chan"].get(); - impl->succ_chans = j["succ_chans"].get>(); + impl->pred_chan = j["pred_chan"].get>(); + std::set succ_chans; + for (auto & p : j["succ_chans"]){ + succ_chans.insert(p.get>()); + } + impl->succ_chans = succ_chans; impl->type = j["type"].get(); impl->dimensions = j["dimensions"]; for (auto i=0; iparams[i] = Halide::Internal::Parameter(impl->type, impl->dimensions != 0, impl->dimensions, - ion::argument_name(std::get<0>(impl->pred_chan), std::get<1>(impl->pred_chan), i)); + ion::argument_name(std::get<0>(impl->pred_chan), std::get<1>(impl->pred_chan), i, impl->graph_id.value())); } } v = ion::Port(impl, j["index"]); @@ -70,7 +79,7 @@ struct adl_serializer { template <> struct adl_serializer { static void to_json(json& j, const ion::Node& v) { - j["id"] = v.id(); + j["id"] = to_string(v.id()); j["name"] = v.name(); j["target"] = v.target().to_string(); j["params"] = v.params(); diff --git a/src/util.cc b/src/util.cc index b17bbcb7..8988cbd4 100644 --- a/src/util.cc +++ b/src/util.cc @@ -5,12 +5,12 @@ namespace ion { -std::string argument_name(const std::string& node_id, const std::string& name, int32_t index) { +std::string argument_name(const NodeID & node_id, const std::string& name, int32_t index, const GraphID & graph_id) { if (index == -1) { index = 0; } - std::string s = "_" + node_id + "_" + name + std::to_string(index);; + std::string s = "_" + node_id.value() + "_" + name + std::to_string(index) + "_" + graph_id.value(); std::replace(s.begin(), s.end(), '-', '_'); return s; diff --git a/test/c_api.cc b/test/c_api.cc index c483b2b8..2088d534 100644 --- a/test/c_api.cc +++ b/test/c_api.cc @@ -311,6 +311,200 @@ int main() if (ret != 0) return ret; } + { + + ion_type_t t = {.code=ion_type_int, .bits=32, .lanes=1}; + + + ion_builder_t b; + ret = ion_builder_create(&b); + if (ret != 0) + return ret; + + ret = ion_builder_set_target(b, "host"); + if (ret != 0) + return ret; + + ret = ion_builder_with_bb_module(b, "ion-bb-test"); + if (ret != 0) + return ret; + + ion_graph_t g0; + ret = ion_builder_add_graph(b, "graph0", &g0); + if (ret != 0) + return ret; + + ion_param_t v41; + ret = ion_param_create(&v41, "v", "41"); + if (ret != 0) + return ret; + + ion_node_t n0; + ret = ion_graph_add_node(g0, "test_inc_i32x2" , &n0); + if (ret != 0) + return ret; + int sizes[] = {16, 16}; + + ret = ion_node_set_param(n0, &v41, 1); + if (ret != 0) + return ret; + + ion_port_t ip0; + ret = ion_port_create(&ip0, "input0", t, 2); + if (ret != 0) + return ret; + ret = ion_node_set_iport(n0, &ip0, 1); + if (ret != 0) + return ret; + + + ion_buffer_t ibuf0; + ret = ion_buffer_create(&ibuf0, t, sizes, 2); + if (ret != 0) + return ret; + + int in0[16*16]; + for (int i=0; i<16*16; ++i) { + in0[i] = 0; + } + ret = ion_buffer_write(ibuf0, in0, 16*16*sizeof(int)); + if (ret != 0) + return ret; + + ret = ion_port_bind_buffer(ip0, ibuf0); + if (ret != 0) + return ret; + + ion_port_t op0; + ret = ion_node_get_port(n0, "output", &op0); + if (ret != 0) + return ret; + + ion_buffer_t obuf0; + ret = ion_buffer_create(&obuf0, t, sizes, 2); + if (ret != 0) + return ret; + + ret = ion_port_bind_buffer(op0, obuf0); + if (ret != 0) + return ret; +// +// ret = ion_graph_run(g0); +// if (ret != 0) +// return ret; + + int out0[16*16] = {0}; +// ret = ion_buffer_read(obuf0, out0, 16*16*sizeof(int)); +// for (int i=0;i<16*16; ++i) { +// if (out0[i] != 41) { +// printf("out0: %d\n", out0[i]); +// +// } +// } + + ion_graph_t g1; + ret = ion_builder_add_graph(b, "graph1", &g1); + if (ret != 0) + return ret; + + + + ion_node_t n1; + ret = ion_graph_add_node(g1, "test_inc_i32x2" , &n1); + if (ret != 0) + return ret; + + ret = ion_node_set_param(n1, &v41, 1); + if (ret != 0) + return ret; + + ion_port_t ip1; + ret = ion_port_create(&ip1, "input1", t, 2); + if (ret != 0) + return ret; + ret = ion_node_set_iport(n1, &ip1, 1); + if (ret != 0) + return ret; + + + ion_buffer_t ibuf1; + ret = ion_buffer_create(&ibuf1, t, sizes, 2); + if (ret != 0) + return ret; + + int in1[16*16]; + for (int i=0; i<16*16; ++i) { + in1[i] = 1; + } + ret = ion_buffer_write(ibuf1, in1, 16*16*sizeof(int)); + if (ret != 0) + return ret; + + ret = ion_port_bind_buffer(ip1, ibuf1); + if (ret != 0) + return ret; + + ion_port_t op1; + ret = ion_node_get_port(n1, "output", &op1); + if (ret != 0) + return ret; + + ion_buffer_t obuf1; + ret = ion_buffer_create(&obuf1, t, sizes, 2); + if (ret != 0) + return ret; + + ret = ion_port_bind_buffer(op1, obuf1); + if (ret != 0) + return ret; + +// ret = ion_graph_run(g1); + if (ret != 0) + return ret; + + int out1[16*16] = {0}; +// ret = ion_buffer_read(obuf1, out1, 16*16*sizeof(int)); +// for (int i=0;i<16*16; ++i) { +// if (out1[i] != 42) { +// printf("out1: %d\n", out1[i]); +// } +// } + + for (int i=0;i<16*16; ++i) { + out0[i] =0; + out1[i] =0; + } + + ion_graph_t g2; + ion_graph_create(&g2, b,"graph2"); + + ion_graph_t *graphs = (ion_graph_t*)malloc(2 * sizeof(ion_graph_t)); + graphs[0] = g0; + graphs[1] = g1; + + ret = ion_graph_create_with_multiple(&g2, graphs, 2); + if (ret != 0) + return ret; + ret = ion_graph_run(g2); + if (ret != 0) + return ret; + ret = ion_buffer_read(obuf0, out0, 16*16*sizeof(int)); + ret = ion_buffer_read(obuf1, out1, 16*16*sizeof(int)); + + for (int i=0;i<16*16; ++i) { + if (out0[i] != 41 ) { + printf("out0: %d\n", out0[i]); + ; + } + if (out1[i] != 42 ) { + printf("out1: %d\n", out1[i]); + } + } + ret = ion_graph_destroy(g0); + ret = ion_graph_destroy(g1); + ret = ion_graph_destroy(g2); + + } } catch (Halide::Error &e) { std::cerr << e.what() << std::endl; diff --git a/test/graph.cc b/test/graph.cc index bb5722f9..98f94ab6 100644 --- a/test/graph.cc +++ b/test/graph.cc @@ -4,83 +4,173 @@ using namespace ion; int main() { - try { - Builder b; - b.with_bb_module("ion-bb-test"); - b.set_target(Halide::get_host_target()); - - int32_t size = 16; - - Buffer in0(size, size); - in0.fill(1); - - Buffer in1(size, size); - in1.fill(1); + { + try { + Builder b; + b.with_bb_module("ion-bb-test"); + b.set_target(Halide::get_host_target()); + + int32_t size = 16; + + Buffer in0(size, size); + in0.fill(1); + + Buffer in1(size, size); + in1.fill(1); + + Buffer out0(size, size); + out0.fill(0); + + Buffer out1(size, size); + out1.fill(0); + + Graph g0 = b.add_graph("graph0"); + Node n0 = g0.add("test_inc_i32x2")(in0).set_param(Param("v", 40)); + n0["output"].bind(out0); + g0.run(); + + for (int y=0; y out0(size, size); - out0.fill(0); + Graph g1 = b.add_graph("graph1"); + Node n1 = g1.add("test_inc_i32x2")(in1).set_param(Param("v", 41)); + n1["output"].bind(out1); + g1.run(); + + for (int y=0; y out1(size, size); - out1.fill(0); + out0.fill(0); + out1.fill(0); - Graph g0 = b.add_graph("graph0"); - Node n0 = g0.add("test_inc_i32x2")(in0).set_param(Param("v", 40)); - n0["output"].bind(out0); - g0.run(); + Graph g2(g0 + g1); + g2.run(); - for (int y=0; y ibuf0(std::vector{1, 1}); + + + Port ip0{"input", Halide::type_of(), 2}; + Port vp0{"v", Halide::type_of()}; + + Graph g0 = b.add_graph("graph0"); + + Node n0 = g0.add("test_incx_i32x2")(ip0, vp0); + + ip0.bind(ibuf0); + int32_t v0 = 0; + vp0.bind(&v0); + + Buffer obuf0(std::vector{1, 1}); + n0["output"].bind(obuf0); + + ibuf0(0, 0) = 42; + v0 = 0; + obuf0(0, 0) = 0; + + g0.run(); + if (obuf0(0, 0) != 42) { + std::cerr << "Expected: " << 42 << " Actual:" << obuf0(0, 0) << std::endl; + return 1; } - } - out0.fill(0); - out1.fill(0); + // Test 2 + Port ip1{"input", Halide::type_of(), 2}; + Port vp1{"v", Halide::type_of()}; - Graph g2(g0 + g1); - g2.run(); + Graph g1 = b.add_graph("graph1"); + Node n1 = g1.add("test_incx_i32x2")(ip1, vp1); - for (int y=0; y ibuf1(std::vector{1, 1}); + ip1.bind(ibuf1); + Buffer obuf1(std::vector{1, 1}); + obuf1.fill(0); + n1["output"].bind(obuf1); + + + ibuf1(0, 0) = 42; + v1 = 1; + obuf1(0, 0) = 0; + + g1.run(); + + if (obuf1(0, 0) != 43) { + std::cerr << "Expected: " << 43 << " Actual:" << obuf1(0, 0) << std::endl; + return 1; } - } - } catch (Halide::Error& e) { - std::cerr << e.what() << std::endl; - return 1; - } catch (const std::exception& e) { - std::cerr << e.what() << std::endl; - return 1; - } + // Test 3 + Graph g2(g0 + g1); - std::cout << "Passed" << std::endl; + g2.run(); + if (obuf1(0, 0) != 43) { + std::cerr << "Expected: " << 43 << " Actual:" << obuf1(0, 0) << std::endl; + return 1; + } + if (obuf0(0, 0) != 42) { + std::cerr << "Expected: " << 42 << " Actual:" << obuf0(0, 0) << std::endl; + return 1; + } + std::cout << "second test Passed" << std::endl; + + } catch (Halide::Error &e) { + std::cerr << e.what() << std::endl; + return 1; + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; + } + } + std::cout << "All Passed" << std::endl; return 0; + }