Skip to content

Commit

Permalink
Fixed instruction::replace() logic - Fix CI. (#3574)
Browse files Browse the repository at this point in the history
  • Loading branch information
tcgu-amd authored Nov 11, 2024
1 parent a435c28 commit 2f97579
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 5 deletions.
6 changes: 6 additions & 0 deletions src/include/migraphx/output_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ auto join_back_inserter(Container& c)
[&](const auto& r) { c.insert(c.end(), r.begin(), r.end()); });
}

template <class Container>
auto push_inserter(Container& c)
{
return make_function_output_iterator([&](const auto& x) { c.push(x); });
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif // MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
32 changes: 27 additions & 5 deletions src/instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
#include <migraphx/erase.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
#include <deque>
#include <migraphx/output_iterator.hpp>
#include <queue>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -58,22 +59,43 @@ instruction::instruction(literal l)
{
}

struct replace_shape_order
{
instruction_ref start;

std::size_t location(instruction_ref x) const { return std::distance(start, x); }

bool operator()(instruction_ref x, instruction_ref y) const
{
return location(x) > location(y);
}
};

void instruction::replace(const shape& r)
{
if(r != result)
{
result = r;
std::deque<instruction_ref> q(output.begin(), output.end());
if(output.empty())
{
return;
}
auto start = std::find_if(output.front()->inputs().begin(),
output.front()->inputs().end(),
[&](instruction_ref x) { return this == as_address(x); });
assert(as_address(*start) == this);
std::priority_queue<instruction_ref, std::vector<instruction_ref>, replace_shape_order> q(
output.begin(), output.end(), replace_shape_order{*start});
while(not q.empty())
{
instruction_ref ins = q.front();
q.pop_front();
instruction_ref ins = q.top();
q.pop();
assert(ins->name() == "@return" or ins->name().front() != '@');
shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
if(new_r != ins->result)
{
ins->result = new_r;
std::copy(ins->output.begin(), ins->output.end(), std::back_inserter(q));
std::copy(ins->output.begin(), ins->output.end(), migraphx::push_inserter(q));
}
}
}
Expand Down
20 changes: 20 additions & 0 deletions test/instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,24 @@ TEST_CASE(check_replace_shape)
EXPECT(add->get_shape() == r);
}

TEST_CASE(check_replace_dag)
{
migraphx::module m;
migraphx::shape s{migraphx::shape::float_type, {3, 2}};
auto input = m.add_parameter("x", s);
auto reduce = m.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), input);
auto abs = m.add_instruction(migraphx::make_op("abs"), reduce);
auto sin = m.add_instruction(migraphx::make_op("sin"), reduce);
auto add = m.add_instruction(migraphx::make_op("add"), abs, sin);
auto add2 = m.add_instruction(migraphx::make_op("add"), add, reduce);

reduce->replace(migraphx::make_op("reduce_sum", {{"axes", {1}}}));

migraphx::shape r{migraphx::shape::float_type, {3, 1}};
EXPECT(reduce->get_shape() == r);
EXPECT(abs->get_shape() == r);
EXPECT(sin->get_shape() == r);
EXPECT(add->get_shape() == r);
EXPECT(add2->get_shape() == r);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit 2f97579

Please sign in to comment.