Skip to content

Commit

Permalink
WIP: vparam
Browse files Browse the repository at this point in the history
  • Loading branch information
Fixstars-iizuka committed Dec 28, 2023
1 parent c9b2c71 commit ec5bb36
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 38 deletions.
84 changes: 74 additions & 10 deletions include/ion/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "util.h"

#define SW 0

namespace ion {

/**
Expand All @@ -27,18 +29,23 @@ class Port {
int32_t dimensions;
std::string node_id;
std::vector<Halide::Internal::Parameter> params;

std::vector<Halide::ImageParam> fparams;

std::unordered_map<int32_t, std::variant<Halide::Internal::Parameter, Halide::ImageParam>> vparams;

Impl() {}

Impl(const std::string& n, const Halide::Type& t, int32_t d, const std::string& nid)
: name(n), type(t), dimensions(d), node_id(nid)
{
if (dimensions == 0) {
params = { Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(node_id, name)) };

vparams[0] = Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(node_id, name, 0));
} else {
fparams = { Halide::ImageParam(type, dimensions, argument_name(node_id, name)) };

vparams[0] = Halide::ImageParam(type, dimensions, argument_name(node_id, name, 0));
}
}
};
Expand Down Expand Up @@ -67,24 +74,16 @@ class Port {
Port(const std::string& n, Halide::Type t, int32_t d) : impl_(new Impl(n, t, d, "")), index_(-1) {}

const std::string& name() const { return impl_->name; }
std::string& name() { return impl_->name; }

const Halide::Type& type() const { return impl_->type; }
Halide::Type& type() { return impl_->type; }

int32_t dimensions() const { return impl_->dimensions; }
int32_t& dimensions() { return impl_->dimensions; }

const std::string& node_id() const { return impl_->node_id; }
std::string& node_id() { return impl_->node_id; }

const std::vector<Halide::Internal::Parameter>& params() const { return impl_->params; }
std::vector<Halide::Internal::Parameter>& params() { return impl_->params; }

int32_t size() const { return (impl_->dimensions == 0) ? impl_->params.size() : impl_->fparams.size(); }

int32_t index() const { return index_; }
int32_t& index() { return index_; }

bool is_bound() const {
return !node_id().empty();
Expand All @@ -106,25 +105,40 @@ class Port {
template<typename T>
void bind(T v) {
auto i = index_ == -1 ? 0 : index_;

// Old
if (impl_->params.size() <= i) {
impl_->params.resize(i+1);
impl_->params[i] = Halide::Internal::Parameter{type(), dimensions() != 0, dimensions(), argument_name(node_id(), name(), i)};
}
impl_->params[i].set_scalar(v);

// New
Halide::Internal::Parameter param{type(), dimensions() != 0, dimensions(), argument_name(node_id(), name(), i)};
param.set_scalar(v);
impl_->vparams[i] = param;
}

template<typename T>
void bind(const Halide::Buffer<T>& buf) {
auto i = index_ == -1 ? 0 : index_;

// Old
if (impl_->fparams.size() <= i) {
impl_->fparams.resize(i+1);
impl_->fparams[i] = Halide::ImageParam{type(), dimensions(), argument_name(node_id(), name(), i)};
}
impl_->fparams[i].set(buf);

// New
Halide::ImageParam param{type(), dimensions(), argument_name(node_id(), name(), i)};
param.set(buf);
impl_->vparams[i] = param;
}

template<typename T>
void bind(const std::vector<Halide::Buffer<T>>& bufs) {
// Old
if (impl_->fparams.size() != bufs.size()) {
impl_->fparams.resize(bufs.size());
for (size_t i=0; i<bufs.size(); ++i) {
Expand All @@ -134,6 +148,13 @@ class Port {
for (size_t i=0; i<bufs.size(); ++i) {
impl_->fparams[i].set(bufs[i]);
}

// New
for (size_t i=0; i<bufs.size(); ++i) {
Halide::ImageParam param{type(), dimensions(), argument_name(node_id(), name(), i)};
param.set(bufs[i]);
impl_->vparams[i] = param;
}
}

static std::shared_ptr<Impl> find_impl(uintptr_t ptr) {
Expand All @@ -154,6 +175,7 @@ class Port {

std::vector<Halide::Argument> as_argument() const {
std::vector<Halide::Argument> args;
#if SW
if (dimensions() == 0) {
for (auto i = 0; i<impl_->params.size(); ++i) {
args.push_back(Halide::Argument(argument_name(node_id(), name(), i), Halide::Argument::InputScalar, type(), dimensions(), Halide::ArgumentEstimates()));
Expand All @@ -163,20 +185,43 @@ class Port {
args.push_back(Halide::Argument(argument_name(node_id(), name(), i), Halide::Argument::InputBuffer, type(), dimensions(), Halide::ArgumentEstimates()));
}
}
#else
for (const auto& [i, param] : impl_->vparams) {
if (args.size() <= i) {
args.resize(i+1, Halide::Argument());
}
auto kind = impl_->dimensions == 0 ? Halide::Argument::InputScalar : Halide::Argument::InputBuffer;
args[i] = Halide::Argument(argument_name(impl_->node_id, impl_->name, i), kind, impl_->type, impl_->dimensions, Halide::ArgumentEstimates());
}

#endif
return args;
}

std::vector<const void *> as_instance() const {
std::vector<const void *> instances;
#if SW
if (dimensions() == 0) {
for (const auto& param : params()) {
for (const auto& param : impl_->params) {
instances.push_back(param.scalar_address());
}
} else {
for (const auto& fparam : impl_->fparams) {
instances.push_back(fparam.get().raw_buffer());
}
}
#else
for (const auto& [i, param] : impl_->vparams) {
if (instances.size() <= i) {
instances.resize(i+1, nullptr);
}
if (auto p = std::get_if<Halide::Internal::Parameter>(&param)) {
instances[i] = p->scalar_address();
} else if (auto p = std::get_if<Halide::ImageParam>(&param)) {
instances[i] = p->get().raw_buffer();
}
}
#endif

return instances;
}
Expand All @@ -187,10 +232,20 @@ class Port {
}

std::vector<Halide::Expr> es;
#if SW
int32_t i = 0;
for (const auto& p : impl_->params) {
es.push_back(Halide::Internal::Variable::make(impl_->type, argument_name(impl_->node_id, impl_->name, i++), p));
}
#else
for (const auto& [i, param] : impl_->vparams) {
if (es.size() <= i) {
es.resize(i+1, Halide::Expr());
}
es.push_back(Halide::Internal::Variable::make(impl_->type, argument_name(impl_->node_id, impl_->name, i),
*std::get_if<Halide::Internal::Parameter>(&param)));
}
#endif
return es;
}

Expand All @@ -200,9 +255,18 @@ class Port {
}

std::vector<Halide::Func> fs;
#if SW
for (const auto& p : impl_->fparams ) {
fs.push_back(p);
}
#else
for (const auto& [i, param] : impl_->vparams ) {
if (fs.size() <= i) {
fs.resize(i+1, Halide::Func());
}
fs.push_back(*std::get_if<Halide::ImageParam>(&param));
}
#endif
return fs;
}

Expand Down
30 changes: 15 additions & 15 deletions src/serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,26 @@ template<>
class adl_serializer<ion::Port> {
public:
static void to_json(json& j, const ion::Port& v) {
j["name"] = v.name();
j["type"] = static_cast<halide_type_t>(v.type());
j["dimensions"] = v.dimensions();
j["index"] = v.index();
j["node_id"] = v.node_id();
j["array_size"] = v.params().size();
j["impl"] = reinterpret_cast<uintptr_t>(v.impl_.get());
j["name"] = v.impl_->name;
j["type"] = static_cast<halide_type_t>(v.impl_->type);
j["dimensions"] = v.impl_->dimensions;
j["node_id"] = v.impl_->node_id;
j["array_size"] = v.impl_->params.size();
j["impl_ptr"] = reinterpret_cast<uintptr_t>(v.impl_.get());
j["index"] = v.index_;
}

static void from_json(const json& j, ion::Port& v) {
v = ion::Port(ion::Port::find_impl(j["impl"].get<uintptr_t>()));
v.name() = j["name"].get<std::string>();
v.type() = j["type"].get<halide_type_t>();
v.dimensions() = j["dimensions"];
v.node_id() = j["node_id"].get<std::string>();
v.params() = std::vector<Halide::Internal::Parameter>(
v = ion::Port(ion::Port::find_impl(j["impl_ptr"].get<uintptr_t>()));
v.impl_->name = j["name"].get<std::string>();
v.impl_->type = j["type"].get<halide_type_t>();
v.impl_->dimensions = j["dimensions"];
v.impl_->node_id = j["node_id"].get<std::string>();
v.impl_->params = std::vector<Halide::Internal::Parameter>(
j["array_size"],
Halide::Internal::Parameter(v.type(), v.dimensions() != 0, v.dimensions(), ion::argument_name(v.node_id(), v.name()))
Halide::Internal::Parameter(v.impl_->type, v.impl_->dimensions != 0, v.impl_->dimensions, ion::argument_name(v.impl_->node_id, v.impl_->name))
);
v.index() = j["index"];
v.index_ = j["index"];
}
};

Expand Down
2 changes: 2 additions & 0 deletions test/array_input.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ int main() {
}
pm.set(n["output"], out);

b.compile("array_input_index");
b.run(pm);

for (int y = 0; y < h; ++y) {
Expand Down Expand Up @@ -96,6 +97,7 @@ int main() {
pm.set(input, ins);
pm.set(n["output"], out);

b.compile("array_input_array");
b.run(pm);

for (int y = 0; y < h; ++y) {
Expand Down
13 changes: 6 additions & 7 deletions test/export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,19 @@ int main()
}
{
Halide::Type t = Halide::type_of<int32_t>();
Port min0{"min0", t}, extent0{"extent0", t}, min1{"min1", t}, extent1{"extent1", t}, v{"v", t};
Builder b;
b.load("simple_graph.json");
PortMap pm;
pm.set(min0, 0);
pm.set(extent0, 2);
pm.set(min1, 0);
pm.set(extent1, 2);
pm.set(v, 1);

Halide::Buffer<int32_t> out = Halide::Buffer<int32_t>::make_scalar();

auto nodes = b.nodes();
pm.set(nodes.back()["output"], out);
pm.set(nodes[1]["min0"], 0);
pm.set(nodes[1]["extent0"], 2);
pm.set(nodes[1]["min1"], 0);
pm.set(nodes[1]["extent1"], 2);
pm.set(nodes[1]["v"], 1);
pm.set(nodes[1]["output"], out);

b.run(pm);
}
Expand Down
12 changes: 6 additions & 6 deletions test/inverted_dep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int main()
{
"array_size": 1,
"dimensions": 0,
"impl": 93843319831024,
"impl_ptr": 93843319831024,
"index": -1,
"name": "output",
"node_id": "54f036a3-0b98-4d42-a343-6c7421d15f2f",
Expand All @@ -37,7 +37,7 @@ int main()
{
"array_size": 1,
"dimensions": 0,
"impl": 93843319821872,
"impl_ptr": 93843319821872,
"index": -1,
"name": "min0",
"node_id": "",
Expand All @@ -50,7 +50,7 @@ int main()
{
"array_size": 1,
"dimensions": 0,
"impl": 93843319822176,
"impl_ptr": 93843319822176,
"index": -1,
"name": "extent0",
"node_id": "",
Expand All @@ -63,7 +63,7 @@ int main()
{
"array_size": 1,
"dimensions": 0,
"impl": 93843319822480,
"impl_ptr": 93843319822480,
"index": -1,
"name": "min1",
"node_id": "",
Expand All @@ -76,7 +76,7 @@ int main()
{
"array_size": 1,
"dimensions": 0,
"impl": 93843319822784,
"impl_ptr": 93843319822784,
"index": -1,
"name": "extent1",
"node_id": "",
Expand All @@ -89,7 +89,7 @@ int main()
{
"array_size": 1,
"dimensions": 0,
"impl": 93843319823088,
"impl_ptr": 93843319823088,
"index": -1,
"name": "v",
"node_id": "",
Expand Down

0 comments on commit ec5bb36

Please sign in to comment.