Skip to content

Commit

Permalink
use reference
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu Li committed Mar 8, 2024
1 parent aad9564 commit 9a1d5fa
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 25 deletions.
2 changes: 1 addition & 1 deletion include/ion/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class DynamicModule;
/**
* Builder class is used to build graph, compile, run, save and load it.
*/
class Builder : std::enable_shared_from_this<Builder> {
class Builder {
public:

struct Impl;
Expand Down
2 changes: 1 addition & 1 deletion include/ion/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Graph {



Graph(std::shared_ptr<Builder> builder , const std::string& name = "");
Graph(Builder & builder , const std::string& name = "");

Graph& operator+=(const Graph& rhs);

Expand Down
3 changes: 1 addition & 2 deletions src/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ Builder::~Builder()

Builder::Impl::~Impl()
{
std::cout<<"Builder pointer is relaesed"<<std::endl;
for (auto [bb_id, disposer] : disposers) {
disposer(bb_id.c_str());
}
Expand All @@ -97,7 +96,7 @@ Node Builder::add(const std::string& name, const GraphID & graph_id)
}

Graph Builder::add_graph(const std::string& name) {
Graph g(shared_from_this(), name);
Graph g(*this, name);
impl_->graphs.push_back(g);
return g;
}
Expand Down
3 changes: 1 addition & 2 deletions src/c_ion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1047,8 +1047,7 @@ int ion_port_map_set_buffer_array(ion_port_map_t obj, ion_port_t p, ion_buffer_t
int ion_graph_create(ion_graph_t *ptr, ion_builder_t obj, const char * name)
{
try {
std::shared_ptr<Builder> other_ptr(reinterpret_cast<Builder*>(obj));
*ptr = reinterpret_cast<ion_graph_t>(new Graph(other_ptr, name));
*ptr = reinterpret_cast<ion_graph_t>(new Graph(*reinterpret_cast<Builder*>(obj), name));
} catch (const Halide::Error& e) {
log::error(e.what());
return 1;
Expand Down
28 changes: 9 additions & 19 deletions src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace ion {

struct Graph::Impl {
std::weak_ptr<Builder> builder;
Builder & builder;
std::string name;
GraphID id;
std::vector<Node> nodes;
Expand All @@ -17,28 +17,18 @@ struct Graph::Impl {
std::unique_ptr<Halide::JITUserContext> jit_ctx;
Halide::JITUserContext* jit_ctx_ptr;
std::vector<const void*> args;
Impl()
: id(sole::uuid4().str())
{}

Impl(std::shared_ptr<Builder> b, const std::string& n)
Impl(Builder & b, const std::string& n)
: id(sole::uuid4().str()), builder(b), name(n), jit_ctx(new Halide::JITUserContext), jit_ctx_ptr(jit_ctx.get())
{
}
~Impl();
};

Graph::Graph()
{
}


Graph::Impl::~Impl()
{
std::cout<<"Graph pointer is relaesed"<<std::endl;
}

Graph::Graph(std::shared_ptr<Builder> builder, const std::string& name)
Graph::Graph(Builder & builder, const std::string& name)
: impl_(new Impl(builder, name))
{
}
Expand All @@ -51,16 +41,16 @@ Graph& Graph::operator+=(const Graph& rhs)

Graph operator+(const Graph& lhs, const Graph& rhs)
{
Graph g(lhs.impl_->builder.lock());
Graph g(lhs.impl_->builder);
g += lhs;
g += rhs;
return g;
}

Node Graph::add(const std::string& name)
{
auto ptr =impl_->builder.lock();
auto n = ptr->add(name,impl_->id);
auto ptr =impl_->builder;
auto n = ptr.add(name,impl_->id);

impl_->nodes.push_back(n);
return n;
Expand All @@ -69,7 +59,7 @@ Node Graph::add(const std::string& name)
void Graph::run()
{
if (!impl_->pipeline.defined()) {
impl_->pipeline = lower(*impl_->builder.lock(), impl_->nodes, false);
impl_->pipeline = lower(impl_->builder, impl_->nodes, false);
if (!impl_->pipeline.defined()) {
log::warn("This pipeline doesn't produce any outputs. Please bind a buffer with output port.");
return;
Expand All @@ -78,11 +68,11 @@ void Graph::run()

if (!impl_->callable.defined()) {

impl_->pipeline.set_jit_externs(impl_->builder.lock()->jit_externs());
impl_->pipeline.set_jit_externs(impl_->builder.jit_externs());

auto inferred_args = impl_->pipeline.infer_arguments();

impl_->callable = impl_->pipeline.compile_to_callable(inferred_args, impl_->builder.lock()->target());
impl_->callable = impl_->pipeline.compile_to_callable(inferred_args, impl_->builder.target());

impl_->args.clear();
impl_->args.push_back(&impl_->jit_ctx_ptr);
Expand Down

0 comments on commit 9a1d5fa

Please sign in to comment.