From ec5bb3603df974879b44c6826f7246c9f6d6d95b Mon Sep 17 00:00:00 2001 From: Takuro Iizuka Date: Wed, 27 Dec 2023 18:06:45 -0800 Subject: [PATCH] WIP: vparam --- include/ion/port.h | 84 ++++++++++++++++++++++++++++++++++++++------ src/serializer.h | 30 ++++++++-------- test/array_input.cc | 2 ++ test/export.cc | 13 ++++--- test/inverted_dep.cc | 12 +++---- 5 files changed, 103 insertions(+), 38 deletions(-) diff --git a/include/ion/port.h b/include/ion/port.h index 61c6ca9f..feabca55 100644 --- a/include/ion/port.h +++ b/include/ion/port.h @@ -14,6 +14,8 @@ #include "util.h" +#define SW 0 + namespace ion { /** @@ -27,9 +29,10 @@ class Port { int32_t dimensions; std::string node_id; std::vector params; - std::vector fparams; + std::unordered_map> vparams; + Impl() {} Impl(const std::string& n, const Halide::Type& t, int32_t d, const std::string& nid) @@ -37,8 +40,12 @@ class Port { { if (dimensions == 0) { params = { Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(node_id, name)) }; + + vparams[0] = Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(node_id, name, 0)); } else { fparams = { Halide::ImageParam(type, dimensions, argument_name(node_id, name)) }; + + vparams[0] = Halide::ImageParam(type, dimensions, argument_name(node_id, name, 0)); } } }; @@ -67,24 +74,16 @@ class Port { Port(const std::string& n, Halide::Type t, int32_t d) : impl_(new Impl(n, t, d, "")), index_(-1) {} const std::string& name() const { return impl_->name; } - std::string& name() { return impl_->name; } const Halide::Type& type() const { return impl_->type; } - Halide::Type& type() { return impl_->type; } int32_t dimensions() const { return impl_->dimensions; } - int32_t& dimensions() { return impl_->dimensions; } const std::string& node_id() const { return impl_->node_id; } - std::string& node_id() { return impl_->node_id; } - - const std::vector& params() const { return impl_->params; } - std::vector& params() { return impl_->params; } int32_t size() const { return (impl_->dimensions == 0) ? impl_->params.size() : impl_->fparams.size(); } int32_t index() const { return index_; } - int32_t& index() { return index_; } bool is_bound() const { return !node_id().empty(); @@ -106,25 +105,40 @@ class Port { template void bind(T v) { auto i = index_ == -1 ? 0 : index_; + + // Old if (impl_->params.size() <= i) { impl_->params.resize(i+1); impl_->params[i] = Halide::Internal::Parameter{type(), dimensions() != 0, dimensions(), argument_name(node_id(), name(), i)}; } impl_->params[i].set_scalar(v); + + // New + Halide::Internal::Parameter param{type(), dimensions() != 0, dimensions(), argument_name(node_id(), name(), i)}; + param.set_scalar(v); + impl_->vparams[i] = param; } template void bind(const Halide::Buffer& buf) { auto i = index_ == -1 ? 0 : index_; + + // Old if (impl_->fparams.size() <= i) { impl_->fparams.resize(i+1); impl_->fparams[i] = Halide::ImageParam{type(), dimensions(), argument_name(node_id(), name(), i)}; } impl_->fparams[i].set(buf); + + // New + Halide::ImageParam param{type(), dimensions(), argument_name(node_id(), name(), i)}; + param.set(buf); + impl_->vparams[i] = param; } template void bind(const std::vector>& bufs) { + // Old if (impl_->fparams.size() != bufs.size()) { impl_->fparams.resize(bufs.size()); for (size_t i=0; ifparams[i].set(bufs[i]); } + + // New + for (size_t i=0; ivparams[i] = param; + } } static std::shared_ptr find_impl(uintptr_t ptr) { @@ -154,6 +175,7 @@ class Port { std::vector as_argument() const { std::vector args; +#if SW if (dimensions() == 0) { for (auto i = 0; iparams.size(); ++i) { args.push_back(Halide::Argument(argument_name(node_id(), name(), i), Halide::Argument::InputScalar, type(), dimensions(), Halide::ArgumentEstimates())); @@ -163,13 +185,24 @@ class Port { args.push_back(Halide::Argument(argument_name(node_id(), name(), i), Halide::Argument::InputBuffer, type(), dimensions(), Halide::ArgumentEstimates())); } } +#else + for (const auto& [i, param] : impl_->vparams) { + if (args.size() <= i) { + args.resize(i+1, Halide::Argument()); + } + auto kind = impl_->dimensions == 0 ? Halide::Argument::InputScalar : Halide::Argument::InputBuffer; + args[i] = Halide::Argument(argument_name(impl_->node_id, impl_->name, i), kind, impl_->type, impl_->dimensions, Halide::ArgumentEstimates()); + } + +#endif return args; } std::vector as_instance() const { std::vector instances; +#if SW if (dimensions() == 0) { - for (const auto& param : params()) { + for (const auto& param : impl_->params) { instances.push_back(param.scalar_address()); } } else { @@ -177,6 +210,18 @@ class Port { instances.push_back(fparam.get().raw_buffer()); } } +#else + for (const auto& [i, param] : impl_->vparams) { + if (instances.size() <= i) { + instances.resize(i+1, nullptr); + } + if (auto p = std::get_if(¶m)) { + instances[i] = p->scalar_address(); + } else if (auto p = std::get_if(¶m)) { + instances[i] = p->get().raw_buffer(); + } + } +#endif return instances; } @@ -187,10 +232,20 @@ class Port { } std::vector es; +#if SW int32_t i = 0; for (const auto& p : impl_->params) { es.push_back(Halide::Internal::Variable::make(impl_->type, argument_name(impl_->node_id, impl_->name, i++), p)); } +#else + for (const auto& [i, param] : impl_->vparams) { + if (es.size() <= i) { + es.resize(i+1, Halide::Expr()); + } + es.push_back(Halide::Internal::Variable::make(impl_->type, argument_name(impl_->node_id, impl_->name, i), + *std::get_if(¶m))); + } +#endif return es; } @@ -200,9 +255,18 @@ class Port { } std::vector fs; +#if SW for (const auto& p : impl_->fparams ) { fs.push_back(p); } +#else + for (const auto& [i, param] : impl_->vparams ) { + if (fs.size() <= i) { + fs.resize(i+1, Halide::Func()); + } + fs.push_back(*std::get_if(¶m)); + } +#endif return fs; } diff --git a/src/serializer.h b/src/serializer.h index 8ca31c74..75e35935 100644 --- a/src/serializer.h +++ b/src/serializer.h @@ -43,26 +43,26 @@ template<> class adl_serializer { public: static void to_json(json& j, const ion::Port& v) { - j["name"] = v.name(); - j["type"] = static_cast(v.type()); - j["dimensions"] = v.dimensions(); - j["index"] = v.index(); - j["node_id"] = v.node_id(); - j["array_size"] = v.params().size(); - j["impl"] = reinterpret_cast(v.impl_.get()); + j["name"] = v.impl_->name; + j["type"] = static_cast(v.impl_->type); + j["dimensions"] = v.impl_->dimensions; + j["node_id"] = v.impl_->node_id; + j["array_size"] = v.impl_->params.size(); + j["impl_ptr"] = reinterpret_cast(v.impl_.get()); + j["index"] = v.index_; } static void from_json(const json& j, ion::Port& v) { - v = ion::Port(ion::Port::find_impl(j["impl"].get())); - v.name() = j["name"].get(); - v.type() = j["type"].get(); - v.dimensions() = j["dimensions"]; - v.node_id() = j["node_id"].get(); - v.params() = std::vector( + v = ion::Port(ion::Port::find_impl(j["impl_ptr"].get())); + v.impl_->name = j["name"].get(); + v.impl_->type = j["type"].get(); + v.impl_->dimensions = j["dimensions"]; + v.impl_->node_id = j["node_id"].get(); + v.impl_->params = std::vector( j["array_size"], - Halide::Internal::Parameter(v.type(), v.dimensions() != 0, v.dimensions(), ion::argument_name(v.node_id(), v.name())) + Halide::Internal::Parameter(v.impl_->type, v.impl_->dimensions != 0, v.impl_->dimensions, ion::argument_name(v.impl_->node_id, v.impl_->name)) ); - v.index() = j["index"]; + v.index_ = j["index"]; } }; diff --git a/test/array_input.cc b/test/array_input.cc index 21b996a8..180b4b6a 100644 --- a/test/array_input.cc +++ b/test/array_input.cc @@ -45,6 +45,7 @@ int main() { } pm.set(n["output"], out); + b.compile("array_input_index"); b.run(pm); for (int y = 0; y < h; ++y) { @@ -96,6 +97,7 @@ int main() { pm.set(input, ins); pm.set(n["output"], out); + b.compile("array_input_array"); b.run(pm); for (int y = 0; y < h; ++y) { diff --git a/test/export.cc b/test/export.cc index 22d8fb8b..8d18b31e 100644 --- a/test/export.cc +++ b/test/export.cc @@ -23,20 +23,19 @@ int main() } { Halide::Type t = Halide::type_of(); - Port min0{"min0", t}, extent0{"extent0", t}, min1{"min1", t}, extent1{"extent1", t}, v{"v", t}; Builder b; b.load("simple_graph.json"); PortMap pm; - pm.set(min0, 0); - pm.set(extent0, 2); - pm.set(min1, 0); - pm.set(extent1, 2); - pm.set(v, 1); Halide::Buffer out = Halide::Buffer::make_scalar(); auto nodes = b.nodes(); - pm.set(nodes.back()["output"], out); + pm.set(nodes[1]["min0"], 0); + pm.set(nodes[1]["extent0"], 2); + pm.set(nodes[1]["min1"], 0); + pm.set(nodes[1]["extent1"], 2); + pm.set(nodes[1]["v"], 1); + pm.set(nodes[1]["output"], out); b.run(pm); } diff --git a/test/inverted_dep.cc b/test/inverted_dep.cc index ea7f56a1..7961749d 100644 --- a/test/inverted_dep.cc +++ b/test/inverted_dep.cc @@ -24,7 +24,7 @@ int main() { "array_size": 1, "dimensions": 0, - "impl": 93843319831024, + "impl_ptr": 93843319831024, "index": -1, "name": "output", "node_id": "54f036a3-0b98-4d42-a343-6c7421d15f2f", @@ -37,7 +37,7 @@ int main() { "array_size": 1, "dimensions": 0, - "impl": 93843319821872, + "impl_ptr": 93843319821872, "index": -1, "name": "min0", "node_id": "", @@ -50,7 +50,7 @@ int main() { "array_size": 1, "dimensions": 0, - "impl": 93843319822176, + "impl_ptr": 93843319822176, "index": -1, "name": "extent0", "node_id": "", @@ -63,7 +63,7 @@ int main() { "array_size": 1, "dimensions": 0, - "impl": 93843319822480, + "impl_ptr": 93843319822480, "index": -1, "name": "min1", "node_id": "", @@ -76,7 +76,7 @@ int main() { "array_size": 1, "dimensions": 0, - "impl": 93843319822784, + "impl_ptr": 93843319822784, "index": -1, "name": "extent1", "node_id": "", @@ -89,7 +89,7 @@ int main() { "array_size": 1, "dimensions": 0, - "impl": 93843319823088, + "impl_ptr": 93843319823088, "index": -1, "name": "v", "node_id": "",