Skip to content

Commit

Permalink
Merge branch 'develop' into disable_e5m2_rocblas
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored Dec 9, 2024
2 parents ffc0995 + a30b253 commit d79bbda
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/sphinx/requirements.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
rocm-docs-core==1.10.0
rocm-docs-core==1.11.0
sphinx-collapse
2 changes: 1 addition & 1 deletion docs/sphinx/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core==1.10.0
rocm-docs-core==1.11.0
# via -r requirements.in
smmap==5.0.1
# via gitdb
Expand Down
32 changes: 29 additions & 3 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,40 @@ struct mlir_op
return pack(f(self.op, "op"));
}

// Check if the shape can be created from a transpose/broadcast/slice
static bool is_mlir_compatible(const shape& s)
{
if(s.standard() or s.packed() or s.scalar() or s.ndim() == 1)
return true;
auto ns = reorder_shape(s, find_permutation(s));
std::vector<std::size_t> stride_ratios;
auto last = std::find(ns.strides().begin(), ns.strides().end(), 0);
if(*std::prev(last) != 1)
return false;
std::adjacent_difference(ns.strides().begin(),
last,
std::back_inserter(stride_ratios),
[](auto y, auto x) -> std::size_t {
assert(y != 0);
if((x % y) != 0)
return 0;
return x / y;
});
return std::equal(stride_ratios.begin() + 1,
stride_ratios.end(),
ns.lens().begin() + 1,
[](auto ratio, auto len) { return ratio >= len; });
}

shape compute_shape(const std::vector<shape>& inputs, const std::vector<module_ref>& mods) const
{
module_ref mod = mods[0];
check_shapes{inputs, *this}.packed_or_broadcasted();
check_shapes{inputs, *this}.has_at_least(1);
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
if(inputs.empty())
MIGRAPHX_THROW("should have at least one input.");

if(not std::all_of(inputs.begin(), inputs.end(), &is_mlir_compatible))
MIGRAPHX_THROW("Shape is not mlir compatible.");

auto result =
mod->compute_shapes(inputs, {.name = name(), .strict_type = true, .strict_lens = true});
Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ struct find_rocblas_gemm_pointwise : gemm_pointwise
shape s = c_ins->get_shape();
// const-fold input if not standard shape since rocblas can't handle it
// Updated for a case where "standard" shape has out-of-sequence strides
if(not s.standard() or s.normalize_standard() != s)
if(not s.standard())
{
auto c = make_op("contiguous");
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
Expand Down
9 changes: 5 additions & 4 deletions src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ rocblas_datatype get_type(shape::type_t type)
MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
}

void blas_shape(const shape& s)
void blas_shape(const shape& in_shape)
{
if(s.lens().size() < 2)
if(in_shape.lens().size() < 2)
return;
auto s = in_shape.normalize_standard();
if(std::none_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 1; }))
MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1");
if(std::any_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 0; }))
Expand Down Expand Up @@ -591,7 +592,7 @@ void gemm_compute(context& ctx,
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
[](const argument& x) { return x.get_shape().normalize_standard(); });
auto gemm_item = gemm_impl<float>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
Expand All @@ -608,7 +609,7 @@ void gemm_compute(context& ctx,
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
[](const argument& x) { return x.get_shape().normalize_standard(); });
auto gemm_item = gemm_impl<int32_t>(output_shape, input_shapes, alpha, beta, compute_fp32);
gemm_item.run(ctx, args, solution_idx);
}
Expand Down
4 changes: 4 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/pp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#ifndef MIGRAPHX_GUARD_KERNELS_PP_HPP
#define MIGRAPHX_GUARD_KERNELS_PP_HPP

// NOLINTBEGIN(*-macro-to-enum)

#define MIGRAPHX_PP_PRIMITIVE_CAT(x, y) x##y
#define MIGRAPHX_PP_CAT(x, y) MIGRAPHX_PP_PRIMITIVE_CAT(x, y)

Expand Down Expand Up @@ -122,4 +124,6 @@
MIGRAPHX_PP_EXPAND(MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARGS( \
m, MIGRAPHX_PP_COMMA, __VA_ARGS__, MIGRAPHX_PP_RES_ARGS()))

// NOLINTEND(*-macro-to-enum)

#endif // MIGRAPHX_GUARD_KERNELS_PP_HPP
2 changes: 1 addition & 1 deletion test/onnx/.onnxrt-commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1128882bfd2a97c20f8a2a5ddb26cb0d42d9ebba
d27fecd3d3837864a268bc96f00f2b8dce294697
59 changes: 59 additions & 0 deletions test/verify/test_gemm_add_broadcast3.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>

template <migraphx::shape::type_t DType>
struct test_gemm_add_broadcast3 : verify_program<test_gemm_add_broadcast3<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{DType, {1, 2}};
migraphx::shape m2_shape{DType, {2, 4}};
migraphx::shape m3_shape{DType, {4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = mm->add_parameter("3", m3_shape);
auto l3_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), l3);

auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2);
mm->add_instruction(migraphx::make_op("add"), l3_b, dot);
return p;
}
std::string section() const { return "gemm"; }
};

template struct test_gemm_add_broadcast3<migraphx::shape::float_type>;
template struct test_gemm_add_broadcast3<migraphx::shape::half_type>;
// template struct test_gemm_add_broadcast3<migraphx::shape::fp8e4m3fnuz_type>;
// template struct test_gemm_add_broadcast3<migraphx::shape::fp8e5m2fnuz_type>;
// template struct test_gemm_add_broadcast3<migraphx::shape::fp8e4m3fn_type>;
// template struct test_gemm_add_broadcast3<migraphx::shape::fp8e5m2_type>;

0 comments on commit d79bbda

Please sign in to comment.