Skip to content

Commit

Permalink
Fixed output binding #2
Browse files Browse the repository at this point in the history
  • Loading branch information
Fixstars-iizuka committed Dec 30, 2023
1 parent 5d5603c commit 5871f04
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 65 deletions.
43 changes: 2 additions & 41 deletions include/ion/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,47 +100,8 @@ class Builder {

Halide::Pipeline build(bool implicit_output = false);

std::vector<Halide::Argument> get_arguments_stub() const {
std::set<Port::Channel> added_ports;
std::vector<Halide::Argument> args;
for (const auto& node : nodes_) {
for (const auto& port : node.iports()) {
if (port.has_pred()) {
continue;
}

if (added_ports.count(port.impl_->pred_chan)) {
continue;
}
added_ports.insert(port.impl_->pred_chan);

const auto& port_args(port.as_argument());
args.insert(args.end(), port_args.begin(), port_args.end());
}
}
return args;
}

std::vector<const void*> get_arguments_instance() const {
std::set<Port::Channel> added_ports;
std::vector<const void*> instances;
for (const auto& node : nodes_) {
for (const auto& port : node.iports()) {
if (port.has_pred()) {
continue;
}

if (added_ports.count(port.impl_->pred_chan)) {
continue;
}
added_ports.insert(port.impl_->pred_chan);

const auto& port_instances(port.as_instance());
instances.insert(instances.end(), port_instances.begin(), port_instances.end());
}
}
return instances;
}
std::vector<Halide::Argument> get_arguments_stub() const;
std::vector<const void*> get_arguments_instance() const;

void set_jit_externs(const std::map<std::string, Halide::JITExtern> &externs) {
pipeline_.set_jit_externs(externs);
Expand Down
71 changes: 59 additions & 12 deletions src/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ using json = nlohmann::json;
Builder::Builder()
: jit_ctx_(new Halide::JITUserContext), jit_ctx_ptr_(jit_ctx_.get())
{
args_.push_back(&jit_ctx_ptr_);
}

Builder::~Builder()
Expand Down Expand Up @@ -212,6 +211,12 @@ void Builder::run(ion::PortMap& pm) {
// pipeline_.infer_arguments()) {

callable_ = pipeline_.compile_to_callable(get_arguments_stub(), target_);

args_.clear();
args_.push_back(&jit_ctx_ptr_);

const auto& args(get_arguments_instance());
args_.insert(args_.end(), args.begin(), args.end());
}

callable_.call_argv_fast(args_.size(), args_.data());
Expand Down Expand Up @@ -247,7 +252,6 @@ Halide::Pipeline Builder::build(bool implicit_output) {
}

// Assigning ports
std::set<Port::Channel> added_args;
for (size_t i=0; i<nodes_.size(); ++i) {
auto n = nodes_[i];
const auto& bb = bbs[n.id()];
Expand Down Expand Up @@ -284,15 +288,6 @@ Halide::Pipeline Builder::build(bool implicit_output) {
} else {
throw std::runtime_error("fixme");
}

// Adding input args
if (added_args.count(port.impl_->pred_chan)) {
continue;
}
added_args.insert(port.impl_->pred_chan);

const auto& port_instances(port.as_instance());
args_.insert(args_.end(), port_instances.begin(), port_instances.end());
}
}
bb->build_pipeline();
Expand Down Expand Up @@ -346,7 +341,6 @@ Halide::Pipeline Builder::build(bool implicit_output) {

auto fs(bbs[port.pred_id()]->output_func(port.pred_name()));
output_funcs.insert(output_funcs.end(), fs.begin(), fs.end());
args_.insert(args_.end(), port_instances.begin(), port_instances.end());
}
}
}
Expand Down Expand Up @@ -381,4 +375,57 @@ void Builder::register_disposer(const std::string& bb_id, const std::string& dis
}
}

std::vector<Halide::Argument> Builder::get_arguments_stub() const {
std::set<Port::Channel> added_ports;
std::vector<Halide::Argument> args;
for (const auto& node : nodes_) {
for (const auto& port : node.iports()) {
if (port.has_pred()) {
continue;
}

if (added_ports.count(port.impl_->pred_chan)) {
continue;
}
added_ports.insert(port.impl_->pred_chan);

const auto& port_args(port.as_argument());
args.insert(args.end(), port_args.begin(), port_args.end());
}
}
return args;
}

std::vector<const void*> Builder::get_arguments_instance() const {
std::set<Port::Channel> added_args;
std::vector<const void*> instances;

// Input
for (const auto& node : nodes_) {
for (const auto& port : node.iports()) {
if (port.has_pred()) {
continue;
}

if (added_args.count(port.impl_->pred_chan)) {
continue;
}
added_args.insert(port.impl_->pred_chan);

const auto& port_instances(port.as_instance());
instances.insert(instances.end(), port_instances.begin(), port_instances.end());
}
}

// Output
for (const auto& node : nodes_) {
for (const auto& port : node.oports()) {
const auto& port_instances(port.as_instance());
instances.insert(instances.end(), port_instances.begin(), port_instances.end());
}
}

return instances;
}

} //namespace ion
25 changes: 13 additions & 12 deletions test/array_input.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,35 @@ int main() {

// Index access
{
Port array_input{"array_input", Halide::type_of<int32_t>(), 2};
Port input{"input", Halide::type_of<int32_t>(), 2};
Builder b;
b.set_target(Halide::get_host_target());
Node n;
n = b.add("test_array_copy")(array_input).set_params(Param{"len", std::to_string(len)});
n = b.add("test_array_copy")(input).set_params(Param{"len", std::to_string(len)});
n = b.add("test_array_input")(n["array_output"]).set_params(Param{"len", std::to_string(len)});

std::vector<Halide::Buffer<int32_t>> ins{
Halide::Buffer<int32_t>{w, h},
Halide::Buffer<int32_t>{w, h},
Halide::Buffer<int32_t>{w, h},
Halide::Buffer<int32_t>{w, h},
Halide::Buffer<int32_t>{w, h}
Halide::Buffer<int32_t>{w, h},
Halide::Buffer<int32_t>{w, h},
Halide::Buffer<int32_t>{w, h},
Halide::Buffer<int32_t>{w, h}
};

Halide::Buffer<int32_t> out(w, h);

for (int y = 0; y < h; ++y) {
for (int x = 0; x < w; ++x) {
for (auto &b : ins) {
b(x, y) = y * w + x;
}
out(x, y) = 0;
}
}

Halide::Buffer<int32_t> out(w, h);

PortMap pm;
for (size_t i=0; i<len; ++i) {
pm.set(array_input[i], ins[i]);
pm.set(input[i], ins[i]);
}
pm.set(n["output"], out);

Expand All @@ -66,11 +67,11 @@ int main() {

// Array access
{
Port array_input{"array_input", Halide::type_of<int32_t>(), 2};
Port input{"input", Halide::type_of<int32_t>(), 2};
Builder b;
b.set_target(Halide::get_host_target());
Node n;
n = b.add("test_array_copy")(array_input).set_params(Param{"len", std::to_string(len)});
n = b.add("test_array_copy")(input).set_params(Param{"len", std::to_string(len)});
n = b.add("test_array_input")(n["array_output"]).set_params(Param{"len", std::to_string(len)});

Halide::Buffer<int32_t> in0(w, h), in1(w, h), in2(w, h), in3(w, h), in4(w, h);
Expand All @@ -94,7 +95,7 @@ int main() {
Halide::Buffer<int32_t> out(w, h);

PortMap pm;
pm.set(array_input, ins);
pm.set(input, ins);
pm.set(n["output"], out);

b.compile("array_input_array");
Expand Down

0 comments on commit 5871f04

Please sign in to comment.