Skip to content

Commit

Permalink
remove splitk support (#3286)
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Jul 20, 2024
1 parent b8c33b0 commit 05c9d1b
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 32 deletions.
2 changes: 0 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ Full documentation for MIGraphX is available at
* Added a `--test` flag in migraphx-driver to validate the installation
* Added support for ONNX Operator: Einsum
* Added uint8 support in ONNX Operators
* Enabled Split-k kernel configurations for performance improvements
* Added fusion for group convolutions
* Added rocMLIR conv3d support
* Added rocgdb to the Dockerfile
Expand Down Expand Up @@ -46,7 +45,6 @@ Full documentation for MIGraphX is available at
* Added support for multi outputs in pointwise ops
* Improve reduction fusion with reshape operators
* Use the quantized output when an operator is used again
* Enabled Split-k GEMM perf configs for rocMLIR based GEMM kernels for better performance on all Hardware


### Fixes
Expand Down
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build ->
}
}, mlir_debug: rocmnode('mi100+') { cmake_build ->
stage('MLIR Debug') {
withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1']) {
withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1', 'MIGRAPHX_MLIR_ENABLE_SPLITK=1']) {
def sanitizers = "undefined"
// Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS.
def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}"
Expand Down
5 changes: 5 additions & 0 deletions docs/dev/env_vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ Limits the number of solutions available to MLIR for tuning.
Set to "1", "enable", "enabled", "yes", or "true" to use.
Enable input fusions in MLIR.

.. envvar:: MIGRAPHX_MLIR_ENABLE_SPLITK

Set to "1", "enable", "enabled", "yes", or "true" to use.
Enable Split-k perf configs when tuning with MLIR.

CK vars
-----------

Expand Down
8 changes: 6 additions & 2 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_LIMIT);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_ENABLE_SPLITK);

#ifdef MIGRAPHX_MLIR
template <class T, class F, F f> // NOLINT
Expand Down Expand Up @@ -595,8 +596,11 @@ struct mlir_program
{"sym_name", sym_name},
{"kernel", std::string("mixr")},
{"arch", target_arch},
{"num_cu", num_cu},
{"enable_splitk_for_tuning", mlirUnitAttrGet(ctx.get())}});
{"num_cu", num_cu}});
if(enabled(MIGRAPHX_MLIR_ENABLE_SPLITK{}))
{
ops.add_attributes({{"enable_splitk_for_tuning", mlirUnitAttrGet(ctx.get())}});
}
ops.add_region(std::move(region));
insert(body, std::move(ops));

Expand Down
86 changes: 59 additions & 27 deletions test/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/write_literals.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/module.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
Expand All @@ -37,6 +38,8 @@
#include <migraphx/functional.hpp>
#include <test.hpp>

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_ENABLE_SPLITK);

struct mlir_gpu_target : migraphx::gpu::target
{
std::string name() const { return "mlir"; }
Expand Down Expand Up @@ -154,11 +157,20 @@ bool verify_mlir(const migraphx::module& mmlir)
"mlir", run_gpu(mlir, inputs), migraphx::verify::expected{run_ref(ref, inputs)});
}

std::string get_attrs()
{
if(migraphx::enabled(MIGRAPHX_MLIR_ENABLE_SPLITK{}))
{
return R"({arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64})";
}
return R"({arch = "", kernel = "mixr", num_cu = 0 : i64})";
}

TEST_CASE(conv)
{
const std::string mlir_output = R"__migraphx__(
std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} {
func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes ${attrs} {
%0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
return %0 : !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>
}
Expand All @@ -173,15 +185,17 @@ module {
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
auto mlir_output_with_attrs =
migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}});
CHECK(encode(s) == encode(mlir_output_with_attrs));
EXPECT(verify_mlir(m));
}

TEST_CASE(conv_nhwc)
{
const std::string mlir_output = R"__migraphx__(
std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x1x24x8>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x1x32x8>) -> !migraphx.shaped<1x2x2x2xf32, 8x1x4x2> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} {
func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x1x24x8>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x1x32x8>) -> !migraphx.shaped<1x2x2x2xf32, 8x1x4x2> attributes ${attrs} {
%0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x1x32x8>, <2x8x3x3xf32, 72x1x24x8> -> <1x2x2x2xf32, 8x1x4x2>
return %0 : !migraphx.shaped<1x2x2x2xf32, 8x1x4x2>
}
Expand All @@ -196,15 +210,17 @@ module {
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
auto mlir_output_with_attrs =
migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}});
CHECK(encode(s) == encode(mlir_output_with_attrs));
EXPECT(verify_mlir(m));
}

TEST_CASE(conv_add_relu)
{
const std::string mlir_output = R"__migraphx__(
std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_convolution_add_relu(%arg0: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg2: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} {
func.func @mlir_convolution_add_relu(%arg0: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg2: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x2x2x2xf32, 8x4x2x1> attributes ${attrs} {
%0 = migraphx.convolution %arg2, %arg1 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
%1 = migraphx.add %0, %arg0 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
%2 = migraphx.relu %1 : <1x2x2x2xf32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
Expand All @@ -224,16 +240,19 @@ module {
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
auto mlir_output_with_attrs =
migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}});
CHECK(encode(s) == encode(mlir_output_with_attrs));

EXPECT(verify_mlir(m));
}

// The following test checks that a dimension -1, within reshape operator is handled properly..
TEST_CASE(conv_reshape_dim_minus_one)
{
const std::string mlir_output = R"__migraphx__(
std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_convolution_reshape(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x4x1x2xf32, 8x2x2x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} {
func.func @mlir_convolution_reshape(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x4x1x2xf32, 8x2x2x1> attributes ${attrs} {
%0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
%1 = migraphx.reshape %0 {dims = [1, 4, 1, 2]} : <1x2x2x2xf32, 8x4x2x1> -> <1x4x1x2xf32, 8x2x2x1>
return %1 : !migraphx.shaped<1x4x1x2xf32, 8x2x2x1>
Expand All @@ -250,15 +269,17 @@ module {
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
auto mlir_output_with_attrs =
migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}});
CHECK(encode(s) == encode(mlir_output_with_attrs));
EXPECT(verify_mlir(m));
}

TEST_CASE(quant_dot_add)
{
const std::string mlir_output = R"__migraphx__(
std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped<1x5x3xi32, 15x3x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} {
func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped<1x5x3xi32, 15x3x1> attributes ${attrs} {
%0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xi8, 20x4x1>, <1x4x3xi8, 12x3x1> -> <1x5x3xi32, 15x3x1>
%1 = migraphx.add %0, %arg2 : <1x5x3xi32, 15x3x1>, <1x5x3xi32, 15x3x1> -> <1x5x3xi32, 15x3x1>
return %1 : !migraphx.shaped<1x5x3xi32, 15x3x1>
Expand All @@ -277,15 +298,17 @@ module {
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
auto mlir_output_with_attrs =
migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}});
CHECK(encode(s) == encode(mlir_output_with_attrs));
EXPECT(verify_mlir(m));
}

TEST_CASE(dot_add)
{
const std::string mlir_output = R"__migraphx__(
std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot_add(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} {
func.func @mlir_dot_add(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} {
%0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.add %0, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
Expand All @@ -303,15 +326,17 @@ module {
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
auto mlir_output_with_attrs =
migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}});
CHECK(encode(s) == encode(mlir_output_with_attrs));
EXPECT(verify_mlir(m));
}

TEST_CASE(conv_int8_dequantize_quantize)
{
const std::string mlir_output = R"__migraphx__(
std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} {
func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> attributes ${attrs} {
%0 = migraphx.quant_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x3x3xi8, 72x9x3x1> -> <1x2x2x2xi32, 8x4x2x1>
%1 = migraphx.dequantizelinear %0, %arg2, %arg3 : <1x2x2x2xi32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1>
%2 = migraphx.quantizelinear %1, %arg2, %arg3 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xi32, 8x4x2x1>
Expand All @@ -336,15 +361,17 @@ module {
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
auto mlir_output_with_attrs =
migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}});
CHECK(encode(s) == encode(mlir_output_with_attrs));
EXPECT(verify_mlir(m));
}

TEST_CASE(dot_convert)
{
const std::string mlir_output = R"__migraphx__(
std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot_convert(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped<1x5x3xf16, 15x3x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} {
func.func @mlir_dot_convert(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped<1x5x3xf16, 15x3x1> attributes ${attrs} {
%0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.convert %0 {target_type = 1 : i64} : <1x5x3xf32, 15x3x1> to <1x5x3xf16, 15x3x1>
return %1 : !migraphx.shaped<1x5x3xf16, 15x3x1>
Expand All @@ -362,15 +389,17 @@ module {
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
auto mlir_output_with_attrs =
migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}});
CHECK(encode(s) == encode(mlir_output_with_attrs));
EXPECT(verify_mlir(m));
}

TEST_CASE(dot_where)
{
const std::string mlir_output = R"__migraphx__(
std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes {arch = "", enable_splitk_for_tuning, kernel = "mixr", num_cu = 0 : i64} {
func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} {
%0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1>
%1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1>
return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
Expand All @@ -389,7 +418,10 @@ module {
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
auto mlir_output_with_attrs =
migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}});
CHECK(encode(s) == encode(mlir_output_with_attrs));

EXPECT(verify_mlir(m));
}

Expand Down

0 comments on commit 05c9d1b

Please sign in to comment.