Skip to content

Commit

Permalink
replace NodeId to string
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu Li authored and Xinyu Li committed Mar 6, 2024
1 parent 42d5caa commit 4e48349
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 73 deletions.
28 changes: 12 additions & 16 deletions include/ion/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class Node {
std::vector<Halide::Internal::AbstractGenerator::ArgInfo> arginfos;

Impl(): id(), name(), target(), params(), ports() {}
Impl(const std::string& id_, const std::string& name_, const Halide::Target& target_);
Impl(const std::string& id_, const std::string& name_, const Halide::Target& target_, const GraphID &graph_id_);
Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_);
Impl(const NodeID& id_, const std::string& name_, const Halide::Target& target_, const GraphID &graph_id_);
};

public:
Expand Down Expand Up @@ -77,7 +77,7 @@ class Node {
*/
template<typename... Args>
Node operator()(Args ...args) {
set_iport(std::vector<Port>{get_iport(args)...});
set_iport(std::vector<Port>{make_iport(args)...});
return *this;
}

Expand All @@ -99,10 +99,6 @@ class Node {
return impl_->id;
}

const std::string& id_to_string() const {
return to_string(impl_->id);
}

const std::string& name() const {
return impl_->name;
}
Expand All @@ -127,38 +123,38 @@ class Node {

private:
Node(const std::string& id, const std::string& name, const Halide::Target& target)
: impl_(new Impl{id, name, target})
: impl_(new Impl{NodeID(id), name, target})
{
}

Node(const std::string& id, const std::string& name, const Halide::Target& target, const GraphID& graph_id)
: impl_(new Impl{id, name, target, graph_id})
: impl_(new Impl{NodeID(id), name, target, graph_id})
{
}

Port get_iport(Port arg) const {
Port make_iport(Port arg) const {
return arg;
}

template<typename T>
Port get_iport(T *vptr) const {
if (impl_->graph_id.value().empty())
Port make_iport(T *vptr) const {
if (to_string(impl_->graph_id).empty())
return Port(vptr);
else
return Port(vptr, impl_->graph_id);
}

template<typename T>
Port get_iport(Halide::Buffer<T>& arg) const {
if (impl_->graph_id.value().empty())
Port make_iport(Halide::Buffer<T>& arg) const {
if (to_string(impl_->graph_id).empty())
return Port(arg);
else
return Port(arg, impl_->graph_id);
}

template<typename T>
Port get_iport(std::vector<Halide::Buffer<T>>& arg) const {
if (impl_->graph_id.value().empty())
Port make_iport(std::vector<Halide::Buffer<T>>& arg) const {
if (to_string(impl_->graph_id).empty())
return Port(arg);
else
return Port(arg, impl_->graph_id);
Expand Down
50 changes: 24 additions & 26 deletions include/ion/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ class Port {
std::unordered_map<uint32_t, const void *> instances;

Impl();
Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d, const GraphID &gid );
Impl(const NodeID& nid, const std::string& pn, const Halide::Type& t, int32_t d, const GraphID &gid );
};

public:

Port() : impl_(new Impl("", "", Halide::Type(), 0, GraphID())), index_(-1) {}
Port() : impl_(new Impl(NodeID(""), "", Halide::Type(), 0, GraphID(""))), index_(-1) {}

Port(const std::shared_ptr<Impl>& impl, int32_t index) : impl_(impl), index_(index) {}

Expand All @@ -79,22 +79,22 @@ class Port {
* @arg k: The key of the port which should be matched with BuildingBlock Input/Output name.
* @arg t: The type of the value.
*/
Port(const std::string& n, Halide::Type t) : impl_(new Impl("", n, t, 0, GraphID(""))), index_(-1) {}
Port(const std::string& n, Halide::Type t) : impl_(new Impl(NodeID(""), n, t, 0, GraphID(""))), index_(-1) {}

/**
* Construct new port for vector value.
* @arg k: The key of the port which should be matched with BuildingBlock Input/Output name.
* @arg t: The type of the element value.
* @arg d: The dimension of the port. The range is 1 to 4.
*/
Port(const std::string& n, Halide::Type t, int32_t d) : impl_(new Impl("", n, t, d, GraphID(""))), index_(-1) {}
Port(const std::string& n, Halide::Type t, int32_t d) : impl_(new Impl(NodeID(""), n, t, d, GraphID(""))), index_(-1) {}

/**
* Construct new port from scalar pointer
*/
template<typename T,
typename std::enable_if<std::is_arithmetic<T>::value>::type* = nullptr>
Port(T *vptr) : impl_(new Impl("", Halide::Internal::unique_name("_ion_port_"), Halide::type_of<T>(), 0, GraphID(""))), index_(-1) {
Port(T *vptr) : impl_(new Impl(NodeID(""), Halide::Internal::unique_name("_ion_port_"), Halide::type_of<T>(), 0, GraphID(""))), index_(-1) {
this->bind(vptr);
}

Expand All @@ -103,7 +103,7 @@ class Port {
*/
template<typename T,
typename std::enable_if<std::is_arithmetic<T>::value>::type* = nullptr>
Port(T *vptr, const GraphID & gid) : impl_(new Impl("", Halide::Internal::unique_name("_ion_port_"), Halide::type_of<T>(), 0, gid)), index_(-1) {
Port(T *vptr, const GraphID & gid) : impl_(new Impl(NodeID(""), Halide::Internal::unique_name("_ion_port_"), Halide::type_of<T>(), 0, gid)), index_(-1) {
this->bind(vptr);
}

Expand All @@ -112,63 +112,61 @@ class Port {
* Construct new port from buffer
*/
template<typename T>
Port(const Halide::Buffer<T>& buf) : impl_(new Impl("", buf.name(), buf.type(), buf.dimensions(), GraphID(""))), index_(-1) {
Port(const Halide::Buffer<T>& buf) : impl_(new Impl(NodeID(""), buf.name(), buf.type(), buf.dimensions(), GraphID(""))), index_(-1) {
this->bind(buf);
}

/**
* Construct new port from buffer and bind graph id to port
*/
template<typename T>
Port(const Halide::Buffer<T>& buf, const GraphID & gid) : impl_(new Impl("", buf.name(), buf.type(), buf.dimensions(), gid)), index_(-1) {
Port(const Halide::Buffer<T>& buf, const GraphID & gid) : impl_(new Impl(NodeID(""), buf.name(), buf.type(), buf.dimensions(), gid)), index_(-1) {
this->bind(buf);
}

/**
* Construct new port from array of buffer
*/
template<typename T>
Port(const std::vector<Halide::Buffer<T>>& bufs) : impl_(new Impl("", unify_name(bufs), Halide::type_of<T>(), unify_dimension(bufs), GraphID(""))), index_(-1) {
Port(const std::vector<Halide::Buffer<T>>& bufs) : impl_(new Impl(NodeID(""), unify_name(bufs), Halide::type_of<T>(), unify_dimension(bufs), GraphID(""))), index_(-1) {
this->bind(bufs);
}

/**
* Construct new port from array of buffer and bind graph id to port
*/
template<typename T>
Port(const std::vector<Halide::Buffer<T>>& bufs, const GraphID & gid) : impl_(new Impl("", unify_name(bufs), Halide::type_of<T>(), unify_dimension(bufs), gid)), index_(-1) {
Port(const std::vector<Halide::Buffer<T>>& bufs, const GraphID & gid) : impl_(new Impl(NodeID(""), unify_name(bufs), Halide::type_of<T>(), unify_dimension(bufs), gid)), index_(-1) {
this->bind(bufs);
}

// Getter
const PortID id() const { return impl_->id; }
const std::string& id_to_string() const { return to_string(impl_->id); }
const Channel& pred_chan() const { return impl_->pred_chan; }
const NodeID& pred_id() const { return std::get<0>(impl_->pred_chan); }
const std::string& pred_id_to_string() const { return to_string(std::get<0>(impl_->pred_chan)); }
const std::string& pred_name() const { return std::get<1>(impl_->pred_chan); }
const std::set<Channel>& succ_chans() const { return impl_->succ_chans; }
const Halide::Type& type() const { return impl_->type; }
int32_t dimensions() const { return impl_->dimensions; }
int32_t size() const { return static_cast<int32_t>(impl_->params.size()); }
int32_t index() const { return index_; }
const std::string& graph_id_to_string() const { return to_string(impl_->graph_id); }
const GraphID& graph_id() const { return impl_->graph_id; }

// Setter
void set_index(int index) { index_ = index; }

// Util
bool has_pred() const { return !std::get<0>(impl_->pred_chan).value().empty(); }
bool has_pred_by_nid(const std::string& nid) const { return !to_string(std::get<0>(impl_->pred_chan)).empty(); }
bool has_pred_by_nid(const NodeID & nid) const { return !to_string(std::get<0>(impl_->pred_chan)).empty(); }
bool has_succ() const { return !impl_->succ_chans.empty(); }
bool has_succ(const Channel& c) const { return impl_->succ_chans.count(c); }
bool has_succ_by_nid(const std::string& nid) const {
bool has_succ_by_nid(const NodeID& nid) const {
return std::count_if(impl_->succ_chans.begin(),
impl_->succ_chans.end(),
[&](const Port::Channel& c) { return std::get<0>(c) == nid; });
}

void determine_succ(const std::string& nid, const std::string& old_pn, const std::string& new_pn);
void determine_succ(const NodeID& nid, const std::string& old_pn, const std::string& new_pn);

/**
* Overloaded operator to set the port index and return a reference to the current port. eg. port[0]
Expand All @@ -183,9 +181,9 @@ class Port {
void bind(T *v) {
auto i = index_ == -1 ? 0 : index_;
if (has_pred()) {
impl_->params[i] = Halide::Internal::Parameter{Halide::type_of<T>(), false, 0, argument_name(pred_id(), pred_name(), i, graph_id_to_string())};
impl_->params[i] = Halide::Internal::Parameter{Halide::type_of<T>(), false, 0, argument_name(pred_id(), pred_name(), i, graph_id())};
} else {
impl_->params[i] = Halide::Internal::Parameter{type(), false, dimensions(), argument_name(pred_id(), pred_name(), i,graph_id_to_string())};
impl_->params[i] = Halide::Internal::Parameter{type(), false, dimensions(), argument_name(pred_id(), pred_name(), i, graph_id())};
}

impl_->instances[i] = v;
Expand All @@ -196,9 +194,9 @@ class Port {
void bind(const Halide::Buffer<T>& buf) {
auto i = index_ == -1 ? 0 : index_;
if (has_pred()) {
impl_->params[i] = Halide::Internal::Parameter{buf.type(), true, buf.dimensions(), argument_name(pred_id(), pred_name(), i,graph_id_to_string())};
impl_->params[i] = Halide::Internal::Parameter{buf.type(), true, buf.dimensions(), argument_name(pred_id(), pred_name(), i,graph_id())};
} else {
impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i,graph_id_to_string())};
impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i,graph_id())};
}

impl_->instances[i] = buf.raw_buffer();
Expand All @@ -208,9 +206,9 @@ class Port {
void bind(const std::vector<Halide::Buffer<T>>& bufs) {
for (int i=0; i<static_cast<int>(bufs.size()); ++i) {
if (has_pred()) {
impl_->params[i] = Halide::Internal::Parameter{bufs[i].type(), true, bufs[i].dimensions(), argument_name(pred_id(), pred_name(), i, graph_id_to_string())};
impl_->params[i] = Halide::Internal::Parameter{bufs[i].type(), true, bufs[i].dimensions(), argument_name(pred_id(), pred_name(), i, graph_id())};
} else {
impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i, graph_id_to_string())};
impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i, graph_id())};
}

impl_->instances[i] = bufs[i].raw_buffer();
Expand All @@ -229,7 +227,7 @@ class Port {
if (es.size() <= i) {
es.resize(i+1, Halide::Expr());
}
es[i] = Halide::Internal::Variable::make(type(), argument_name(pred_id(), pred_name(), i, graph_id_to_string()), param);
es[i] = Halide::Internal::Variable::make(type(), argument_name(pred_id(), pred_name(), i, graph_id()), param);
}
return es;
}
Expand All @@ -250,7 +248,7 @@ class Port {
args.push_back(Halide::Var::implicit(i));
args_expr.push_back(Halide::Var::implicit(i));
}
Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), pred_name(), i, graph_id_to_string()) + "_im");
Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), pred_name(), i, graph_id()) + "_im");
f(args) = Halide::Internal::Call::make(param, args_expr);
fs[i] = f;
}
Expand All @@ -264,7 +262,7 @@ class Port {
args.resize(i+1, Halide::Argument());
}
auto kind = dimensions() == 0 ? Halide::Argument::InputScalar : Halide::Argument::InputBuffer;
args[i] = Halide::Argument(argument_name(pred_id(), pred_name(), i, graph_id_to_string()), kind, type(), dimensions(), Halide::ArgumentEstimates());
args[i] = Halide::Argument(argument_name(pred_id(), pred_name(), i, graph_id()), kind, type(), dimensions(), Halide::ArgumentEstimates());
}
return args;
}
Expand All @@ -287,7 +285,7 @@ class Port {
* pid and pn is stored in both pred and succ,
* then it will determined through pipeline build process.
*/
Port(const std::string& pid, const std::string& pn) : impl_(new Impl(pid, pn, Halide::Type(), 0, GraphID(""))), index_(-1) {}
Port(const NodeID & nid, const std::string& pn) : impl_(new Impl(nid, pn, Halide::Type(), 0, GraphID(""))), index_(-1) {}

std::shared_ptr<Impl> impl_;

Expand Down
2 changes: 1 addition & 1 deletion include/ion/port_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class PortMap {
template<typename T>
[[deprecated("Port::bind can be used instead of PortMap.")]]
void set(Port port, T v) {
auto& buf(scalar_buffer_[argument_name(port.pred_id_to_string(), port.pred_name(), port.index(), port.graph_id_to_string())]);
auto& buf(scalar_buffer_[argument_name(port.pred_id(), port.pred_name(), port.index(), port.graph_id())]);
buf.resize(sizeof(v));
std::memcpy(buf.data(), &v, sizeof(v));
port.bind(reinterpret_cast<T*>(buf.data()));
Expand Down
10 changes: 9 additions & 1 deletion include/ion/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,20 @@ template<class Tag>
struct StringID {
using tag_type = Tag;

// needs to be default-constuctable because of use in map[] below
// needs to be default-constructable because of use in map[] below
StringID(std::string s) : _value(std::move(s)) {}
StringID() : _value() {}
// provide access to the underlying string value
const std::string &value() const { return _value; }

struct StringIDHash {
// Use hash of string as hash function.
size_t operator()(const StringID& id) const
{
return std::hash<std::string>()(id.value());
}
};

private:
std::string _value;

Expand Down
24 changes: 12 additions & 12 deletions src/lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void determine_and_validate(std::vector<Node>& nodes) {
throw std::runtime_error(msg);
}

port.determine_succ(n.id_to_string(), pn, arginfo.name);
port.determine_succ(n.id(), pn, arginfo.name);
pn = arginfo.name;
}

Expand Down Expand Up @@ -218,27 +218,27 @@ Halide::Pipeline lower(Builder builder, std::vector<Node>& nodes, bool implicit_
topological_sort(nodes);

// Constructing Generator object and setting static parameters
std::unordered_map<std::string, Halide::Internal::AbstractGeneratorPtr> bbs;
std::unordered_map<NodeID, Halide::Internal::AbstractGeneratorPtr, NodeID::StringIDHash> bbs;
for (auto n : nodes) {
auto bb(Halide::Internal::GeneratorRegistry::create(n.name(), Halide::GeneratorContext(n.target())));

// Default parameter
Halide::GeneratorParamsMap params;
params["builder_impl_ptr"] = std::to_string(reinterpret_cast<uint64_t>(builder.impl_ptr()));
params["bb_id"] = n.id_to_string();
params["bb_id"] = to_string(n.id());

// User defined parameter
for (const auto& p : n.params()) {
params[p.key()] = p.val();
}
bb->set_generatorparam_values(params);
bbs[n.id_to_string()] = std::move(bb);
bbs[n.id()] = std::move(bb);
}

// Assigning ports and build pipeline
for (size_t i=0; i<nodes.size(); ++i) {
auto n = nodes[i];
const auto& bb = bbs[n.id_to_string()];
const auto& bb = bbs[n.id()];
auto arginfos = bb->arginfos();
for (const auto& [pn, port] : n.iports()) {

Expand All @@ -254,7 +254,7 @@ Halide::Pipeline lower(Builder builder, std::vector<Node>& nodes, bool implicit_
auto index = port.index();

if (port.has_pred()) {
const auto& pred_bb(bbs[port.pred_id_to_string()]);
const auto& pred_bb(bbs[port.pred_id()]);
auto fs = pred_bb->output_func(port.pred_name());
if (arginfo.kind == Halide::Internal::ArgInfoKind::Scalar) {
bb->bind_input(arginfo.name, fs);
Expand Down Expand Up @@ -295,19 +295,19 @@ Halide::Pipeline lower(Builder builder, std::vector<Node>& nodes, bool implicit_
if (implicit_output) {
// Collects all output which is never referenced.
// This mode is used for AOT compilation
std::unordered_map<std::string, std::vector<std::string>> referenced;
std::unordered_map<NodeID , std::vector<std::string>, NodeID::StringIDHash> referenced;
for (const auto& n : nodes) {
for (const auto& [pn, port] : n.iports()) {
if (port.has_pred()) {
for (const auto &f : bbs[port.pred_id_to_string()]->output_func(port.pred_name())) {
referenced[port.pred_id_to_string()].emplace_back(f.name());
for (const auto &f : bbs[port.pred_id()]->output_func(port.pred_name())) {
referenced[port.pred_id()].emplace_back(f.name());
}
}
}
}

for (const auto& node : nodes) {
auto node_id = node.id_to_string();
auto node_id = node.id();
for (auto arginfo : bbs[node_id]->arginfos()) {
if (arginfo.dir != Halide::Internal::ArgInfoDirection::Output) {
continue;
Expand Down Expand Up @@ -336,7 +336,7 @@ Halide::Pipeline lower(Builder builder, std::vector<Node>& nodes, bool implicit_
continue;
}

const auto& pred_bb(bbs[port.pred_id_to_string()]);
const auto& pred_bb(bbs[port.pred_id()]);

// Validate port exists
const auto& port_(port); // This is workaround for Clang-14 (MacOS)
Expand All @@ -349,7 +349,7 @@ Halide::Pipeline lower(Builder builder, std::vector<Node>& nodes, bool implicit_
}


auto fs(bbs[port.pred_id_to_string()]->output_func(port.pred_name()));
auto fs(bbs[port.pred_id()]->output_func(port.pred_name()));
output_funcs.insert(output_funcs.end(), fs.begin(), fs.end());
}
}
Expand Down
Loading

0 comments on commit 4e48349

Please sign in to comment.