diff --git a/include/ion/builder.h b/include/ion/builder.h index b7a099c9..593684c9 100644 --- a/include/ion/builder.h +++ b/include/ion/builder.h @@ -100,47 +100,8 @@ class Builder { Halide::Pipeline build(bool implicit_output = false); - std::vector get_arguments_stub() const { - std::set added_ports; - std::vector args; - for (const auto& node : nodes_) { - for (const auto& port : node.iports()) { - if (port.has_pred()) { - continue; - } - - if (added_ports.count(port.impl_->pred_chan)) { - continue; - } - added_ports.insert(port.impl_->pred_chan); - - const auto& port_args(port.as_argument()); - args.insert(args.end(), port_args.begin(), port_args.end()); - } - } - return args; - } - - std::vector get_arguments_instance() const { - std::set added_ports; - std::vector instances; - for (const auto& node : nodes_) { - for (const auto& port : node.iports()) { - if (port.has_pred()) { - continue; - } - - if (added_ports.count(port.impl_->pred_chan)) { - continue; - } - added_ports.insert(port.impl_->pred_chan); - - const auto& port_instances(port.as_instance()); - instances.insert(instances.end(), port_instances.begin(), port_instances.end()); - } - } - return instances; - } + std::vector get_arguments_stub() const; + std::vector get_arguments_instance() const; void set_jit_externs(const std::map &externs) { pipeline_.set_jit_externs(externs); diff --git a/src/builder.cc b/src/builder.cc index b7c7687b..5d380391 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -83,7 +83,6 @@ using json = nlohmann::json; Builder::Builder() : jit_ctx_(new Halide::JITUserContext), jit_ctx_ptr_(jit_ctx_.get()) { - args_.push_back(&jit_ctx_ptr_); } Builder::~Builder() @@ -212,6 +211,12 @@ void Builder::run(ion::PortMap& pm) { // pipeline_.infer_arguments()) { callable_ = pipeline_.compile_to_callable(get_arguments_stub(), target_); + + args_.clear(); + args_.push_back(&jit_ctx_ptr_); + + const auto& args(get_arguments_instance()); + args_.insert(args_.end(), args.begin(), args.end()); } callable_.call_argv_fast(args_.size(), args_.data()); @@ -247,7 +252,6 @@ Halide::Pipeline Builder::build(bool implicit_output) { } // Assigning ports - std::set added_args; for (size_t i=0; ipred_chan)) { - continue; - } - added_args.insert(port.impl_->pred_chan); - - const auto& port_instances(port.as_instance()); - args_.insert(args_.end(), port_instances.begin(), port_instances.end()); } } bb->build_pipeline(); @@ -346,7 +341,6 @@ Halide::Pipeline Builder::build(bool implicit_output) { auto fs(bbs[port.pred_id()]->output_func(port.pred_name())); output_funcs.insert(output_funcs.end(), fs.begin(), fs.end()); - args_.insert(args_.end(), port_instances.begin(), port_instances.end()); } } } @@ -381,4 +375,57 @@ void Builder::register_disposer(const std::string& bb_id, const std::string& dis } } +std::vector Builder::get_arguments_stub() const { + std::set added_ports; + std::vector args; + for (const auto& node : nodes_) { + for (const auto& port : node.iports()) { + if (port.has_pred()) { + continue; + } + + if (added_ports.count(port.impl_->pred_chan)) { + continue; + } + added_ports.insert(port.impl_->pred_chan); + + const auto& port_args(port.as_argument()); + args.insert(args.end(), port_args.begin(), port_args.end()); + } + } + return args; +} + +std::vector Builder::get_arguments_instance() const { + std::set added_args; + std::vector instances; + + // Input + for (const auto& node : nodes_) { + for (const auto& port : node.iports()) { + if (port.has_pred()) { + continue; + } + + if (added_args.count(port.impl_->pred_chan)) { + continue; + } + added_args.insert(port.impl_->pred_chan); + + const auto& port_instances(port.as_instance()); + instances.insert(instances.end(), port_instances.begin(), port_instances.end()); + } + } + + // Output + for (const auto& node : nodes_) { + for (const auto& port : node.oports()) { + const auto& port_instances(port.as_instance()); + instances.insert(instances.end(), port_instances.begin(), port_instances.end()); + } + } + + return instances; +} + } //namespace ion diff --git a/test/array_input.cc b/test/array_input.cc index 206c4443..a26ba4cd 100644 --- a/test/array_input.cc +++ b/test/array_input.cc @@ -14,34 +14,35 @@ int main() { // Index access { - Port array_input{"array_input", Halide::type_of(), 2}; + Port input{"input", Halide::type_of(), 2}; Builder b; b.set_target(Halide::get_host_target()); Node n; - n = b.add("test_array_copy")(array_input).set_params(Param{"len", std::to_string(len)}); + n = b.add("test_array_copy")(input).set_params(Param{"len", std::to_string(len)}); n = b.add("test_array_input")(n["array_output"]).set_params(Param{"len", std::to_string(len)}); std::vector> ins{ Halide::Buffer{w, h}, - Halide::Buffer{w, h}, - Halide::Buffer{w, h}, - Halide::Buffer{w, h}, - Halide::Buffer{w, h} + Halide::Buffer{w, h}, + Halide::Buffer{w, h}, + Halide::Buffer{w, h}, + Halide::Buffer{w, h} }; + Halide::Buffer out(w, h); + for (int y = 0; y < h; ++y) { for (int x = 0; x < w; ++x) { for (auto &b : ins) { b(x, y) = y * w + x; } + out(x, y) = 0; } } - Halide::Buffer out(w, h); - PortMap pm; for (size_t i=0; i(), 2}; + Port input{"input", Halide::type_of(), 2}; Builder b; b.set_target(Halide::get_host_target()); Node n; - n = b.add("test_array_copy")(array_input).set_params(Param{"len", std::to_string(len)}); + n = b.add("test_array_copy")(input).set_params(Param{"len", std::to_string(len)}); n = b.add("test_array_input")(n["array_output"]).set_params(Param{"len", std::to_string(len)}); Halide::Buffer in0(w, h), in1(w, h), in2(w, h), in3(w, h), in4(w, h); @@ -94,7 +95,7 @@ int main() { Halide::Buffer out(w, h); PortMap pm; - pm.set(array_input, ins); + pm.set(input, ins); pm.set(n["output"], out); b.compile("array_input_array");