Skip to content

Commit

Permalink
use weak pointer
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu Li committed Mar 8, 2024
1 parent 93e352b commit aad9564
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 15 deletions.
2 changes: 1 addition & 1 deletion include/ion/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class DynamicModule;
/**
* Builder class is used to build graph, compile, run, save and load it.
*/
class Builder {
class Builder : std::enable_shared_from_this<Builder> {
public:

struct Impl;
Expand Down
9 changes: 6 additions & 3 deletions include/ion/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ namespace ion {
class Builder;

class Graph {
public:

struct Impl;

public:
struct Impl;

Graph();

Graph(Builder builder, const std::string& name = "");


Graph(std::shared_ptr<Builder> builder , const std::string& name = "");

Graph& operator+=(const Graph& rhs);

Expand Down Expand Up @@ -46,6 +48,7 @@ class Graph {

private:
std::shared_ptr<Impl> impl_;

};

} // namespace ion
Expand Down
4 changes: 2 additions & 2 deletions src/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ Builder::Builder()

Builder::~Builder()
{
impl_->graphs.clear();
}

Builder::Impl::~Impl()
{
std::cout<<"Builder pointer is relaesed"<<std::endl;
for (auto [bb_id, disposer] : disposers) {
disposer(bb_id.c_str());
}
Expand All @@ -97,7 +97,7 @@ Node Builder::add(const std::string& name, const GraphID & graph_id)
}

Graph Builder::add_graph(const std::string& name) {
Graph g(*this, name);
Graph g(shared_from_this(), name);
impl_->graphs.push_back(g);
return g;
}
Expand Down
3 changes: 2 additions & 1 deletion src/c_ion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,8 @@ int ion_port_map_set_buffer_array(ion_port_map_t obj, ion_port_t p, ion_buffer_t
int ion_graph_create(ion_graph_t *ptr, ion_builder_t obj, const char * name)
{
try {
*ptr = reinterpret_cast<ion_graph_t>(new Graph(*reinterpret_cast<Builder*>(obj), name));
std::shared_ptr<Builder> other_ptr(reinterpret_cast<Builder*>(obj));
*ptr = reinterpret_cast<ion_graph_t>(new Graph(other_ptr, name));
} catch (const Halide::Error& e) {
log::error(e.what());
return 1;
Expand Down
26 changes: 18 additions & 8 deletions src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace ion {

struct Graph::Impl {
Builder builder;
std::weak_ptr<Builder> builder;
std::string name;
GraphID id;
std::vector<Node> nodes;
Expand All @@ -20,17 +20,25 @@ struct Graph::Impl {
Impl()
: id(sole::uuid4().str())
{}
Impl(Builder b, const std::string& n)

Impl(std::shared_ptr<Builder> b, const std::string& n)
: id(sole::uuid4().str()), builder(b), name(n), jit_ctx(new Halide::JITUserContext), jit_ctx_ptr(jit_ctx.get())
{
}
~Impl();
};

Graph::Graph()
{
}

Graph::Graph(Builder builder, const std::string& name)

Graph::Impl::~Impl()
{
std::cout<<"Graph pointer is relaesed"<<std::endl;
}

Graph::Graph(std::shared_ptr<Builder> builder, const std::string& name)
: impl_(new Impl(builder, name))
{
}
Expand All @@ -43,15 +51,16 @@ Graph& Graph::operator+=(const Graph& rhs)

Graph operator+(const Graph& lhs, const Graph& rhs)
{
Graph g(lhs.impl_->builder);
Graph g(lhs.impl_->builder.lock());
g += lhs;
g += rhs;
return g;
}

Node Graph::add(const std::string& name)
{
auto n = impl_->builder.add(name,impl_->id);
auto ptr =impl_->builder.lock();
auto n = ptr->add(name,impl_->id);

impl_->nodes.push_back(n);
return n;
Expand All @@ -60,7 +69,7 @@ Node Graph::add(const std::string& name)
void Graph::run()
{
if (!impl_->pipeline.defined()) {
impl_->pipeline = lower(impl_->builder, impl_->nodes, false);
impl_->pipeline = lower(*impl_->builder.lock(), impl_->nodes, false);
if (!impl_->pipeline.defined()) {
log::warn("This pipeline doesn't produce any outputs. Please bind a buffer with output port.");
return;
Expand All @@ -69,11 +78,11 @@ void Graph::run()

if (!impl_->callable.defined()) {

impl_->pipeline.set_jit_externs(impl_->builder.jit_externs());
impl_->pipeline.set_jit_externs(impl_->builder.lock()->jit_externs());

auto inferred_args = impl_->pipeline.infer_arguments();

impl_->callable = impl_->pipeline.compile_to_callable(inferred_args, impl_->builder.target());
impl_->callable = impl_->pipeline.compile_to_callable(inferred_args, impl_->builder.lock()->target());

impl_->args.clear();
impl_->args.push_back(&impl_->jit_ctx_ptr);
Expand All @@ -94,4 +103,5 @@ std::vector<Node>& Graph::nodes() {
}



} // namespace ion

0 comments on commit aad9564

Please sign in to comment.