Skip to content

Commit

Permalink
use weak_pointer
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu Li committed Mar 8, 2024
1 parent 90e5381 commit 696d303
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 14 deletions.
2 changes: 1 addition & 1 deletion include/ion/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class DynamicModule;
/**
* Builder class is used to build graph, compile, run, save and load it.
*/
class Builder {
class Builder : std::enable_shared_from_this<Builder> {
public:

struct Impl;
Expand Down
2 changes: 1 addition & 1 deletion include/ion/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Graph {

Graph();

Graph(Builder & builder , const std::string& name = "");
Graph(std::weak_ptr<Builder> builder , const std::string& name = "");

Graph& operator+=(const Graph& rhs);

Expand Down
5 changes: 3 additions & 2 deletions src/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Builder::~Builder()

Builder::Impl::~Impl()
{
std::cout<<"Builder pointer is relaesed"<<std::endl;
for (auto [bb_id, disposer] : disposers) {
disposer(bb_id.c_str());
}
Expand All @@ -96,7 +97,7 @@ Node Builder::add(const std::string& name, const GraphID & graph_id)
}

Graph Builder::add_graph(const std::string& name) {
Graph g(*this, name);
Graph g(weak_from_this(), name);
impl_->graphs.push_back(g);
return g;
}
Expand Down Expand Up @@ -305,4 +306,4 @@ const Builder::Impl* Builder::impl_ptr() const {
}


} //namespace ion
} //namespace ion
3 changes: 2 additions & 1 deletion src/c_ion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,8 @@ int ion_port_map_set_buffer_array(ion_port_map_t obj, ion_port_t p, ion_buffer_t
int ion_graph_create(ion_graph_t *ptr, ion_builder_t obj, const char * name)
{
try {
*ptr = reinterpret_cast<ion_graph_t>(new Graph(*reinterpret_cast<Builder*>(obj), name));
std::shared_ptr<Builder> graph_ptr(reinterpret_cast<Builder*>(obj));
*ptr = reinterpret_cast<ion_graph_t>(new Graph(graph_ptr, name));
} catch (const Halide::Error& e) {
log::error(e.what());
return 1;
Expand Down
31 changes: 22 additions & 9 deletions src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace ion {

struct Graph::Impl {
Builder & builder;
std::weak_ptr<Builder> builder;
std::string name;
GraphID id;
std::vector<Node> nodes;
Expand All @@ -17,18 +17,28 @@ struct Graph::Impl {
std::unique_ptr<Halide::JITUserContext> jit_ctx;
Halide::JITUserContext* jit_ctx_ptr;
std::vector<const void*> args;
Impl()
: id(sole::uuid4().str())
{}

Impl(Builder & b, const std::string& n)
Impl(std::weak_ptr<Builder> b, const std::string& n)
: id(sole::uuid4().str()), builder(b), name(n), jit_ctx(new Halide::JITUserContext), jit_ctx_ptr(jit_ctx.get())
{
}
~Impl();
};

Graph::Graph()
{
}

Graph::Graph(Builder & builder, const std::string& name)

Graph::Impl::~Impl()
{
std::cout<<"Graph pointer is relaesed"<<std::endl;
}

Graph::Graph(std::weak_ptr<Builder> builder, const std::string& name)
: impl_(new Impl(builder, name))
{
}
Expand All @@ -41,23 +51,25 @@ Graph& Graph::operator+=(const Graph& rhs)

Graph operator+(const Graph& lhs, const Graph& rhs)
{
Graph g(lhs.impl_->builder);
Graph g(lhs.impl_->builder.lock());
g += lhs;
g += rhs;
return g;
}

Node Graph::add(const std::string& name)
{
auto n = impl_->builder.add(name,impl_->id);
auto ptr =impl_->builder.lock();
auto n = ptr->add(name,impl_->id);

impl_->nodes.push_back(n);
return n;
}

void Graph::run()
{
if (!impl_->pipeline.defined()) {
impl_->pipeline = lower(impl_->builder, impl_->nodes, false);
impl_->pipeline = lower(*impl_->builder.lock(), impl_->nodes, false);
if (!impl_->pipeline.defined()) {
log::warn("This pipeline doesn't produce any outputs. Please bind a buffer with output port.");
return;
Expand All @@ -66,11 +78,11 @@ void Graph::run()

if (!impl_->callable.defined()) {

impl_->pipeline.set_jit_externs(impl_->builder.jit_externs());
impl_->pipeline.set_jit_externs(impl_->builder.lock()->jit_externs());

auto inferred_args = impl_->pipeline.infer_arguments();

impl_->callable = impl_->pipeline.compile_to_callable(inferred_args, impl_->builder.target());
impl_->callable = impl_->pipeline.compile_to_callable(inferred_args, impl_->builder.lock()->target());

impl_->args.clear();
impl_->args.push_back(&impl_->jit_ctx_ptr);
Expand All @@ -91,4 +103,5 @@ std::vector<Node>& Graph::nodes() {
}


} // namespace ion

} // namespace ion

0 comments on commit 696d303

Please sign in to comment.