From 7fd3357c9f81826e1d9314e9bc8e77db7c4b3166 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Tue, 3 Dec 2024 21:54:12 -0600 Subject: [PATCH 1/7] Trim down tests (#3674) --- test/verify/test_add_mixed_layout.cpp | 1 - test/verify/test_arg_ops.cpp | 420 ------------------------ test/verify/test_logsoftmax.cpp | 8 - test/verify/test_pad.cpp | 4 - test/verify/test_prefix_scan_sum_2d.cpp | 4 - test/verify/test_reduce_add.cpp | 4 - test/verify/test_reduce_mean_nhwc.cpp | 4 - test/verify/test_reduce_op_large.cpp | 35 -- test/verify/test_roialign.cpp | 4 - test/verify/test_softmax.cpp | 8 - test/verify/test_topk_0.cpp | 4 - 11 files changed, 496 deletions(-) diff --git a/test/verify/test_add_mixed_layout.cpp b/test/verify/test_add_mixed_layout.cpp index 3920df069d8..c4e94feda45 100644 --- a/test/verify/test_add_mixed_layout.cpp +++ b/test/verify/test_add_mixed_layout.cpp @@ -43,6 +43,5 @@ struct test_add_mixed_layout : verify_program> } }; -template struct test_add_mixed_layout; template struct test_add_mixed_layout; template struct test_add_mixed_layout; diff --git a/test/verify/test_arg_ops.cpp b/test/verify/test_arg_ops.cpp index 61dd37c06f0..ac52da56a23 100644 --- a/test/verify/test_arg_ops.cpp +++ b/test/verify/test_arg_ops.cpp @@ -162,423 +162,3 @@ template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; - -// transpose argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// transpose argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; - -// transpose argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// transpose argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; - -// transpose argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// transpose argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; - -// transpose argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// transpose argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// broadcast argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// slice argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -// default case, standard shape argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; diff --git a/test/verify/test_logsoftmax.cpp b/test/verify/test_logsoftmax.cpp index e126f24f209..0586227a7c7 100644 --- a/test/verify/test_logsoftmax.cpp +++ b/test/verify/test_logsoftmax.cpp @@ -52,22 +52,14 @@ template struct test_logsoftmax<0, migraphx::shape::half_type>; template struct test_logsoftmax<2, migraphx::shape::half_type>; template struct test_logsoftmax<3, migraphx::shape::half_type>; -template struct test_logsoftmax<0, migraphx::shape::fp8e4m3fnuz_type>; template struct test_logsoftmax<1, migraphx::shape::fp8e4m3fnuz_type>; -template struct test_logsoftmax<2, migraphx::shape::fp8e4m3fnuz_type>; template struct test_logsoftmax<3, migraphx::shape::fp8e4m3fnuz_type>; -template struct test_logsoftmax<0, migraphx::shape::fp8e5m2fnuz_type>; template struct test_logsoftmax<1, migraphx::shape::fp8e5m2fnuz_type>; -template struct test_logsoftmax<2, migraphx::shape::fp8e5m2fnuz_type>; template struct test_logsoftmax<3, migraphx::shape::fp8e5m2fnuz_type>; -template struct test_logsoftmax<0, migraphx::shape::fp8e4m3fn_type>; template struct test_logsoftmax<1, migraphx::shape::fp8e4m3fn_type>; -template struct test_logsoftmax<2, migraphx::shape::fp8e4m3fn_type>; template struct test_logsoftmax<3, migraphx::shape::fp8e4m3fn_type>; -template struct test_logsoftmax<0, migraphx::shape::fp8e5m2_type>; template struct test_logsoftmax<1, migraphx::shape::fp8e5m2_type>; -template struct test_logsoftmax<2, migraphx::shape::fp8e5m2_type>; template struct test_logsoftmax<3, migraphx::shape::fp8e5m2_type>; diff --git a/test/verify/test_pad.cpp b/test/verify/test_pad.cpp index 0e1a94849de..d4351d32fd4 100644 --- a/test/verify/test_pad.cpp +++ b/test/verify/test_pad.cpp @@ -51,7 +51,3 @@ struct test_pad : verify_program> template struct test_pad; template struct test_pad; template struct test_pad; -template struct test_pad; -template struct test_pad; -template struct test_pad; -template struct test_pad; diff --git a/test/verify/test_prefix_scan_sum_2d.cpp b/test/verify/test_prefix_scan_sum_2d.cpp index 2364c2392cd..a7530e08009 100644 --- a/test/verify/test_prefix_scan_sum_2d.cpp +++ b/test/verify/test_prefix_scan_sum_2d.cpp @@ -68,7 +68,3 @@ struct test_prefix_scan_sum_2d_large : verify_program; template struct test_prefix_scan_sum_2d_large; -template struct test_prefix_scan_sum_2d_large; -template struct test_prefix_scan_sum_2d_large; -template struct test_prefix_scan_sum_2d_large; -template struct test_prefix_scan_sum_2d_large; diff --git a/test/verify/test_reduce_add.cpp b/test/verify/test_reduce_add.cpp index 26417732d35..c0e2b5ddf93 100644 --- a/test/verify/test_reduce_add.cpp +++ b/test/verify/test_reduce_add.cpp @@ -50,7 +50,3 @@ struct test_reduce_add : verify_program> }; template struct test_reduce_add; -template struct test_reduce_add; -template struct test_reduce_add; -template struct test_reduce_add; -template struct test_reduce_add; diff --git a/test/verify/test_reduce_mean_nhwc.cpp b/test/verify/test_reduce_mean_nhwc.cpp index 72ebbd15806..5ba5f617cde 100644 --- a/test/verify/test_reduce_mean_nhwc.cpp +++ b/test/verify/test_reduce_mean_nhwc.cpp @@ -47,7 +47,3 @@ struct test_reduce_mean_nhwc : verify_program> template struct test_reduce_mean_nhwc; template struct test_reduce_mean_nhwc; -template struct test_reduce_mean_nhwc; -template struct test_reduce_mean_nhwc; -template struct test_reduce_mean_nhwc; -template struct test_reduce_mean_nhwc; diff --git a/test/verify/test_reduce_op_large.cpp b/test/verify/test_reduce_op_large.cpp index aed457d33cb..473c762d23a 100644 --- a/test/verify/test_reduce_op_large.cpp +++ b/test/verify/test_reduce_op_large.cpp @@ -58,48 +58,13 @@ template struct test_reduce_op_large; template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; template struct test_reduce_op_large; - -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; template struct test_reduce_op_large; - -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; template struct test_reduce_op_large; - -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; -template struct test_reduce_op_large; template struct test_reduce_op_large; struct test_reduce_mean_1 : verify_program diff --git a/test/verify/test_roialign.cpp b/test/verify/test_roialign.cpp index 22231b62b05..9f5271f1c5d 100644 --- a/test/verify/test_roialign.cpp +++ b/test/verify/test_roialign.cpp @@ -60,7 +60,3 @@ struct test_roialign : verify_program> template struct test_roialign; template struct test_roialign; -template struct test_roialign; -template struct test_roialign; -template struct test_roialign; -template struct test_roialign; diff --git a/test/verify/test_softmax.cpp b/test/verify/test_softmax.cpp index 0d1c8c56328..af9b02a2903 100644 --- a/test/verify/test_softmax.cpp +++ b/test/verify/test_softmax.cpp @@ -50,22 +50,14 @@ template struct test_softmax<1, migraphx::shape::half_type>; template struct test_softmax<2, migraphx::shape::half_type>; template struct test_softmax<3, migraphx::shape::half_type>; -template struct test_softmax<0, migraphx::shape::fp8e4m3fnuz_type>; template struct test_softmax<1, migraphx::shape::fp8e4m3fnuz_type>; -template struct test_softmax<2, migraphx::shape::fp8e4m3fnuz_type>; template struct test_softmax<3, migraphx::shape::fp8e4m3fnuz_type>; -template struct test_softmax<0, migraphx::shape::fp8e5m2fnuz_type>; template struct test_softmax<1, migraphx::shape::fp8e5m2fnuz_type>; -template struct test_softmax<2, migraphx::shape::fp8e5m2fnuz_type>; template struct test_softmax<3, migraphx::shape::fp8e5m2fnuz_type>; -template struct test_softmax<0, migraphx::shape::fp8e4m3fn_type>; template struct test_softmax<1, migraphx::shape::fp8e4m3fn_type>; -template struct test_softmax<2, migraphx::shape::fp8e4m3fn_type>; template struct test_softmax<3, migraphx::shape::fp8e4m3fn_type>; -template struct test_softmax<0, migraphx::shape::fp8e5m2_type>; template struct test_softmax<1, migraphx::shape::fp8e5m2_type>; -template struct test_softmax<2, migraphx::shape::fp8e5m2_type>; template struct test_softmax<3, migraphx::shape::fp8e5m2_type>; diff --git a/test/verify/test_topk_0.cpp b/test/verify/test_topk_0.cpp index 243b234af72..625e19b6cac 100644 --- a/test/verify/test_topk_0.cpp +++ b/test/verify/test_topk_0.cpp @@ -47,7 +47,3 @@ struct test_topk_0 : verify_program> template struct test_topk_0; template struct test_topk_0; -template struct test_topk_0; -template struct test_topk_0; -template struct test_topk_0; -template struct test_topk_0; From ea04f367b4a235656a1f7c8b55a86dd58eeab87e Mon Sep 17 00:00:00 2001 From: spolifroni-amd Date: Wed, 4 Dec 2024 00:00:46 -0500 Subject: [PATCH 2/7] Added onnx_operators to TOC and landing page (#3668) --- docs/index.rst | 1 + docs/sphinx/_toc.yml.in | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index 55403e64ba7..e9ba8f40e6d 100755 --- a/docs/index.rst +++ b/docs/index.rst @@ -33,6 +33,7 @@ The MIGraphX public repository is located at `https://github.com/ROCm/AMDMIGraph * :ref:`cpp-api-reference` * :ref:`python-api-reference` * :ref:`migraphx-driver` + * :doc:`Supported ONNX Operators <./dev/onnx_operators>` .. grid-item-card:: Contributing to the MIGraphX code base diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index bab661618f7..528ef7cc371 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -16,6 +16,7 @@ subtrees: subtrees: - entries: - file: reference/driver-options + - file: dev/onnx_operators - caption: Developing for MIGraphX entries: From 6b886e3ae440492d24971677d50ce4147cb4a42e Mon Sep 17 00:00:00 2001 From: spolifroni-amd Date: Wed, 4 Dec 2024 00:01:48 -0500 Subject: [PATCH 3/7] updated metadata (#3667) --- docs/dev/data.rst | 4 ++++ docs/dev/dev_intro.rst | 5 +++++ docs/dev/env_vars.rst | 4 ++++ docs/dev/matchers.rst | 4 ++++ docs/dev/onnx_operators.rst | 22 ++++++++++++---------- docs/dev/operators.rst | 4 ++++ docs/dev/pass.rst | 4 ++++ docs/dev/program.rst | 4 ++++ docs/dev/quantization.rst | 4 ++++ docs/dev/targets.rst | 4 ++++ docs/dev/tools.rst | 4 ++++ docs/dev/triage-rocmlir.rst | 4 ++++ docs/reference/cpp.rst | 4 ++++ docs/reference/driver-options.rst | 4 ++-- docs/reference/py.rst | 4 ++++ 15 files changed, 67 insertions(+), 12 deletions(-) diff --git a/docs/dev/data.rst b/docs/dev/data.rst index ce3e77e04b8..217a6c6b81b 100755 --- a/docs/dev/data.rst +++ b/docs/dev/data.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX internal data types + :keywords: MIGraphX, code base, contribution, developing, data types + Data types ========== diff --git a/docs/dev/dev_intro.rst b/docs/dev/dev_intro.rst index f22f1d36f17..7454821f5e0 100644 --- a/docs/dev/dev_intro.rst +++ b/docs/dev/dev_intro.rst @@ -1,3 +1,8 @@ +.. meta:: + :description: MIGraphX introduction to developing for the code base + :keywords: MIGraphX, code base, contribution, developing, introduction, developers + + Developer Introduction ====================== diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst index 06e9624741c..cc7915df879 100644 --- a/docs/dev/env_vars.rst +++ b/docs/dev/env_vars.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX internal environment variables + :keywords: MIGraphX, code base, contribution, developing, env vars, environment variables + Environment Variables ===================== diff --git a/docs/dev/matchers.rst b/docs/dev/matchers.rst index 32c5b075d84..01d4ae6e35d 100644 --- a/docs/dev/matchers.rst +++ b/docs/dev/matchers.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX internal matchers + :keywords: MIGraphX, code base, contribution, developing, matchers + Matchers ======== diff --git a/docs/dev/onnx_operators.rst b/docs/dev/onnx_operators.rst index fc621b4f894..3e21f1172bb 100644 --- a/docs/dev/onnx_operators.rst +++ b/docs/dev/onnx_operators.rst @@ -1,22 +1,24 @@ +.. meta:: + :description: MIGraphX supported ONNX operators + :keywords: MIGraphX, code base, contribution, developing, ONNX operators + Supported ONNX Operators ======================== MIGraphX supports operators up to Opset 19. Latest information of ONNX -operators can be found -`here `__ +operators can be found in `the ONNX GitHub repository `_. -MIGraphX supports the following ONNX data types: BOOL, UINT8, UINT16, -UINT32, UINT64, INT8, INT16, INT32, INT64, FLOAT8, FLOAT16, FLOAT32, -DOUBLE +MIGraphX supports the following ONNX data types: BOOL, UINT8, UINT16, UINT32, UINT64, INT8, INT16, INT32, INT64, FLOAT8, FLOAT16, FLOAT32, and DOUBLE - NOTE: FP8 support is only for E4M3FNUZ, see - `here `__ + .. Note:: + + FP8 support is only for E4M3FNUZ, see `Float stored in 8 bits `_ in the ONNX documentation. See below for the support matrix of ONNX operators in MIGraphX. - NOTE: Supported Types are from ONNX specification. An operator might - support more datatypes (e.g. integer type for float operator) than - listed. + .. Note:: + + The listed supported types are taken from the ONNX specification. An operator might support other additional datatypes. Operator Support Matrix ----------------------- diff --git a/docs/dev/operators.rst b/docs/dev/operators.rst index 15691feb92f..8cf641f5767 100755 --- a/docs/dev/operators.rst +++ b/docs/dev/operators.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX internal operators + :keywords: MIGraphX, code base, contribution, developing, operators + Operators ========= diff --git a/docs/dev/pass.rst b/docs/dev/pass.rst index 4c27b706252..feada6df969 100755 --- a/docs/dev/pass.rst +++ b/docs/dev/pass.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX internal passes + :keywords: MIGraphX, code base, contribution, developing, passes + Passes ====== diff --git a/docs/dev/program.rst b/docs/dev/program.rst index 65b99343a9b..fe1ab3cfa38 100755 --- a/docs/dev/program.rst +++ b/docs/dev/program.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX program + :keywords: MIGraphX, code base, contribution, developing, program + Program ======= diff --git a/docs/dev/quantization.rst b/docs/dev/quantization.rst index aecbd63188f..16e79c8a93d 100755 --- a/docs/dev/quantization.rst +++ b/docs/dev/quantization.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX internal quantization + :keywords: MIGraphX, code base, contribution, developing, quantization + Quantization ============ diff --git a/docs/dev/targets.rst b/docs/dev/targets.rst index eb1ee223ca6..3f95688ea84 100755 --- a/docs/dev/targets.rst +++ b/docs/dev/targets.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX targets + :keywords: MIGraphX, code base, contribution, developing, targets + Targets ======= diff --git a/docs/dev/tools.rst b/docs/dev/tools.rst index 077eb5b9208..43847e399e4 100644 --- a/docs/dev/tools.rst +++ b/docs/dev/tools.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX tools + :keywords: MIGraphX, code base, contribution, developing, tooks, knobs + .. _tools: Tools diff --git a/docs/dev/triage-rocmlir.rst b/docs/dev/triage-rocmlir.rst index 63bed90455f..7f21c0a167a 100644 --- a/docs/dev/triage-rocmlir.rst +++ b/docs/dev/triage-rocmlir.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: Issue Triaging Guide for suspected issues + :keywords: MIGraphX, rocMLIR, issues, pipeline, compilation, bug, code base, kernel, contribution, developing + Issue Triaging Guide for suspected rocMLIR issue ================================================ diff --git a/docs/reference/cpp.rst b/docs/reference/cpp.rst index 57baef8bda1..1982328f7f6 100755 --- a/docs/reference/cpp.rst +++ b/docs/reference/cpp.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX C++ API reference + :keywords: MIGraphX, ROCm, C++, API, reference, development, developer + .. _cpp-api-reference: C++ API reference diff --git a/docs/reference/driver-options.rst b/docs/reference/driver-options.rst index 55012aa0fb1..e58a3752135 100644 --- a/docs/reference/driver-options.rst +++ b/docs/reference/driver-options.rst @@ -1,6 +1,6 @@ .. meta:: - :description: MIGraphX provides an optimized execution engine for deep learning neural networks - :keywords: MIGraphX, ROCm, library, API, tool + :description: MIGraphX driver options + :keywords: MIGraphX, ROCm, driver, options .. _driver-options: diff --git a/docs/reference/py.rst b/docs/reference/py.rst index c68a2df0e54..17077dc120a 100755 --- a/docs/reference/py.rst +++ b/docs/reference/py.rst @@ -1,3 +1,7 @@ +.. meta:: + :description: MIGraphX Python API reference + :keywords: MIGraphX, ROCm, Python, API, reference, development, developer + .. py:module:: migraphx .. _python-api-reference: From 871fd568cf27dae09ebf5f5634f5b21b53762f93 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Wed, 4 Dec 2024 10:42:53 -0600 Subject: [PATCH 4/7] Refactor GPU math functions (#3657) --- .../include/migraphx/kernels/functional.hpp | 4 +- .../kernels/include/migraphx/kernels/math.hpp | 282 +++++++----------- .../kernels/include/migraphx/kernels/pp.hpp | 77 +++++ .../include/migraphx/kernels/type_traits.hpp | 16 + test/gpu/jit.cpp | 37 ++- 5 files changed, 245 insertions(+), 171 deletions(-) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp index 1a263e98b2d..3e9d802611f 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp @@ -63,11 +63,11 @@ template struct overloaded : Fs... { using Fs::operator()...; - overloaded(Fs... fs) : Fs(fs)... {} + constexpr overloaded(Fs... fs) : Fs(fs)... {} }; template -overloaded overload(Fs... fs) +constexpr overloaded overload(Fs... fs) { return {fs...}; } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp index 611ac93a721..790a82cdd8d 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/math.hpp @@ -30,24 +30,63 @@ #include #include #include +#include namespace migraphx { namespace math { -constexpr float as_float(migraphx::half x) { return x; } - -constexpr float as_float(migraphx::fp8::fp8e4m3fnuz x) { return x; } -constexpr float as_float(migraphx::fp8::fp8e5m2fnuz x) { return x; } -constexpr float as_float(migraphx::fp8::fp8e4m3fn x) { return x; } -constexpr float as_float(migraphx::fp8::fp8e5m2 x) { return x; } template -constexpr T as_float(T x) +constexpr auto as_float(T x) +{ + if constexpr(is_integral{}) + return x; + else + return float(x); +} + +template ())> +__device__ auto wrap(F f, T x, Ts... xs) { - return x; + if constexpr(is_integral{}) + { + return wrap(f, double(x), double(xs)...); + } + else if constexpr(is_callable{}) + { + return f(x, xs...); + } + else + { + T result = f(as_float(x), as_float(xs)...); + return result; + } } + } // namespace math +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_LIFT_IMPL(type, ...) \ + [](type x, auto... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(x, xs...)) + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_LIFT(...) MIGRAPHX_DEVICE_MATH_LIFT_IMPL(__VA_ARGS__) + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_PARSE(x) x, + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_EACH(f) MIGRAPHX_DEVICE_MATH_LIFT(MIGRAPHX_DEVICE_MATH_PARSE f) + +// NOLINTNEXTLINE +#define MIGRAPHX_DEVICE_MATH_WRAP(name, ...) \ + namespace math { \ + inline static constexpr auto wrap_##name = \ + overload(MIGRAPHX_PP_TRANSFORM_ARGS(MIGRAPHX_DEVICE_MATH_EACH, __VA_ARGS__)); \ + } \ + template \ + auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(math::wrap(math::wrap_##name, xs...)) + // NOLINTNEXTLINE #define MIGRAPHX_DEVICE_MATH(name, fname) \ template ())> \ @@ -73,169 +112,47 @@ constexpr T as_float(T x) #define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \ inline auto __device__ name(type x, type y) -> type { return fname(x, y); } -// NOLINTNEXTLINE -#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \ - template ())> \ - auto __device__ name(migraphx::half x, Ts... xs) \ - MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...)) - -// NOLINTNEXTLINE -#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \ - template ())> \ - auto __device__ name(migraphx::fp8::fp8e4m3fnuz x, Ts... xs) MIGRAPHX_RETURNS( \ - migraphx::fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...))) \ - \ - template ())> \ - auto __device__ name(migraphx::fp8::fp8e5m2fnuz x, Ts... xs) MIGRAPHX_RETURNS( \ - migraphx::fp8::fp8e5m2fnuz(fname(math::as_float(x), math::as_float(xs)...))) \ - \ - template ())> \ - auto __device__ name(migraphx::fp8::fp8e4m3fn x, Ts... xs) MIGRAPHX_RETURNS( \ - migraphx::fp8::fp8e4m3fn(fname(math::as_float(x), math::as_float(xs)...))) \ - \ - template ())> \ - auto __device__ name(migraphx::fp8::fp8e5m2 x, Ts... xs) MIGRAPHX_RETURNS( \ - migraphx::fp8::fp8e5m2(fname(math::as_float(x), math::as_float(xs)...))) - // Template with two overloads for math functions, one for half2 type and one for more generic // vectorization where N is 4 or another even number. - // NOLINTNEXTLINE -#define MIGRAPHX_DEVICE_MATH_HALF2(name, fname) \ - template \ - auto __device__ name(migraphx::vec x, Ts... xs) \ - MIGRAPHX_RETURNS(migraphx::vec{fname(x, xs...)}); \ - template 2))> \ - auto __device__ name(migraphx::vec x, Ts... xs) \ - { \ - return vec_packed_transform<2>(x, xs...)( \ - [](auto... ys) -> migraphx::vec { return fname(ys...); }); \ +#define MIGRAPHX_DEVICE_MATH_VEC2(type, name, fname) \ + template \ + auto __device__ name(migraphx::vec x, Ts... xs) \ + MIGRAPHX_RETURNS(migraphx::vec{fname(x, xs...)}); \ + template 2))> \ + auto __device__ name(migraphx::vec x, Ts... xs) \ + { \ + return vec_packed_transform<2>(x, xs...)( \ + [](auto... ys) -> migraphx::vec { return fname(ys...); }); \ } -MIGRAPHX_DEVICE_MATH(abs, ::abs) -MIGRAPHX_DEVICE_MATH(acos, ::acos) -MIGRAPHX_DEVICE_MATH(acosh, ::acosh) -MIGRAPHX_DEVICE_MATH(asin, ::asin) -MIGRAPHX_DEVICE_MATH(asinh, ::asinh) -MIGRAPHX_DEVICE_MATH(atan, ::atan) -MIGRAPHX_DEVICE_MATH(atanh, ::atanh) -MIGRAPHX_DEVICE_MATH(ceil, ::ceil) -MIGRAPHX_DEVICE_MATH(cos, ::cos) -MIGRAPHX_DEVICE_MATH(cosh, ::cosh) -MIGRAPHX_DEVICE_MATH(erf, ::erf) -MIGRAPHX_DEVICE_MATH(exp, ::exp) -MIGRAPHX_DEVICE_MATH(floor, ::floor) -MIGRAPHX_DEVICE_MATH(isnan, ::isnan) -MIGRAPHX_DEVICE_MATH(isinf, ::isinf) -MIGRAPHX_DEVICE_MATH(log, ::log) -MIGRAPHX_DEVICE_MATH(log2, ::log2) -MIGRAPHX_DEVICE_MATH(nearbyint, ::nearbyint) -MIGRAPHX_DEVICE_MATH(pow, ::pow) -MIGRAPHX_DEVICE_MATH(remainder, ::remainder) -MIGRAPHX_DEVICE_MATH(round, ::round) -MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt) -MIGRAPHX_DEVICE_MATH(sin, ::sin) -MIGRAPHX_DEVICE_MATH(sinh, ::sinh) -MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt) -MIGRAPHX_DEVICE_MATH(tan, ::tan) -MIGRAPHX_DEVICE_MATH(tanh, ::tanh) -MIGRAPHX_DEVICE_MATH(fmod, ::fmod) - -// Float overloads -MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf) -MIGRAPHX_DEVICE_MATH_FOR(float, acosh, ::acoshf) -MIGRAPHX_DEVICE_MATH_FOR(float, asin, ::asinf) -MIGRAPHX_DEVICE_MATH_FOR(float, asinh, ::asinhf) -MIGRAPHX_DEVICE_MATH_FOR(float, atan, ::atanf) -MIGRAPHX_DEVICE_MATH_FOR(float, atanh, ::atanhf) -MIGRAPHX_DEVICE_MATH_FOR(float, cos, ::cosf) -MIGRAPHX_DEVICE_MATH_FOR(float, cosh, ::coshf) -MIGRAPHX_DEVICE_MATH_FOR(float, rsqrt, ::rsqrtf) -MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf) -MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf) -MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf) -MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf) -MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf) - -// Builtin half functions -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, ceil, ::hceil) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, cos, ::hcos) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isinf, ::__hisinf) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log2, ::hlog2) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin) -MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt) - -// Use float to compute half overload -MIGRAPHX_DEVICE_MATH_HALF(acos, ::acos) -MIGRAPHX_DEVICE_MATH_HALF(acosh, ::acosh) -MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin) -MIGRAPHX_DEVICE_MATH_HALF(asinh, ::asinh) -MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan) -MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh) -MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh) -MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf) -MIGRAPHX_DEVICE_MATH_HALF(nearbyint, ::nearbyint) -MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow) -MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder) -MIGRAPHX_DEVICE_MATH_HALF(round, ::round) -MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh) -MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) -MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) -MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod) - -// use float to compute fp8 overload -MIGRAPHX_DEVICE_MATH_FP8(abs, ::abs) -MIGRAPHX_DEVICE_MATH_FP8(acos, ::acos) -MIGRAPHX_DEVICE_MATH_FP8(acosh, ::acosh) -MIGRAPHX_DEVICE_MATH_FP8(asin, ::asin) -MIGRAPHX_DEVICE_MATH_FP8(asinh, ::asinh) -MIGRAPHX_DEVICE_MATH_FP8(atan, ::atan) -MIGRAPHX_DEVICE_MATH_FP8(atanh, ::atanh) -MIGRAPHX_DEVICE_MATH_FP8(ceil, ::ceil) -MIGRAPHX_DEVICE_MATH_FP8(cos, ::cos) -MIGRAPHX_DEVICE_MATH_FP8(cosh, ::cosh) -MIGRAPHX_DEVICE_MATH_FP8(erf, ::erf) -MIGRAPHX_DEVICE_MATH_FP8(exp, ::exp) -MIGRAPHX_DEVICE_MATH_FP8(floor, ::floor) -MIGRAPHX_DEVICE_MATH_FP8(isnan, ::isnan) -MIGRAPHX_DEVICE_MATH_FP8(log, ::log) -MIGRAPHX_DEVICE_MATH_FP8(log2, ::log2) -MIGRAPHX_DEVICE_MATH_FP8(pow, ::pow) -MIGRAPHX_DEVICE_MATH_FP8(remainder, ::remainder) -MIGRAPHX_DEVICE_MATH_FP8(round, ::round) -MIGRAPHX_DEVICE_MATH_FP8(rsqrt, ::rsqrt) -MIGRAPHX_DEVICE_MATH_FP8(sin, ::sin) -MIGRAPHX_DEVICE_MATH_FP8(sinh, ::sinh) -MIGRAPHX_DEVICE_MATH_FP8(sqrt, ::sqrt) -MIGRAPHX_DEVICE_MATH_FP8(tan, ::tan) -MIGRAPHX_DEVICE_MATH_FP8(tanh, ::tanh) -MIGRAPHX_DEVICE_MATH_FP8(fmod, ::fmod) - -// Map math functions to hip half2 functions -// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats -// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names -// Most but not all of these math ops have operators of the same names. -MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2) -MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil) -MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos) -MIGRAPHX_DEVICE_MATH_HALF2(exp, ::h2exp) -MIGRAPHX_DEVICE_MATH_HALF2(exp10, ::h2exp10) -MIGRAPHX_DEVICE_MATH_HALF2(exp2, ::h2exp2) -MIGRAPHX_DEVICE_MATH_HALF2(floor, ::h2floor) -MIGRAPHX_DEVICE_MATH_HALF2(isinf, ::__hisinf2) -MIGRAPHX_DEVICE_MATH_HALF2(isnan, ::__hisnan2) -MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log) -MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10) -MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2) -MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt) -MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin) -MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt) +MIGRAPHX_DEVICE_MATH_WRAP(acos, (double)::acos, (float)::acosf); +MIGRAPHX_DEVICE_MATH_WRAP(acosh, (double)::acosh, (float)::acoshf); +MIGRAPHX_DEVICE_MATH_WRAP(asin, (double)::asin, (float)::asinf); +MIGRAPHX_DEVICE_MATH_WRAP(asinh, (double)::asinh, (float)::asinh); +MIGRAPHX_DEVICE_MATH_WRAP(atan, (double)::atan, (float)::atan); +MIGRAPHX_DEVICE_MATH_WRAP(atanh, (double)::atanh, (float)::atanh); +MIGRAPHX_DEVICE_MATH_WRAP(ceil, (double)::ceil, (float)::ceilf, (half)::hceil); +MIGRAPHX_DEVICE_MATH_WRAP(cos, (double)::cos, (float)::cosf, (half)::hcos); +MIGRAPHX_DEVICE_MATH_WRAP(cosh, (double)::cosh, (float)::coshf); +MIGRAPHX_DEVICE_MATH_WRAP(erf, (double)::erf, (float)::erff); +MIGRAPHX_DEVICE_MATH_WRAP(exp, (double)::exp, (float)::expf, (half)::hexp); +MIGRAPHX_DEVICE_MATH_WRAP(floor, (double)::floor, (float)::floorf, (half)::hfloor); +MIGRAPHX_DEVICE_MATH_WRAP(isnan, (double)::isnan, (float)::isnan, (half)::__hisnan); +MIGRAPHX_DEVICE_MATH_WRAP(isinf, (double)::isinf, (float)::isinf, (half)::__hisinf); +MIGRAPHX_DEVICE_MATH_WRAP(log, (double)::log, (float)::logf, (half)::hlog); +MIGRAPHX_DEVICE_MATH_WRAP(log2, (double)::log2, (float)::log2f, (half)::hlog2); +MIGRAPHX_DEVICE_MATH_WRAP(nearbyint, (double)::nearbyint, (float)::nearbyintf); +MIGRAPHX_DEVICE_MATH_WRAP(pow, (double)::pow, (float)::powf); +MIGRAPHX_DEVICE_MATH_WRAP(remainder, (double)::remainder, (float)::remainderf); +MIGRAPHX_DEVICE_MATH_WRAP(round, (double)::round, (float)::roundf); +MIGRAPHX_DEVICE_MATH_WRAP(rsqrt, (double)::rsqrt, (float)::rsqrtf, (half)::hrsqrt); +MIGRAPHX_DEVICE_MATH_WRAP(sin, (double)::sin, (float)::sinf, (half)::hsin); +MIGRAPHX_DEVICE_MATH_WRAP(sinh, (double)::sinh, (float)::sinhf); +MIGRAPHX_DEVICE_MATH_WRAP(sqrt, (double)::sqrt, (float)::sqrtf, (half)::hsqrt); +MIGRAPHX_DEVICE_MATH_WRAP(tan, (double)::tan, (float)::tanf); +MIGRAPHX_DEVICE_MATH_WRAP(tanh, (double)::tanh, (float)::tanhf); +MIGRAPHX_DEVICE_MATH_WRAP(fmod, (double)::fmod, (float)::fmodf); template constexpr auto where(bool cond, const T& a, const U& b) @@ -243,13 +160,22 @@ constexpr auto where(bool cond, const T& a, const U& b) return cond ? a : b; } -MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max) -MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min) +MIGRAPHX_DEVICE_MATH_FOR(float, abs, ::abs) +MIGRAPHX_DEVICE_MATH_FOR(double, abs, ::abs) +MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs) +MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::fmaxf) +MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::fminf) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max) MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::__hmax) MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::__hmin) +template () and is_integral{})> +constexpr auto abs(const T& a) +{ + return where(a < 0, -a, a); +} + template ())> constexpr auto max(const T& a, const T& b) { @@ -322,6 +248,26 @@ MIGRAPHX_DEVICE_MATH_VEC(tan) MIGRAPHX_DEVICE_MATH_VEC(tanh) MIGRAPHX_DEVICE_MATH_VEC(where) +// Map math functions to hip half2 functions +// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats +// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names +// Most but not all of these math ops have operators of the same names. +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, abs, ::__habs2) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, ceil, ::h2ceil) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, cos, ::h2cos) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, exp, ::h2exp) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, exp10, ::h2exp10) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, exp2, ::h2exp2) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, floor, ::h2floor) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, isinf, ::__hisinf2) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, isnan, ::__hisnan2) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, log, ::h2log) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, log10, ::h2log10) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, log2, ::h2log2) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, rsqrt, ::h2rsqrt) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, sin, ::h2sin) +MIGRAPHX_DEVICE_MATH_VEC2(migraphx::half, sqrt, ::h2sqrt) + template constexpr auto convert(U v) { diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/pp.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/pp.hpp index 272e0ca0d10..62739f166b9 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/pp.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/pp.hpp @@ -29,6 +29,34 @@ #define MIGRAPHX_PP_EAT(...) #define MIGRAPHX_PP_EXPAND(...) __VA_ARGS__ +#define MIGRAPHX_PP_COMMA(...) , + +#define MIGRAPHX_PP_IIF(c) MIGRAPHX_PP_PRIMITIVE_CAT(MIGRAPHX_PP_IIF_, c) +#define MIGRAPHX_PP_IIF_0(t, ...) __VA_ARGS__ +#define MIGRAPHX_PP_IIF_1(t, ...) t + +#define MIGRAPHX_PP_COMPL(b) MIGRAPHX_PP_PRIMITIVE_CAT(MIGRAPHX_PP_COMPL_, b) +#define MIGRAPHX_PP_COMPL_0 1 +#define MIGRAPHX_PP_COMPL_1 0 + +#define MIGRAPHX_PP_BITAND(x) MIGRAPHX_PP_PRIMITIVE_CAT(MIGRAPHX_PP_BITAND_, x) +#define MIGRAPHX_PP_BITAND_0(y) 0 +#define MIGRAPHX_PP_BITAND_1(y) y + +#define MIGRAPHX_PP_CHECK(...) MIGRAPHX_PP_CHECK_N(__VA_ARGS__, 0, ) +#define MIGRAPHX_PP_CHECK_N(x, n, ...) n +#define MIGRAPHX_PP_PROBE(x) x, 1, + +#define MIGRAPHX_PP_IS_PAREN(x) MIGRAPHX_PP_CHECK(MIGRAPHX_PP_IS_PAREN_PROBE x) +#define MIGRAPHX_PP_IS_PAREN_PROBE(...) MIGRAPHX_PP_PROBE(~) + +#define MIGRAPHX_PP_PRIMITIVE_IS_EMPTY(x) \ + MIGRAPHX_PP_CHECK(MIGRAPHX_PP_PRIMITIVE_IS_EMPTY_PROBE x()) +#define MIGRAPHX_PP_PRIMITIVE_IS_EMPTY_PROBE(...) MIGRAPHX_PP_PROBE(~) + +#define MIGRAPHX_PP_IS_EMPTY_ARG(x) \ + MIGRAPHX_PP_BITAND(MIGRAPHX_PP_COMPL(MIGRAPHX_PP_IS_PAREN(x))) \ + (MIGRAPHX_PP_PRIMITIVE_IS_EMPTY(x)) #define MIGRAPHX_PP_REPEAT0(m, ...) m(0, __VA_ARGS__) #define MIGRAPHX_PP_REPEAT1(m, ...) MIGRAPHX_PP_REPEAT0(m, __VA_ARGS__) m(1, __VA_ARGS__) @@ -45,4 +73,53 @@ #define MIGRAPHX_PP_REPEAT(n, m, ...) \ MIGRAPHX_PP_PRIMITIVE_CAT(MIGRAPHX_PP_REPEAT, n)(m, __VA_ARGS__) +#define MIGRAPHX_PP_RES_ARGS() , , , , , , , , , , , , , , , + +#define MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARGS(...) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARGS_IMPL(__VA_ARGS__) + +#define MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARGS_IMPL( \ + m, delim, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15, ...) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x0) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x1) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x1) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x2) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x2) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x3) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x3) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x4) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x4) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x5) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x5) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x6) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x6) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x7) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x7) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x8) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x8) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x9) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x9) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x10) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x10) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x11) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x11) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x12) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x12) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x13) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x13) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x14) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x14) \ + MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(delim, x15) MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x15) + +#define MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARG(m, x) \ + MIGRAPHX_PP_IIF(MIGRAPHX_PP_IS_EMPTY_ARG(x))(MIGRAPHX_PP_EAT, m)(x) + +#define MIGRAPHX_PP_EACH_ARGS(m, ...) \ + MIGRAPHX_PP_EXPAND(MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARGS( \ + m, MIGRAPHX_PP_EAT, __VA_ARGS__, MIGRAPHX_PP_RES_ARGS())) + +#define MIGRAPHX_PP_TRANSFORM_ARGS(m, ...) \ + MIGRAPHX_PP_EXPAND(MIGRAPHX_PP_PRIMITIVE_TRANSFORM_ARGS( \ + m, MIGRAPHX_PP_COMMA, __VA_ARGS__, MIGRAPHX_PP_RES_ARGS())) + #endif // MIGRAPHX_GUARD_KERNELS_PP_HPP diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp index 60f68029304..1b0d1343ea2 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp @@ -29,6 +29,9 @@ namespace migraphx { +template +using void_t = void; + template U private_declval(int); @@ -38,6 +41,19 @@ T private_declval(long); template auto declval() noexcept -> decltype(private_declval(0)); +template +struct is_callable_impl : false_type +{ +}; + +template +struct is_callable_impl()(declval()...))>, F, Ts...> : true_type +{ +}; + +template +using is_callable = is_callable_impl; + template struct type_identity { diff --git a/test/gpu/jit.cpp b/test/gpu/jit.cpp index 750c1e03c23..01306af8422 100644 --- a/test/gpu/jit.cpp +++ b/test/gpu/jit.cpp @@ -141,11 +141,46 @@ const std::string math_template = R"__migraphx__( #include namespace migraphx { + +template +struct test_implicit_conversion_op +{ + T x; + + template + constexpr operator vec() const + { + if constexpr(vec_size() == 0) + { + return x; + } + else + { + static_assert(vec_size() == N, "Vector mismatch size"); + return __builtin_convertvector(x, vec); + } + } + + template + constexpr operator U() const + { + static_assert(is_same{} or not is_same{} or is_same{}); + return static_cast(x); + } +}; + +template +constexpr test_implicit_conversion_op test_implicit_conversion(T x) +{ + return {x}; +} + + extern "C" { __global__ void kernel(${type}* p) { auto x = *p; - *p = migraphx::implicit_conversion(migraphx::${invoke}); + *p = migraphx::test_implicit_conversion(migraphx::${invoke}); } } From 935b96b29ecdfd6cda816f0725f28a19cf965415 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Wed, 4 Dec 2024 11:05:34 -0600 Subject: [PATCH 5/7] Track broadcast axes in the shape_transform_descriptor (#3610) Although, this prevents simplifying as much, it does help preserve the permutation of the broadcasted axes. So if we have a tensor of `{2, 16, 10240}` that goes into a reduction along the last axis it will output to `{2, 16, 1}`, which may be broadcasted back into `{2, 16, 10240}`, but there could be more shape transformations after the reduce but before an pointwise operator: ``` @1 = multibroadcast[out_lens={2, 16, 10240},out_dyn_dims={}](@0) -> int64_type, {2, 16, 10240}, {16, 1, 0} @2 = reshape[dims={2, 160, 32, 32}](@1) -> int64_type, {2, 160, 32, 32}, {163840, 1024, 32, 1} @3 = transpose[permutation={0, 2, 3, 1}](@2) -> int64_type, {2, 32, 32, 160}, {163840, 32, 1, 1024} ``` On develop this would be simplified to: ``` @1 = unsqueeze[axes={1, 2, 5},steps={}](@0) -> int64_type, {2, 1, 1, 16, 1, 1}, {16, 16, 16, 1, 1, 1} @2 = multibroadcast[out_lens={2, 1, 1, 16, 1, 10},out_dyn_dims={}](@1) -> int64_type, {2, 1, 1, 16, 1, 10}, {16, 16, 16, 1, 1, 0} @3 = reshape[dims={2, 1, 1, 160}](@2) -> int64_type, {2, 1, 1, 160}, {160, 160, 160, 1} @4 = multibroadcast[out_lens={2, 32, 32, 160},out_dyn_dims={}](@3) -> int64_type, {2, 32, 32, 160}, {160, 0, 0, 1} ``` Ideally, we would want to apply these transformations without the broadcast before the reduction but if it simplified like above because the shape_transform_descriptor doesnt track the permutation of the the broadcasted axes. With this PR, it will simplify to: ``` @1 = unsqueeze[axes={3, 4},steps={}](@0) -> int64_type, {2, 16, 1, 1, 1}, {16, 1, 1, 1, 1} @2 = transpose[permutation={0, 3, 4, 1, 2}](@1) -> int64_type, {2, 1, 1, 16, 1}, {16, 1, 1, 1, 1} @3 = multibroadcast[out_lens={2, 1, 1, 16, 10},out_dyn_dims={}](@2) -> int64_type, {2, 1, 1, 16, 10}, {16, 1, 1, 1, 0} @4 = reshape[dims={2, 1, 1, 160}](@3) -> int64_type, {2, 1, 1, 160}, {160, 160, 160, 1} @5 = multibroadcast[out_lens={2, 32, 32, 160},out_dyn_dims={}](@4) -> int64_type, {2, 32, 32, 160}, {160, 0, 0, 1} ``` This has a transpose because the shape_transform_descriptor understands how it will output in NHWC, which means we can make the input to the reduction NHWC layout as well. This PR doesn't enable such rewriting, it only modifies the shape_transform descriptor to track such layouts. Also, there is some updates to the tests as well: - Validate that a simplified transformation produces the same result - Check that the simplification cannot be simplified further --- .../migraphx/shape_transform_descriptor.hpp | 7 +- src/shape_transform_descriptor.cpp | 201 +++++++++++--- test/shape_transform_descriptor.cpp | 259 ++++++++++++------ test/simplify_reshapes_test.cpp | 14 +- 4 files changed, 353 insertions(+), 128 deletions(-) diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index 4160759d47d..029f5b85f3d 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -98,7 +98,12 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor // the axis. However, it still needs to accounted for. After we // generate the broadcast we will set the axis to the hidden // axis, and then length to 1. - optional hidden_axis = nullopt; + std::vector hidden_axis = {}; + + const std::vector& origin_axis() const; + bool has_hidden_axis() const; + + void add_split_axis(std::size_t i); MIGRAPHX_EXPORT friend bool operator==(const sub& x, const sub& y); MIGRAPHX_EXPORT friend bool operator!=(const sub& x, const sub& y); diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 51802ef0fae..42b354eda5f 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -28,10 +28,12 @@ #include #include #include +#include #include #include #include #include +#include #include namespace migraphx { @@ -99,6 +101,16 @@ std::vector compute_dims(const std::vector& ops, return s.lens(); } +static dimension::sub* get_last_subdimension(std::vector& dims) +{ + if(dims.empty()) + return {}; + auto& d = dims.back(); + if(d.subdimensions.empty()) + return nullptr; + return &d.subdimensions.back(); +} + bool shape_transform_descriptor::apply(const std::vector& ops) { std::vector dims; @@ -196,8 +208,7 @@ bool shape_transform_descriptor::apply_reshape_impl(const std::vector dimension { auto new_sub = sub; - if(not new_sub.axis.empty()) - new_sub.axis.push_back(j); + new_sub.add_split_axis(j); new_sub.len = start[j]; return {{new_sub}}; }); @@ -209,12 +220,20 @@ bool shape_transform_descriptor::apply_reshape_impl(const std::vector{} : sub->axis; + auto trailing_dims = range(rdims.begin() + new_dims.size(), rdims.end()); + if(any_of(trailing_dims, [](auto d) { return d != 1; })) + return false; + if(distance(trailing_dims) > 1) + sub->add_split_axis(0); + transform(range(distance(trailing_dims)), + std::back_inserter(new_dims), + [&](std::size_t j) -> dimension { + dimension::sub s{1, axis}; + s.add_split_axis(j + 1); + return {{s}}; + }); } assert(rdims.size() == new_dims.size()); if(rdims.size() != new_dims.size()) @@ -252,7 +271,20 @@ bool shape_transform_descriptor::apply_broadcast(const std::vector& return dim; if(dim.len() != 1) MIGRAPHX_THROW("Wrong out_lens for broadcast"); - return {{dimension::sub{len, {}}}}; + auto new_subs = dim.subdimensions; + if(not new_subs.empty()) + { + new_subs.front().len = len; + } + for(auto& s : new_subs) + { + if(not s.axis.empty()) + { + s.hidden_axis = s.axis; + s.axis.clear(); + } + } + return {new_subs}; }); std::transform(out_lens.begin() + offset + dimensions.size(), out_lens.end(), @@ -281,14 +313,19 @@ void dimension::simplify() remove_1_sub_dims(subdimensions); // Flatten adjacent dimensions adjacent_for_each(subdimensions.begin(), subdimensions.end(), [&](sub& d1, sub& d2) { - if(d1.axis.size() < 2) + if(d1.origin_axis().size() < 2) + return; + if(d2.origin_axis().size() < 2) return; - if(d2.axis.size() < 2) + if(d1.has_hidden_axis() != d2.has_hidden_axis()) return; - if(not std::equal(d1.axis.begin(), d1.axis.end() - 1, d2.axis.begin(), d2.axis.end() - 1)) + if(not std::equal(d1.origin_axis().begin(), + d1.origin_axis().end() - 1, + d2.origin_axis().begin(), + d2.origin_axis().end() - 1)) return; - auto a1 = d1.axis.back(); - auto a2 = d2.axis.back(); + auto a1 = d1.origin_axis().back(); + auto a2 = d2.origin_axis().back(); assert(a2 != a1); if(a2 <= a1) return; @@ -347,7 +384,7 @@ static bool missing_leading_axis(const dimension& d) if(d.subdimensions.empty()) return true; const auto& sub = d.subdimensions.front(); - return sub.axis.empty(); + return sub.origin_axis().empty(); } static void set_broadcast_dim(dimension& d, std::size_t axis) @@ -355,7 +392,10 @@ static void set_broadcast_dim(dimension& d, std::size_t axis) if(d.subdimensions.empty()) d.subdimensions.push_back({1, {axis}}); else - d.subdimensions.front().hidden_axis = axis; + { + assert(d.subdimensions.front().hidden_axis.empty()); + d.subdimensions.front().hidden_axis = {axis}; + } } // Group all axes into a map with a key of the axis and the value is vector of @@ -368,14 +408,77 @@ group_axes(std::vector& dimensions) { for(auto& s : d.subdimensions) { - if(s.axis.empty()) + if(s.origin_axis().empty()) continue; - axes_map[s.axis.front()].push_back(&s); + axes_map[s.origin_axis().front()].push_back(&s); } } return axes_map; } +static void set_origin_axis(dimension::sub& s, const std::vector& axis) +{ + if(s.has_hidden_axis()) + s.hidden_axis = axis; + else + s.axis = axis; +} + +// If an axis is split and some dimensions are hidden and others are not, then +// remove the hidden axis so only the non-hidden axis is used in +// simplificaiton +static void remove_split_hidden_axes(std::map>& axes_map) +{ + for(auto&& p : axes_map) + { + auto& subs = p.second; + if(std::all_of(subs.begin(), subs.end(), [](const dimension::sub* s) { + return s->has_hidden_axis(); + })) + continue; + for(auto* sub : subs) + { + if(not sub->has_hidden_axis()) + continue; + sub->hidden_axis.clear(); + } + // Remove the subdimesions that no longer have an axis + subs.erase(std::remove_if(subs.begin(), + subs.end(), + [](const dimension::sub* s) { + return s->axis.empty() and s->hidden_axis.empty(); + }), + subs.end()); + } + // Remove axis from group if empty + erase_if(axes_map, [](auto&& p) { return p.second.empty(); }); +} + +// If this is scalar, then remove all axes +static void remove_scalar_axis(std::vector& dimensions) +{ + dimension::sub* s = nullptr; + for(auto& d : dimensions) + { + auto has_axis = [](const dimension::sub& x) { return not x.origin_axis().empty(); }; + auto it = std::find_if(d.subdimensions.begin(), d.subdimensions.end(), has_axis); + if(it == d.subdimensions.end()) + continue; + if(s != nullptr) + return; + if(std::count_if(std::next(it), d.subdimensions.end(), has_axis) > 0) + return; + s = &*it; + } + if(s != nullptr) + { + if(s->has_hidden_axis()) + s->hidden_axis.clear(); + if(s->len == 1) + s->axis.clear(); + } +} + // Renumber all axes while preserving the order of the axes static void renumber_axes(std::map>& axes_map) { @@ -385,15 +488,15 @@ static void renumber_axes(std::map>& a auto& subs = p.second; if(subs.size() == 1) { - subs[0]->axis = {axis}; + set_origin_axis(*subs[0], {axis}); } else { std::sort(subs.begin(), subs.end(), by(std::less<>{}, [](const dimension::sub* s) { - return s->axis; + return s->origin_axis(); })); for(std::size_t i : range(subs.size())) - subs[i]->axis = {axis, i}; + set_origin_axis(*subs[i], {axis, i}); } } } @@ -437,6 +540,8 @@ void shape_transform_descriptor::simplify() for(auto& d : dimensions) d.simplify(); + remove_scalar_axis(dimensions); + std::map missing_axes; std::vector last_axis; { @@ -445,6 +550,7 @@ void shape_transform_descriptor::simplify() if(axes_map.empty()) return; + remove_split_hidden_axes(axes_map); renumber_axes(axes_map); // Find last axis @@ -471,8 +577,8 @@ void shape_transform_descriptor::simplify() { assert(not last->subdimensions.empty()); const auto& sub = last->subdimensions.front(); - assert(not sub.axis.empty()); - axis = sub.axis.front(); + assert(not sub.origin_axis().empty()); + axis = sub.origin_axis().front(); } std::deque dims(std::distance(start, last)); std::iota(dims.begin(), dims.end(), std::distance(dimensions.begin(), start)); @@ -518,18 +624,18 @@ void shape_transform_descriptor::simplify() // Search for the subdimension that has the next axis and try to // insert the axis before it will be in order. auto [sub, it, prev] = find_subdimension(*this, [&](const dimension::sub& s) { - if(s.axis.empty()) + if(s.origin_axis().empty()) return false; - if(s.axis.front() != next_axis) + if(s.origin_axis().front() != next_axis) return false; - if(s.axis.size() == 1) + if(s.origin_axis().size() == 1) return true; - assert(s.axis.size() == 2); - return s.axis.back() == 0; + assert(s.origin_axis().size() == 2); + return s.origin_axis().back() == 0; }); bool in_order = false; - if(prev.has_value() and not(*prev)->axis.empty()) - in_order = (*prev)->axis.front() == missing_axis - 1; + if(prev.has_value() and not(*prev)->origin_axis().empty()) + in_order = (*prev)->origin_axis().front() == missing_axis - 1; // If the axis is not inorder then see if we can find a broadcast axis to place it auto bdims = in_order ? broadcast_dims_map.end() : broadcast_dims_map.upper_bound(missing_axis); @@ -611,17 +717,15 @@ static void flatten_broadcasted_dim(dimension::sub& s) if(s.axis.empty()) { s.len = 1; - if(s.hidden_axis.has_value()) - { - s.axis = {s.hidden_axis.value()}; - s.hidden_axis = nullopt; - } + s.axis = s.hidden_axis; + s.hidden_axis.clear(); } } static operation make_reshape_unsqueeze(const std::vector& subs) { bool use_reshape = false; + std::unordered_set all_1s; // Check if split dimensions are all additional 1s if(std::any_of( subs.begin(), subs.end(), [](const dimension::sub& s) { return s.axis.size() > 1; })) @@ -645,6 +749,8 @@ static operation make_reshape_unsqueeze(const std::vector& subs) // Number of elements that are 1 auto n1 = std::count_if(start, last, [](const dimension::sub& s) { return s.len == 1; }); + if(n == n1 and not start->axis.empty()) + all_1s.insert(start->axis.front()); use_reshape |= std::max(0, n - n1 - 1) > 0; }, by_axis); @@ -672,6 +778,8 @@ static operation make_reshape_unsqueeze(const std::vector& subs) continue; if(sub.len != 1 and not sub.axis.empty()) continue; + if(not sub.axis.empty() and contains(all_1s, sub.axis.front()) and sub.axis.back() == 0) + continue; axes.push_back(i); } return make_op("unsqueeze", {{"axes", axes}}); @@ -681,7 +789,7 @@ static operation make_reshape_unsqueeze(const std::vector& subs) static bool has_no_axes(const dimension& d) { return std::all_of(d.subdimensions.begin(), d.subdimensions.end(), [](const dimension::sub& s) { - return s.axis.empty() and not s.hidden_axis.has_value(); + return s.axis.empty() and s.hidden_axis.empty(); }); } static bool has_axes(const dimension& d) @@ -824,6 +932,23 @@ std::size_t shape_transform_descriptor::elements() const [](const auto& s) { return s.len(); }); } +const std::vector& shape_transform_descriptor::dimension::sub::origin_axis() const +{ + return axis.empty() ? hidden_axis : axis; +} +bool shape_transform_descriptor::dimension::sub::has_hidden_axis() const +{ + return axis.empty() and not hidden_axis.empty(); +} + +void shape_transform_descriptor::dimension::sub::add_split_axis(std::size_t i) +{ + if(not axis.empty()) + axis.push_back(i); + if(not hidden_axis.empty()) + hidden_axis.push_back(i); +} + bool operator==(const dimension::sub& x, const dimension::sub& y) { return by(std::equal_to<>{}, @@ -833,8 +958,8 @@ bool operator!=(const dimension::sub& x, const dimension::sub& y) { return not(x std::ostream& operator<<(std::ostream& os, const dimension::sub& x) { os << x.len << ":" << to_string_range(x.axis, "x"); - if(x.hidden_axis.has_value()) - os << "$" << x.hidden_axis.value(); + if(not x.hidden_axis.empty()) + os << "$" << to_string_range(x.hidden_axis, "x"); return os; } bool operator==(const dimension& x, const dimension& y) diff --git a/test/shape_transform_descriptor.cpp b/test/shape_transform_descriptor.cpp index 8a5b7cf34f2..d6168ecbe7c 100644 --- a/test/shape_transform_descriptor.cpp +++ b/test/shape_transform_descriptor.cpp @@ -24,6 +24,8 @@ */ #include #include +#include +#include #include using migraphx::make_op; @@ -76,6 +78,34 @@ all_axes get_all_axes(const shape_transform_descriptor& d) return result; } +std::vector run_shape_transforms(const std::vector& dims, + const std::vector& ops) +{ + migraphx::shape s{migraphx::shape::int64_type, dims}; + std::vector data(s.elements()); + std::iota(data.begin(), data.end(), 0); + + migraphx::program p; + auto* mm = p.get_main_module(); + auto start = mm->add_literal(s, data); + for(const auto& op : ops) + start = mm->add_instruction(op, start); + mm->add_return({start}); + + auto result = p.eval({}).at(0); + return result.to_vector(); +} + +std::vector +check_optimize_shape_transforms(const std::vector& dims, + const std::vector& ops) +{ + auto result = migraphx::optimize_shape_transforms(dims, ops); + CHECK(run_shape_transforms(dims, ops) == run_shape_transforms(dims, result)); + CHECK(result == migraphx::optimize_shape_transforms(dims, result)); + return result; +} + template shape_transform_descriptor make_descriptor(const std::vector& dims, const Ts&... xs) { @@ -115,7 +145,7 @@ TEST_CASE(record_reshape_trailing_1s) EXPECT(get_final_lens(desc) == final_lens{3, 4, 4, 1, 1}); EXPECT(get_all_lens(desc) == all_lens{{3}, {4}, {4}, {1}, {1}}); EXPECT(get_all_axes(desc) == - all_axes{d_axes{{0}}, d_axes{{1}}, d_axes{{2}}, d_axes{{}}, d_axes{{}}}); + all_axes{d_axes{{0}}, d_axes{{1}}, d_axes{{2, 0}}, d_axes{{2, 1}}, d_axes{{2, 2}}}); } TEST_CASE(record_reshape_merge) @@ -158,7 +188,7 @@ TEST_CASE(record_reshape_squeeze_trailing_1s) make_op("reshape", {{"dims", {3, 4, 4}}})); EXPECT(get_final_lens(desc) == final_lens{3, 4, 4}); EXPECT(get_all_lens(desc) == all_lens{{3}, {4}, {4}}); - EXPECT(get_all_axes(desc) == all_axes{d_axes{{0}}, d_axes{{1}}, d_axes{{2}}}); + EXPECT(get_all_axes(desc) == all_axes{d_axes{{0}}, d_axes{{1}}, d_axes{{2, 0}}}); } TEST_CASE(record_reshape_non_divisible_fail) @@ -234,41 +264,41 @@ TEST_CASE(simplify_dimension_remove_1_dim) TEST_CASE(optimize_transpose_transpose) { - EXPECT(migraphx::optimize_shape_transforms( - {3, 5, 2}, - { - make_op("transpose", {{"permutation", {0, 2, 1}}}), - make_op("transpose", {{"permutation", {1, 0, 2}}}), - }) == ops{ - make_op("transpose", {{"permutation", {2, 0, 1}}}), - }); + EXPECT(check_optimize_shape_transforms({3, 5, 2}, + { + make_op("transpose", {{"permutation", {0, 2, 1}}}), + make_op("transpose", {{"permutation", {1, 0, 2}}}), + }) == + ops{ + make_op("transpose", {{"permutation", {2, 0, 1}}}), + }); } TEST_CASE(optimize_reshape_reshape1) { - EXPECT(migraphx::optimize_shape_transforms({3, 5, 2}, - { - make_op("reshape", {{"dims", {30}}}), - make_op("reshape", {{"dims", {3, 10}}}), - }) == ops{ - make_op("reshape", {{"dims", {3, 10}}}), - }); + EXPECT(check_optimize_shape_transforms({3, 5, 2}, + { + make_op("reshape", {{"dims", {30}}}), + make_op("reshape", {{"dims", {3, 10}}}), + }) == ops{ + make_op("reshape", {{"dims", {3, 10}}}), + }); } TEST_CASE(optimize_reshape_reshape2) { - EXPECT(migraphx::optimize_shape_transforms({15, 4}, - { - make_op("reshape", {{"dims", {3, 5, 2, 2}}}), - make_op("reshape", {{"dims", {15, 2, 2}}}), - }) == ops{ - make_op("reshape", {{"dims", {15, 2, 2}}}), - }); + EXPECT(check_optimize_shape_transforms({15, 4}, + { + make_op("reshape", {{"dims", {3, 5, 2, 2}}}), + make_op("reshape", {{"dims", {15, 2, 2}}}), + }) == ops{ + make_op("reshape", {{"dims", {15, 2, 2}}}), + }); } TEST_CASE(optimize_reshape_transpose_reshape_to_none) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {6, 5, 2}, { make_op("reshape", {{"dims", {6, 5, 2, 1, 1}}}), @@ -279,22 +309,22 @@ TEST_CASE(optimize_reshape_transpose_reshape_to_none) TEST_CASE(optimize_reshape_transpose_reshape_to_same) { - EXPECT(migraphx::optimize_shape_transforms( - {1, 112, 56, 56}, + EXPECT(check_optimize_shape_transforms( + {1, 112, 7, 7}, { - make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), + make_op("reshape", {{"dims", {1, 4, 28, 7, 7}}}), make_op("transpose", {{"permutation", {0, 2, 1, 3, 4}}}), - make_op("reshape", {{"dims", {1, 112, 56, 56}}}), + make_op("reshape", {{"dims", {1, 112, 7, 7}}}), }) == ops{ - make_op("reshape", {{"dims", {1, 4, 28, 56, 56}}}), + make_op("reshape", {{"dims", {1, 4, 28, 7, 7}}}), make_op("transpose", {{"permutation", {0, 2, 1, 3, 4}}}), - make_op("reshape", {{"dims", {1, 112, 56, 56}}}), + make_op("reshape", {{"dims", {1, 112, 7, 7}}}), }); } TEST_CASE(optimize_reshape_transpose_reshape_to_transpose) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {6, 5, 2}, { make_op("reshape", {{"dims", {2, 3, 5, 2}}}), @@ -307,20 +337,20 @@ TEST_CASE(optimize_reshape_transpose_reshape_to_transpose) TEST_CASE(optimize_reshape_transpose_reshape_to_reshape) { - EXPECT(migraphx::optimize_shape_transforms( - {6, 5, 2}, - { - make_op("reshape", {{"dims", {6, 5, 2, 1}}}), - make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), - make_op("reshape", {{"dims", {6, 10}}}), - }) == ops{ - make_op("reshape", {{"dims", {6, 10}}}), - }); + EXPECT( + check_optimize_shape_transforms({6, 5, 2}, + { + make_op("reshape", {{"dims", {6, 5, 2, 1}}}), + make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), + make_op("reshape", {{"dims", {6, 10}}}), + }) == ops{ + make_op("reshape", {{"dims", {6, 10}}}), + }); } TEST_CASE(optimize_multibroadcast_transpose_reshape) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {1, 5, 2}, { make_op("multibroadcast", {{"out_lens", {20, 5, 2}}}), @@ -335,7 +365,7 @@ TEST_CASE(optimize_multibroadcast_transpose_reshape) TEST_CASE(optimize_resize1) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {3, 4, 4}, { make_op("reshape", {{"dims", {3, 1, 4, 1, 4}}}), @@ -350,7 +380,7 @@ TEST_CASE(optimize_resize1) TEST_CASE(optimize_resize2) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {1, 1, 2, 2}, { make_op("reshape", {{"dims", {1, 1, 2, 1, 2, 1}}}), @@ -366,54 +396,53 @@ TEST_CASE(optimize_resize2) TEST_CASE(optimize_reshape_2_squeeze) { - EXPECT(migraphx::optimize_shape_transforms({3, 1, 5, 1, 2, 1, 1}, - { - make_op("reshape", {{"dims", {3, 5, 2}}}), - }) == - ops{ - make_op("squeeze", {{"axes", {1, 3, 5, 6}}}), - }); + EXPECT(check_optimize_shape_transforms({3, 1, 5, 1, 2, 1, 1}, + { + make_op("reshape", {{"dims", {3, 5, 2}}}), + }) == ops{ + make_op("squeeze", {{"axes", {1, 3, 5, 6}}}), + }); } TEST_CASE(optimize_reshape_2_unsqueeze) { - EXPECT(migraphx::optimize_shape_transforms( - {3, 5, 2}, - { - make_op("reshape", {{"dims", {3, 1, 5, 1, 2, 1, 1}}}), - }) == ops{ - make_op("unsqueeze", {{"axes", {1, 3, 5, 6}}}), - }); + EXPECT( + check_optimize_shape_transforms({3, 5, 2}, + { + make_op("reshape", {{"dims", {3, 1, 5, 1, 2, 1, 1}}}), + }) == ops{ + make_op("unsqueeze", {{"axes", {1, 3, 5, 6}}}), + }); } TEST_CASE(optimize_unsqueeze_multibroadcast) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {32, 10}, { make_op("unsqueeze", {{"axes", {0, 3, 4}}}), - make_op("multibroadcast", {{"out_lens", {256, 32, 10, 16, 16}}}), + make_op("multibroadcast", {{"out_lens", {4, 32, 10, 16, 16}}}), }) == ops{ - make_op("broadcast", {{"axis", 1}, {"out_lens", {256, 32, 10, 16, 16}}}), + make_op("broadcast", {{"axis", 1}, {"out_lens", {4, 32, 10, 16, 16}}}), }); } TEST_CASE(optimize_multibroadcast_reshape) { - EXPECT(migraphx::optimize_shape_transforms( - {1, 4, 1}, - { - make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), - make_op("reshape", {{"dims", {2, 2, 2, 6}}}), - }) == ops{ - make_op("reshape", {{"dims", {1, 2, 2, 1}}}), - make_op("multibroadcast", {{"out_lens", {2, 2, 2, 6}}}), - }); + EXPECT(check_optimize_shape_transforms({1, 4, 1}, + { + make_op("multibroadcast", {{"out_lens", {2, 4, 6}}}), + make_op("reshape", {{"dims", {2, 2, 2, 6}}}), + }) == + ops{ + make_op("reshape", {{"dims", {1, 2, 2, 1}}}), + make_op("multibroadcast", {{"out_lens", {2, 2, 2, 6}}}), + }); } -TEST_CASE(optimize_squeeze_broadcast) +TEST_CASE(optimize_squeeze_broadcast1) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {256, 1, 1}, { make_op("squeeze"), @@ -424,9 +453,22 @@ TEST_CASE(optimize_squeeze_broadcast) }); } +TEST_CASE(optimize_squeeze_broadcast2) +{ + EXPECT(check_optimize_shape_transforms( + {1, 128, 1}, + { + make_op("squeeze", {{"axes", {0}}}), + make_op("multibroadcast", {{"out_lens", {128, 768}}}), + }) == ops{ + make_op("squeeze", {{"axes", {0}}}), + make_op("multibroadcast", {{"out_lens", {128, 768}}}), + }); +} + TEST_CASE(optimize_squeeze_unsqueeze_broadcast_to_broadcast) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {256}, { make_op("unsqueeze", {{"axes", {0}}}), @@ -439,7 +481,7 @@ TEST_CASE(optimize_squeeze_unsqueeze_broadcast_to_broadcast) TEST_CASE(optimize_transpose_reshape_to_transpose) { - EXPECT(migraphx::optimize_shape_transforms( + EXPECT(check_optimize_shape_transforms( {3, 3, 3, 1}, { make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), @@ -451,14 +493,73 @@ TEST_CASE(optimize_transpose_reshape_to_transpose) TEST_CASE(optimize_scalar_broadcast_unsqueeze) { - EXPECT(migraphx::optimize_shape_transforms({1}, - { - make_op("multibroadcast", {{"out_lens", {2}}}), - make_op("unsqueeze", {{"axes", {1}}}), - }) == + EXPECT(check_optimize_shape_transforms({1}, + { + make_op("multibroadcast", {{"out_lens", {2}}}), + make_op("unsqueeze", {{"axes", {1}}}), + }) == ops{ make_op("multibroadcast", {{"out_lens", {2, 1}}}), }); } +TEST_CASE(optimize_broadcast_reshape_transpose) +{ + EXPECT(check_optimize_shape_transforms( + {2, 16, 1}, + { + make_op("multibroadcast", {{"out_lens", {2, 16, 10240}}}), + make_op("reshape", {{"dims", {2, 160, 32, 32}}}), + make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), + }) == ops{ + make_op("unsqueeze", {{"axes", {3, 4}}}), + make_op("transpose", {{"permutation", {0, 3, 4, 1, 2}}}), + make_op("multibroadcast", {{"out_lens", {2, 1, 1, 16, 10}}}), + make_op("reshape", {{"dims", {2, 1, 1, 160}}}), + make_op("multibroadcast", {{"out_lens", {2, 32, 32, 160}}}), + }); +} + +TEST_CASE(optimize_multibroadcast_transpose) +{ + EXPECT(check_optimize_shape_transforms( + {320, 1, 1}, + { + make_op("multibroadcast", {{"out_lens", {2, 320, 64, 64}}}), + make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), + }) == ops{ + make_op("unsqueeze", {{"axes", {0}}}), + make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), + make_op("multibroadcast", {{"out_lens", {2, 64, 64, 320}}}), + }); +} + +TEST_CASE(optimize_unsqueeze_transpose_squeeze_multibroadcast) +{ + EXPECT(check_optimize_shape_transforms( + {320, 1, 1}, + { + make_op("unsqueeze", {{"axes", {0}}}), + make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), + make_op("squeeze", {{"axes", {0, 1}}}), + make_op("multibroadcast", {{"out_lens", {320, 320}}}), + }) == ops{ + make_op("multibroadcast", {{"out_lens", {320, 1, 320}}}), + make_op("squeeze", {{"axes", {1}}}), + }); +} + +TEST_CASE(optimize_squeeze_multibroadcast_transpose) +{ + EXPECT(check_optimize_shape_transforms( + {16, 1, 16}, + { + make_op("squeeze", {{"axes", {1}}}), + make_op("multibroadcast", {{"out_lens", {4, 16, 16}}}), + make_op("transpose", {{"permutation", {1, 0, 2}}}), + }) == ops{ + make_op("multibroadcast", {{"out_lens", {16, 4, 16}}}), + }); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index a1dcd2a28be..3f1f5ebabb6 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -1296,17 +1296,11 @@ TEST_CASE(concat_reshape_broadcast) } migraphx::module m2; { - auto x = m2.add_parameter("x", s); - auto y = m2.add_parameter("y", s); - auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y); - // TODO: This could just be a broadcast - // auto broadcast = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", - // {22016, 32, 128}}}), concat); - auto unsqueeze = - m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), concat); + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x, y); auto broadcast = m2.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {22016, 32, 1, 128}}}), unsqueeze); - + migraphx::make_op("multibroadcast", {{"out_lens", {22016, 32, 128}}}), concat); auto reshape = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {22016, 4096}}}), broadcast); m2.add_return({reshape}); From 88327d7e03baee9117b5b3686beb3d2bd95ee05a Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Wed, 4 Dec 2024 14:19:13 -0600 Subject: [PATCH 6/7] Increase timeout to 3 hours (#3675) --- Jenkinsfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 0f6725d29eb..5ab5af29cb6 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -42,7 +42,7 @@ def rocmtestnode(Map conf) { rm -rf build mkdir build cd build - cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_DEV=On -DCMAKE_EXECUTE_PROCESS_COMMAND_ECHO=STDOUT -DMIGRAPHX_DISABLE_VIRTUAL_ENV=ON ${flags} .. + cmake -DCTEST_TIMEOUT=3600 -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_DEV=On -DCMAKE_EXECUTE_PROCESS_COMMAND_ECHO=STDOUT -DMIGRAPHX_DISABLE_VIRTUAL_ENV=ON ${flags} .. git diff git diff-index --quiet HEAD || (echo "Git repo is not clean after running cmake." && exit 1) make -j\$(nproc) generate VERBOSE=1 @@ -71,7 +71,7 @@ def rocmtestnode(Map conf) { pre() sh "docker pull ${DOCKER_IMAGE}:${env.IMAGE_TAG}" withDockerContainer(image: "${DOCKER_IMAGE}:${env.IMAGE_TAG}", args: "--device=/dev/kfd --device=/dev/dri --group-add video --cap-add SYS_PTRACE -v=/home/jenkins/:/home/jenkins ${docker_args}") { - timeout(time: 2, unit: 'HOURS') { + timeout(time: 3, unit: 'HOURS') { body(cmake_build) } } From dde79865ee219badf35b509dad0f783ccf0505f3 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Wed, 4 Dec 2024 15:21:50 -0500 Subject: [PATCH 7/7] bit_cast operator (#3655) --- src/CMakeLists.txt | 1 + src/include/migraphx/op/bit_cast.hpp | 104 ++++++++++++++++++ .../include/migraphx/kernels/bit_cast.hpp | 11 +- test/op_shape_test.cpp | 15 +++ test/ref/bit_cast.cpp | 75 +++++++++++++ test/verify/main.cpp | 8 +- test/verify/test_bit_cast.cpp | 55 +++++++++ 7 files changed, 264 insertions(+), 5 deletions(-) create mode 100644 src/include/migraphx/op/bit_cast.hpp create mode 100644 test/ref/bit_cast.cpp create mode 100644 test/verify/test_bit_cast.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 334c9615f68..eda6ea626e4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -145,6 +145,7 @@ register_migraphx_ops( as_shape atanh atan + bit_cast bitwise_and broadcast broadcast_for_dot diff --git a/src/include/migraphx/op/bit_cast.hpp b/src/include/migraphx/op/bit_cast.hpp new file mode 100644 index 00000000000..eb233ad8b36 --- /dev/null +++ b/src/include/migraphx/op/bit_cast.hpp @@ -0,0 +1,104 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_OPERATORS_BIT_CAST_HPP +#define MIGRAPHX_GUARD_OPERATORS_BIT_CAST_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +/** + * Obtain a value of type `target_type` by reinterpreting + * the object represnetaion of the input. Originally used + * for casting from fp8e4m3fn to fp8e4m3fnuz. + */ +struct bit_cast : unary +{ + shape::type_t target_type; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.target_type, "target_type")); + } + + shape compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this, true}.has(1); + auto input = inputs.at(0); + std::size_t target_type_size; + shape::visit(target_type, [&](auto as) { target_type_size = as.size(); }); + if(input.type_size() != target_type_size) + { + MIGRAPHX_THROW("BIT_CAST: target_type has different type_size from input's"); + } + if(input.dynamic()) + { + return {target_type, input.dyn_dims()}; + } + else + { + return {target_type, input.lens(), input.strides()}; + } + } + + std::string point_op() const + { + return "${function:bit_cast}<" + shape::cpp_type(target_type) + ">(${0})"; + } + + argument compute(const dyn_output& dyn_out, std::vector args) const + { + argument result{dyn_out.computed_shape}; + result.visit([&](auto output) { + using otype = typename decltype(output)::value_type; + args[0].visit([&](auto input) { + using itype = typename decltype(input)::value_type; + if constexpr(sizeof(otype) == sizeof(itype)) + { + par_transform(input.begin(), input.end(), output.begin(), [&](auto x) { + return migraphx::bit_cast(x); + }); + } + else + { + // not possible to hit this unless somehow the types change after compute_shape + // is called + MIGRAPHX_THROW("BIT_CAST: type size mismatch"); + } + }); + }); + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp index c98395bbe10..e559658a004 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp @@ -23,15 +23,20 @@ #define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP #include +#include namespace migraphx { + template {} and is_trivially_copyable{})> -inline constexpr To bit_cast(From fr) noexcept +inline constexpr auto bit_cast(From fr) noexcept { - static_assert(sizeof(To) == sizeof(From)); - return __builtin_bit_cast(To, fr); + return vec_transform(fr)([](auto x) -> To { + static_assert(sizeof(To) == sizeof(decltype(x))); + return __builtin_bit_cast(To, x); + }); } + } // namespace migraphx #endif // MIGRAPHX_GUARD_KERNELS_BITCAST_HPP diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 8d08455d814..66d54c8a460 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -201,6 +201,21 @@ TEST_CASE(binary_dyn_static_error) throws_shape(migraphx::make_op("add"), a_shape, b_shape); } +TEST_CASE(bit_cast_typesize_mismatch) +{ + migraphx::shape a_shape{migraphx::shape::int8_type, {1, 4, 4}}; + throws_shape(migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::int32_type}}), + a_shape); +} + +TEST_CASE(bit_cast_dyn) +{ + migraphx::shape a_shape{migraphx::shape::int8_type, {{1, 1}, {4, 8}, {4, 8}}}; + expect_shape(migraphx::shape{migraphx::shape::uint8_type, {{1, 1}, {4, 8}, {4, 8}}}, + migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::uint8_type}}), + a_shape); +} + TEST_CASE(bitwise_and_not_integral_error) { migraphx::shape a_shape{migraphx::shape::float_type, {1, 4, 4}}; diff --git a/test/ref/bit_cast.cpp b/test/ref/bit_cast.cpp new file mode 100644 index 00000000000..4f9438ef4fd --- /dev/null +++ b/test/ref/bit_cast.cpp @@ -0,0 +1,75 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include + +#include + +TEST_CASE(bit_cast_fp8) +{ + using migraphx::fp8::fp8e4m3fn; + using migraphx::fp8::fp8e4m3fnuz; + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::fp8e4m3fn_type, {2, 2}}; + std::vector data; + data.push_back(fp8e4m3fn{26.0f}); + data.push_back(fp8e4m3fn{3.0f}); + data.push_back(fp8e4m3fn{96.0f}); + data.push_back(fp8e4m3fn{-1.25f}); + auto lit = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), lit); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold; + gold.push_back(fp8e4m3fnuz{13.0f}); + gold.push_back(fp8e4m3fnuz{1.5f}); + gold.push_back(fp8e4m3fnuz{48.0f}); + gold.push_back(fp8e4m3fnuz{-0.625f}); + EXPECT(results_vector == gold); +} + +TEST_CASE(bit_cast_uint8) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::int8_type, {2, 2}}; + std::vector data = {23, -3, 0, -1}; + auto lit = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::shape::uint8_type}}), lit); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector(4); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold = {23, 253, 0, 255}; + EXPECT(results_vector == gold); +} diff --git a/test/verify/main.cpp b/test/verify/main.cpp index 5daa8a858d6..876db639644 100644 --- a/test/verify/main.cpp +++ b/test/verify/main.cpp @@ -129,14 +129,18 @@ int main(int argc, const char* argv[]) "test_block_reduce_small<67, migraphx::shape::int8_type>", "test_block_reduce_small<128, migraphx::shape::int8_type>", "test_block_reduce_small<129, migraphx::shape::int8_type>", + // disabled because CPU does eliminate_data_type to float for everything "test_bitwise_and", "test_bitwise_and", - "test_unpack_int4", "test_unpack_int4", "test_unpack_int4", - "test_unpack_int4"}); + "test_unpack_int4", + "test_bit_cast", + "test_bit_cast", + "test_bit_cast", + "test_bit_cast"}); rv.disable_test_for("gpu", { // These passes on MI300 but fails on others, same issue as CPU. diff --git a/test/verify/test_bit_cast.cpp b/test/verify/test_bit_cast.cpp new file mode 100644 index 00000000000..24f9a7fc745 --- /dev/null +++ b/test/verify/test_bit_cast.cpp @@ -0,0 +1,55 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +template +struct test_bit_cast : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{From, {8}}; + auto pa = mm->add_parameter("a", s); + auto pb = mm->add_parameter("b", s); + auto ia = mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::to_value(To)}}), pa); + auto ib = mm->add_instruction( + migraphx::make_op("bit_cast", {{"target_type", migraphx::to_value(To)}}), pb); + auto ret = mm->add_instruction(migraphx::make_op("add"), ia, ib); + mm->add_return({ret}); + return p; + }; +}; + +template struct test_bit_cast; +template struct test_bit_cast; +template struct test_bit_cast; +template struct test_bit_cast;