Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework split_reshape simplification to handle multiple reshapes from same slice #2146

Merged
merged 13 commits into from
Sep 19, 2023

Conversation

shivadbhavsar
Copy link
Contributor

This resolves an edge case found in multiple huggingface models where the following happens:

slice1 -> cont -> reshape1
       |
        -> cont -> reshape2

slice2 -> cont -> reshape1
       |
        -> cont -> reshape2

slice3 -> cont -> reshape1
       |
        -> cont -> reshape2

In some cases the find_split_reshape matcher will match with reshape2, but vec_rsp will consist of reshape1 dims causing a shape mismatch error. Solution is to include rsp when checking all reshapes are the same.

Added unit test reproduces this issue in the current build.

@pfultz2
Copy link
Collaborator

pfultz2 commented Aug 31, 2023

It seems like this should be written as:

reshape1 -> slice1, slice2, slice3
reshape2 -> slice4,slice5,slice6

Rather than just stopping.

@codecov
Copy link

codecov bot commented Aug 31, 2023

Codecov Report

Merging #2146 (4cc9076) into develop (c2e01b1) will increase coverage by 0.00%.
The diff coverage is 100.00%.

❗ Current head 4cc9076 differs from pull request most recent head 3400122. Consider uploading reports for the commit 3400122 to get more accurate results

@@           Coverage Diff            @@
##           develop    #2146   +/-   ##
========================================
  Coverage    91.48%   91.49%           
========================================
  Files          427      427           
  Lines        15938    15953   +15     
========================================
+ Hits         14581    14596   +15     
  Misses        1357     1357           
Files Changed Coverage Δ
src/simplify_algebra.cpp 96.85% <100.00%> (+0.06%) ⬆️

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Sep 1, 2023

Test Batch Rate new
340012
Rate old
c2e01b
Diff Compare
torchvision-resnet50 64 2,282.79 2,284.87 -0.09%
torchvision-resnet50_fp16 64 5,367.78 5,357.62 0.19%
torchvision-densenet121 32 1,835.07 1,829.81 0.29%
torchvision-densenet121_fp16 32 3,391.29 3,378.73 0.37%
torchvision-inceptionv3 32 1,336.41 1,335.78 0.05%
torchvision-inceptionv3_fp16 32 2,588.08 2,585.25 0.11%
cadene-inceptionv4 16 678.73 679.39 -0.10%
cadene-resnext64x4 16 589.72 590.19 -0.08%
slim-mobilenet 64 7,213.88 7,216.76 -0.04%
slim-nasnetalarge 64 237.05 236.57 0.20%
slim-resnet50v2 64 2,528.69 2,530.71 -0.08%
bert-mrpc-onnx 8 721.09 721.25 -0.02%
bert-mrpc-tf 1 390.79 391.28 -0.12%
pytorch-examples-wlang-gru 1 310.08 303.90 2.04%
pytorch-examples-wlang-lstm 1 310.72 311.77 -0.34%
torchvision-resnet50_1 1 556.07 555.70 0.07%
torchvision-inceptionv3_1 1 307.56 306.97 0.19%
cadene-dpn92_1 1 352.44 351.70 0.21%
cadene-resnext101_1 1 220.50 220.52 -0.01%
slim-vgg16_1 1 224.86 224.37 0.22%
slim-mobilenet_1 1 1,483.04 1,468.80 0.97%
slim-inceptionv4_1 1 221.72 221.26 0.21%
onnx-taau-downsample 1 322.55 322.42 0.04%
dlrm-criteoterabyte 1 21.68 21.67 0.02%
dlrm-criteoterabyte_fp16 1 40.62 40.54 0.19%
agentmodel 1 5,836.42 5,779.11 0.99%
unet_fp16 2 55.08 55.07 0.02%

This build is OK for merge ✅

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Sep 1, 2023


    :white_check_mark:bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

    :white_check_mark:bert-mrpc-tf: PASSED: MIGraphX meets tolerance

    :white_check_mark:pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

    :white_check_mark:pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

    :white_check_mark:torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

🔴torchvision-inceptionv3_1: FAILED: MIGraphX is not within tolerance - check verbose output


🔴cadene-dpn92_1: FAILED: MIGraphX is not within tolerance - check verbose output


    :white_check_mark:cadene-resnext101_1: PASSED: MIGraphX meets tolerance

    :white_check_mark:slim-vgg16_1: PASSED: MIGraphX meets tolerance

    :white_check_mark:slim-mobilenet_1: PASSED: MIGraphX meets tolerance

🔴slim-inceptionv4_1: FAILED: MIGraphX is not within tolerance - check verbose output


    :white_check_mark:dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

    :white_check_mark:agentmodel: PASSED: MIGraphX meets tolerance

    :white_check_mark:unet: PASSED: MIGraphX meets tolerance

@umangyadav
Copy link
Member

In some cases the find_split_reshape matcher will match with reshape2, but vec_rsp will consist of reshape1 dims causing a shape mismatch error

Looks like Matcher was assuming slice only has one output but is not enforcing it.

        // Only want to apply this optimization if each split output is followed by
        // a contiguous op and a reshape
        if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
               if(i->outputs().size() == 1)
               {
                   auto cont = i->outputs().front();
                   return cont->outputs().size() != 1;
               }
               return false;
           }))
        {
            return;
        }

This part if outputs are more than 1 then it is returning false, but it should return true i think.

@shivadbhavsar
Copy link
Contributor Author

In some cases the find_split_reshape matcher will match with reshape2, but vec_rsp will consist of reshape1 dims causing a shape mismatch error

Looks like Matcher was assuming slice only has one output but is not enforcing it.

        // Only want to apply this optimization if each split output is followed by
        // a contiguous op and a reshape
        if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
               if(i->outputs().size() == 1)
               {
                   auto cont = i->outputs().front();
                   return cont->outputs().size() != 1;
               }
               return false;
           }))
        {
            return;
        }

This part if outputs are more than 1 then it is returning false, but it should return true i think.

That was my initial fix actually. But I figured it was coded this way because we wanted to allow the splits to have other outputs but only simpilfy the instructions leading to the reshapes. I guess its not really doing that though because we just assume single output (and that its cont --> reshape) for each slice when we create vec_rsp.

In any case, this would solve the bug but it wont do what @pfultz2 is suggesting. For that we'd have to add more checks for what going into vec_rsp.

@shivadbhavsar
Copy link
Contributor Author

This turned into a larger update than intended but here is a summary of changes:

  1. The initial bug was in part due to a set of (slice -> cont -> reshape)s where the reshapes did not use the entire input shape. Eg. Input: [4, 6, 30], Slices: [4, 2, 30], [4, 2, 30], Reshapes: [4, 60], [4, 60]. Here, our current logic would try to create a new reshape (from input [4, 6, 60]) of size [4, 120], which is invalid (should be [4, 180]). Allowing this kind of simplification required modifying how the new slice indices are computed.
  2. It no longer assumes that the slices have only a single (contiguous -> reshape) attached. We do not modify the existing slice -> contiguous so any nodes using these as inputs should still be valid. If no nodes use these as inputs, dead code elim will take care of it.

@shivadbhavsar shivadbhavsar changed the title Include matched reshape instruction in split_reshape in same_ops check Rework split_reshape simplification to handle multiple reshapes from same slice Sep 6, 2023
src/simplify_algebra.cpp Outdated Show resolved Hide resolved
@umangyadav
Copy link
Member

Unet accuracy check has failed on compilation. Can you check ?

@shivadbhavsar
Copy link
Contributor Author

Unet accuracy check has failed on compilation. Can you check ?

This should be fixed in the new run

// cannot create a valid reshape for simplification
if(input_size % rsp_fixed_size != 0)
{
return;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a unit test for this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

@shivadbhavsar shivadbhavsar linked an issue Sep 12, 2023 that may be closed by this pull request
@TedThemistokleous TedThemistokleous added roadmap Tasks to finish for a release enhancement New feature or request labels Sep 15, 2023
@TedThemistokleous TedThemistokleous added TorchMIGraphX high priority A PR with high priority for review and merging. labels Sep 15, 2023
@causten
Copy link
Collaborator

causten commented Sep 16, 2023

But I figured it was coded this way

Is there enough comments in the code to help understand why something was done a certain way? If you need to make a code change for any other reason feel free to improve the code comments.

@shivadbhavsar
Copy link
Contributor Author

For more clarity heres an example walking through the code (this simplification wouldnt work with the original implementation):

Original Instuctions:
Input: {128, 96}
Slices: 
s0 slice(axes={0}, starts={0}, ends={8})(inp) -> {8, 96}
s1 slice(axes={0}, starts={8}, ends={16})(inp) -> {8, 96}
s2 slice(axes={0}, starts={16}, ends={24})(inp) -> {8, 96}
s3 slice(axes={0}, starts={24}, ends={128})(inp) -> {104, 96}
Reshapes:
reshape(dims={2, 4, 96})(s0)
reshape(dims={2, 4, 96})(s1)
reshape(dims={2, 4, 96})(s2)

Intermediate Variables:
axis = 0
slc_lens = {8, 96}
slc_axis_len = input[axis] = 128
slc_dim_size = 8*96 = 768

rsp_lens = {2, 4, 96}, strides: {96*4*2, 96*4, 96, 1} = {768, 384, 96, 1} (with added stride)
rsp_axis = <index where (rsp_stride == slc_dim_size)> = 0

rsp_out_lens (initial) = {1, 4, 96}
rsp_fixed_size = 1*4*96 = 384
rsp_axis_len = 128 * 96 / 384 = 32 
rsp_out_lens (final) = {32, 4, 96}

new_starts = {0*32/128,  8*32/128, 16*32/128} = {0, 2, 4}
new_ends   = {8*32/128, 16*32/128, 24*32/128} = {2, 4, 6}


New Instructions:
reshape(dims={32, 4, 96})
slice(axes={0}, starts={0}, ends={2}) -> {2, 4, 96}
slice(axes={0}, starts={2}, ends={4}) -> {2, 4, 96}
slice(axes={0}, starts={4}, ends={6}) -> {2, 4, 96}

@umangyadav
Copy link
Member

umangyadav commented Sep 19, 2023

For more clarity heres an example walking through the code (this simplification wouldnt work with the original implementation):

Original Instuctions:
Input: {128, 96}
Slices: 
s0 slice(axes={0}, starts={0}, ends={8})(inp) -> {8, 96}
s1 slice(axes={0}, starts={8}, ends={16})(inp) -> {8, 96}
s2 slice(axes={0}, starts={16}, ends={24})(inp) -> {8, 96}
s3 slice(axes={0}, starts={24}, ends={128})(inp) -> {104, 96}
Reshapes:
reshape(dims={2, 4, 96})(s0)
reshape(dims={2, 4, 96})(s1)
reshape(dims={2, 4, 96})(s2)

Intermediate Variables:
axis = 0
slc_lens = {8, 96}
slc_axis_len = input[axis] = 128
slc_dim_size = 8*96 = 768

rsp_lens = {2, 4, 96}, strides: {96*4*2, 96*4, 96, 1} = {768, 384, 96, 1} (with added stride)
rsp_axis = <index where (rsp_stride == slc_dim_size)> = 0

rsp_out_lens (initial) = {1, 4, 96}
rsp_fixed_size = 1*4*96 = 384
rsp_axis_len = 128 * 96 / 384 = 32 
rsp_out_lens (final) = {32, 4, 96}

new_starts = {0*32/128,  8*32/128, 16*32/128} = {0, 2, 4}
new_ends   = {8*32/128, 16*32/128, 24*32/128} = {2, 4, 6}


New Instructions:
reshape(dims={32, 4, 96})
slice(axes={0}, starts={0}, ends={2}) -> {2, 4, 96}
slice(axes={0}, starts={2}, ends={4}) -> {2, 4, 96}
slice(axes={0}, starts={4}, ends={6}) -> {2, 4, 96}

It would be good to add this a test instead of one with reorder_reshape_slice_uneven_slice which would have worked without your changes as well. Better Keep both.

@umangyadav
Copy link
Member

@shivadbhavsar

onnx-taau-downsample 1 247.81 322.80 -23.23% 🔴

Can you check this one ?

@causten causten merged commit 0bb8508 into develop Sep 19, 2023
12 checks passed
@causten causten deleted the fix_split_reshape branch September 19, 2023 23:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request high priority A PR with high priority for review and merging. roadmap Tasks to finish for a release TorchMIGraphX
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dynamo Benchmarks - Simplify algebra reshape size error
6 participants