Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Cumulative reduction (max, min, sum, prod) forward with small last dim #3297

Open
wants to merge 71 commits into
base: develop
Choose a base branch
from

Conversation

long10024070
Copy link
Collaborator

@long10024070 long10024070 commented Oct 4, 2024

This PR is a continuation of PR #3182. Accidently, I have closed the older PR, and then made change to the working branch, which makes me cannot reopen the older once. There are not many comments in that PR, I hope it doesn't interrupt your reviewing process. And again, sorry for this Inconvenience.

  • Added cumulative reduction forward operation and kernel with solver, support binary operators (max, min, sum, prod). This operation equivalent to cummax, cummin, cumsum, cumprod in Pytorch.
  • Added driver test and gtest for cumulative reduction.
  • New API is guarded by MIOPEN_BETA_API macro.
  • Compared to ROCm pytorch, there is a performance improvement when operation is performed on the dim with size smaller or equal to 256 and stride value at that dim of both input, output and indices tensor must equal to 1. For that reason, IsApplicable constraint makes sure that the operation only works with the above case.
float16
op_name dtype size dim contiguous model direction ROCm pytorch MIOpen HIP Improvement
CumMax float16 [512 64 112 112] -1 TRUE random fwd 79103622 10290800 7.69
CumMax float16 [512 64 56 56] -1 TRUE random fwd 39319091 2490330 15.79
CumMax float16 [512 128 56 56] -1 TRUE random fwd 78599721 4982140 15.78
CumMax float16 [512 128 28 28] -1 TRUE random fwd 39227767 2479240 15.82
CumMax float16 [512 256 28 28] -1 TRUE random fwd 78414528 4955720 15.82
CumMax float16 [512 256 14 14] -1 TRUE random fwd 39164283 2427920 16.13
CumMax float16 [512 512 14 14] -1 TRUE random fwd 80268168 4854160 16.54
CumMax float16 [512 512 7 7] -1 TRUE random fwd 39191305 2401980 16.32
CumMax float16 [512 1024 7 7] -1 TRUE random fwd 78271414 4805250 16.29
CumMax float16 [512 1024 100] -1 TRUE random fwd 11277661 1463220 7.71
CumMax float16 [1024 1024 7 7] -1 TRUE random fwd 156666821 10686200 14.66
CumMax float16 [1024 1024 100] -1 TRUE random fwd 22540060 2920460 7.72
CumMin float16 [512 64 112 112] -1 TRUE random fwd 79032894 10293300 7.68
CumMin float16 [512 64 56 56] -1 TRUE random fwd 39290595 2491030 15.77
CumMin float16 [512 128 56 56] -1 TRUE random fwd 78578550 4982730 15.77
CumMin float16 [512 128 28 28] -1 TRUE random fwd 39189412 2478940 15.81
CumMin float16 [512 256 28 28] -1 TRUE random fwd 78419674 4956120 15.82
CumMin float16 [512 256 14 14] -1 TRUE random fwd 39156197 2426850 16.13
CumMin float16 [512 512 14 14] -1 TRUE random fwd 78311994 4855330 16.13
CumMin float16 [512 512 7 7] -1 TRUE random fwd 39105638 2400610 16.29
CumMin float16 [512 1024 7 7] -1 TRUE random fwd 78254683 4805610 16.28
CumMin float16 [512 1024 100] -1 TRUE random fwd 11269600 1461600 7.71
CumMin float16 [1024 1024 7 7] -1 TRUE random fwd 156521111 10696300 14.63
CumMin float16 [1024 1024 100] -1 TRUE random fwd 22551889 5641600 4.00
CumSum float16 [512 64 112 112] -1 TRUE random fwd 36839240 6739680 5.47
CumSum float16 [512 64 56 56] -1 TRUE random fwd 18283694 2321070 7.88
CumSum float16 [512 128 56 56] -1 TRUE random fwd 36585132 4639960 7.88
CumSum float16 [512 128 28 28] -1 TRUE random fwd 18230703 2307310 7.90
CumSum float16 [512 256 28 28] -1 TRUE random fwd 36477501 4612030 7.91
CumSum float16 [512 256 14 14] -1 TRUE random fwd 18207967 2298780 7.92
CumSum float16 [512 512 14 14] -1 TRUE random fwd 36433086 4594060 7.93
CumSum float16 [512 512 7 7] -1 TRUE random fwd 18215727 2291620 7.95
CumSum float16 [512 1024 7 7] -1 TRUE random fwd 36442782 4580620 7.96
CumSum float16 [512 1024 100] -1 TRUE random fwd 5255286 956699 5.49
CumSum float16 [1024 1024 7 7] -1 TRUE random fwd 72925500 9161610 7.96
CumSum float16 [1024 1024 100] -1 TRUE random fwd 10510668 1924150 5.46
CumProd float16 [512 64 112 112] -1 TRUE random fwd 36853144 6734100 5.47
CumProd float16 [512 64 56 56] -1 TRUE random fwd 18310781 2320740 7.89
CumProd float16 [512 128 56 56] -1 TRUE random fwd 36623723 4640240 7.89
CumProd float16 [512 128 28 28] -1 TRUE random fwd 18271694 2309390 7.91
CumProd float16 [512 256 28 28] -1 TRUE random fwd 36523629 4616960 7.91
CumProd float16 [512 256 14 14] -1 TRUE random fwd 18221247 2301290 7.92
CumProd float16 [512 512 14 14] -1 TRUE random fwd 36498109 4601770 7.93
CumProd float16 [512 512 7 7] -1 TRUE random fwd 18246623 2295060 7.95
CumProd float16 [512 1024 7 7] -1 TRUE random fwd 39393701 4588880 8.58
CumProd float16 [512 1024 100] -1 TRUE random fwd 5260982 956966 5.50
CumProd float16 [1024 1024 7 7] -1 TRUE random fwd 73033611 10276700 7.11
CumProd float16 [1024 1024 100] -1 TRUE random fwd 10518988 1910770 5.51
float32
op_name dtype size dim contiguous model direction ROCm pytorch MIOpen HIP Improvement
CumMax float32 [512 64 112 112] -1 TRUE random fwd 79353556 10510300 7.55
CumMax float32 [512 64 56 56] -1 TRUE random fwd 39444502 2528340 15.60
CumMax float32 [512 128 56 56] -1 TRUE random fwd 78924619 5057250 15.61
CumMax float32 [512 128 28 28] -1 TRUE random fwd 39394950 2517180 15.65
CumMax float32 [512 256 28 28] -1 TRUE random fwd 78769181 5027420 15.67
CumMax float32 [512 256 14 14] -1 TRUE random fwd 39279320 2443630 16.07
CumMax float32 [512 512 14 14] -1 TRUE random fwd 80072059 4885720 16.39
CumMax float32 [512 512 7 7] -1 TRUE random fwd 39238905 2423280 16.19
CumMax float32 [512 1024 7 7] -1 TRUE random fwd 78490449 4814360 16.30
CumMax float32 [512 1024 100] -1 TRUE random fwd 11320257 1490630 7.59
CumMax float32 [1024 1024 7 7] -1 TRUE random fwd 157026754 9636110 16.30
CumMax float32 [1024 1024 100] -1 TRUE random fwd 22649554 2982970 7.59
CumMin float32 [512 64 112 112] -1 TRUE random fwd 79317382 10511900 7.55
CumMin float32 [512 64 56 56] -1 TRUE random fwd 39419030 2529820 15.58
CumMin float32 [512 128 56 56] -1 TRUE random fwd 78850445 9170440 8.60
CumMin float32 [512 128 28 28] -1 TRUE random fwd 39393495 2515360 15.66
CumMin float32 [512 256 28 28] -1 TRUE random fwd 78737166 5027230 15.66
CumMin float32 [512 256 14 14] -1 TRUE random fwd 39270408 2443900 16.07
CumMin float32 [512 512 14 14] -1 TRUE random fwd 79465092 4885980 16.26
CumMin float32 [512 512 7 7] -1 TRUE random fwd 39264681 2410990 16.29
CumMin float32 [512 1024 7 7] -1 TRUE random fwd 78513042 4815270 16.31
CumMin float32 [512 1024 100] -1 TRUE random fwd 11321649 1490860 7.59
CumMin float32 [1024 1024 7 7] -1 TRUE random fwd 157041875 9633820 16.30
CumMin float32 [1024 1024 100] -1 TRUE random fwd 22661778 5730720 3.95
CumSum float32 [512 64 112 112] -1 TRUE random fwd 37420899 7051980 5.31
CumSum float32 [512 64 56 56] -1 TRUE random fwd 18553115 2330090 7.96
CumSum float32 [512 128 56 56] -1 TRUE random fwd 37096775 4656530 7.97
CumSum float32 [512 128 28 28] -1 TRUE random fwd 18498636 2312900 8.00
CumSum float32 [512 256 28 28] -1 TRUE random fwd 37008008 4623340 8.00
CumSum float32 [512 256 14 14] -1 TRUE random fwd 18427773 2301890 8.01
CumSum float32 [512 512 14 14] -1 TRUE random fwd 36886474 4601850 8.02
CumSum float32 [512 512 7 7] -1 TRUE random fwd 18399326 2293910 8.02
CumSum float32 [512 1024 7 7] -1 TRUE random fwd 36863786 4586130 8.04
CumSum float32 [512 1024 100] -1 TRUE random fwd 5337701 998352 5.35
CumSum float32 [1024 1024 7 7] -1 TRUE random fwd 75153089 9171500 8.19
CumSum float32 [1024 1024 100] -1 TRUE random fwd 10686874 1993500 5.36
CumProd float32 [512 64 112 112] -1 TRUE random fwd 37492178 7043960 5.32
CumProd float32 [512 64 56 56] -1 TRUE random fwd 18602251 2328130 7.99
CumProd float32 [512 128 56 56] -1 TRUE random fwd 37180790 4653930 7.99
CumProd float32 [512 128 28 28] -1 TRUE random fwd 18552732 2312610 8.02
CumProd float32 [512 256 28 28] -1 TRUE random fwd 37102295 4625170 8.02
CumProd float32 [512 256 14 14] -1 TRUE random fwd 18471901 2303490 8.02
CumProd float32 [512 512 14 14] -1 TRUE random fwd 36980297 4605450 8.03
CumProd float32 [512 512 7 7] -1 TRUE random fwd 18449117 2295490 8.04
CumProd float32 [512 1024 7 7] -1 TRUE random fwd 36929706 4589030 8.05
CumProd float32 [512 1024 100] -1 TRUE random fwd 5350325 996876 5.37
CumProd float32 [1024 1024 7 7] -1 TRUE random fwd 73828228 9180310 8.04
CumProd float32 [1024 1024 100] -1 TRUE random fwd 10692522 1992210 5.37
bfloat16
op_name dtype size dim contiguous model direction ROCm pytorch MIOpen HIP Improvement
CumMax bfloat16 [512 64 112 112] -1 TRUE random fwd 82001795 10583800 7.75
CumMax bfloat16 [512 64 56 56] -1 TRUE random fwd 40779253 2538400 16.06
CumMax bfloat16 [512 128 56 56] -1 TRUE random fwd 81604184 5080060 16.06
CumMax bfloat16 [512 128 28 28] -1 TRUE random fwd 40773765 2517520 16.20
CumMax bfloat16 [512 256 28 28] -1 TRUE random fwd 81536393 5032490 16.20
CumMax bfloat16 [512 256 14 14] -1 TRUE random fwd 40759269 2462850 16.55
CumMax bfloat16 [512 512 14 14] -1 TRUE random fwd 81497354 4925450 16.55
CumMax bfloat16 [512 512 7 7] -1 TRUE random fwd 41580409 2433980 17.08
CumMax bfloat16 [512 1024 7 7] -1 TRUE random fwd 81399196 4872310 16.71
CumMax bfloat16 [512 1024 100] -1 TRUE random fwd 11702748 1502110 7.79
CumMax bfloat16 [1024 1024 7 7] -1 TRUE random fwd 162846550 9719290 16.75
CumMax bfloat16 [1024 1024 100] -1 TRUE random fwd 23391864 2999200 7.80
CumMin bfloat16 [512 64 112 112] -1 TRUE random fwd 82027507 10510000 7.80
CumMin bfloat16 [512 64 56 56] -1 TRUE random fwd 40770069 2529640 16.12
CumMin bfloat16 [512 128 56 56] -1 TRUE random fwd 81606825 5061550 16.12
CumMin bfloat16 [512 128 28 28] -1 TRUE random fwd 40762245 2513820 16.22
CumMin bfloat16 [512 256 28 28] -1 TRUE random fwd 81501883 5028670 16.21
CumMin bfloat16 [512 256 14 14] -1 TRUE random fwd 40744486 2462740 16.54
CumMin bfloat16 [512 512 14 14] -1 TRUE random fwd 81500475 4921650 16.56
CumMin bfloat16 [512 512 7 7] -1 TRUE random fwd 40697830 2433980 16.72
CumMin bfloat16 [512 1024 7 7] -1 TRUE random fwd 81402956 4870530 16.71
CumMin bfloat16 [512 1024 100] -1 TRUE random fwd 11700876 1492480 7.84
CumMin bfloat16 [1024 1024 7 7] -1 TRUE random fwd 162799578 9718710 16.75
CumMin bfloat16 [1024 1024 100] -1 TRUE random fwd 23387721 2980940 7.85
CumSum bfloat16 [512 64 112 112] -1 TRUE random fwd 46814849 6889390 6.80
CumSum bfloat16 [512 64 56 56] -1 TRUE random fwd 23282362 2320860 10.03
CumSum bfloat16 [512 128 56 56] -1 TRUE random fwd 46555589 6526530 7.13
CumSum bfloat16 [512 128 28 28] -1 TRUE random fwd 23230827 2307880 10.07
CumSum bfloat16 [512 256 28 28] -1 TRUE random fwd 46477910 4613740 10.07
CumSum bfloat16 [512 256 14 14] -1 TRUE random fwd 23206284 2299330 10.09
CumSum bfloat16 [512 512 14 14] -1 TRUE random fwd 46414775 4595450 10.10
CumSum bfloat16 [512 512 7 7] -1 TRUE random fwd 23198811 2292080 10.12
CumSum bfloat16 [512 1024 7 7] -1 TRUE random fwd 46428582 4581610 10.13
CumSum bfloat16 [512 1024 100] -1 TRUE random fwd 6686083 978317 6.83
CumSum bfloat16 [1024 1024 7 7] -1 TRUE random fwd 92810670 9164860 10.13
CumSum bfloat16 [1024 1024 100] -1 TRUE random fwd 13372485 1954060 6.84
CumProd bfloat16 [512 64 112 112] -1 TRUE random fwd 46990639 6894460 6.82
CumProd bfloat16 [512 64 56 56] -1 TRUE random fwd 23371545 2323120 10.06
CumProd bfloat16 [512 128 56 56] -1 TRUE random fwd 46773218 4644510 10.07
CumProd bfloat16 [512 128 28 28] -1 TRUE random fwd 23333338 2310600 10.10
CumProd bfloat16 [512 256 28 28] -1 TRUE random fwd 46674003 4619370 10.10
CumProd bfloat16 [512 256 14 14] -1 TRUE random fwd 23306058 2302690 10.12
CumProd bfloat16 [512 512 14 14] -1 TRUE random fwd 46625844 4603780 10.13
CumProd bfloat16 [512 512 7 7] -1 TRUE random fwd 23304010 2295150 10.15
CumProd bfloat16 [512 1024 7 7] -1 TRUE random fwd 46605092 4588370 10.16
CumProd bfloat16 [512 1024 100] -1 TRUE random fwd 6709842 979597 6.85
CumProd bfloat16 [1024 1024 7 7] -1 TRUE random fwd 93215385 9178690 10.16
CumProd bfloat16 [1024 1024 100] -1 TRUE random fwd 13421396 1954820 6.87

Average over all cases:

type average
float16 10.42
float32 10.43
bfloat16 11.32

long10024070 and others added 30 commits June 19, 2024 10:37
Copy link
Contributor

@iq136boy iq136boy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI log:
3297_log.txt

@long10024070 long10024070 force-pushed the impl_cumulative_reduction_improvedOverROCM branch from ef7bb56 to 7b6cb42 Compare October 9, 2024 04:55
Copy link
Contributor

@iq136boy iq136boy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error_log:
3297_log (2).txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants