Skip to content

Commit

Permalink
Improvements to verify --reduce (#2206)
Browse files Browse the repository at this point in the history
Skip verifying literals and parameters
Print the correct trim value
Print failure for an exception and continue reduce
Print if shapes do not match
  • Loading branch information
pfultz2 authored Sep 22, 2023
1 parent 52eb36f commit 8298423
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions src/driver/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ranges.hpp>

namespace migraphx {
namespace driver {
Expand Down Expand Up @@ -84,7 +85,17 @@ void verify_program(const std::string& name,
std::size_t output_num = x.size();
for(std::size_t i = 0; i < output_num; ++i)
{
verify_args(name, x[i], y[i], tolerance);
if(x[i].get_shape().type() != y[i].get_shape().type() or
x[i].get_shape().lens() != y[i].get_shape().lens())
{
std::cout << "FAILED: " << name << std::endl;
std::cout << "Shape mismatch {" << x[i].get_shape() << "} != {" << y[i].get_shape()
<< "}" << std::endl;
}
else
{
verify_args(name, x[i], y[i], tolerance);
}
}
}

Expand Down Expand Up @@ -143,11 +154,19 @@ void verify_reduced(program p,
double tolerance)
{
auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1);
auto last = std::prev(mm->end(), n);
mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << n << std::endl;
std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance);
try
{
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance);
}
catch(const std::exception& e)
{
std::cout << "FAILED: " << n << std::endl;
std::cout << "Exception: " << e.what() << std::endl;
}
}

void verify_reduced_program(const program& p,
Expand All @@ -160,8 +179,14 @@ void verify_reduced_program(const program& p,
const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end());
std::cout << "Verify steps: " << n << std::endl;
for(std::size_t i = 0; i < n; i++)
for(std::size_t i = 1; i < n; i++)
{
auto last = std::prev(mm->end(), i + 1);
if(contains({"@literal", "@param"}, last->name()))
{
std::cout << "Skip: " << i << std::endl;
continue;
}
verify_reduced(p, i, t, options, quantize, inputs, tolerance);
}
}
Expand Down

0 comments on commit 8298423

Please sign in to comment.