Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Where backward #3295

Open
wants to merge 50 commits into
base: develop
Choose a base branch
from
Open

Implement Where backward #3295

wants to merge 50 commits into from

Conversation

cognaiger9
Copy link
Collaborator

  • Add Where operation with contiguous backward kernel.
  • Add driver and gtest for kernel.
  • MIOpen performs better if:
    • Input, other and condition tensors have the same shape
    • All tensors are contiguous

Average improvement over ROCm

type bwd
float16 1.79
float 1.74
bfloat16 1.8

Detail Benchmark

float16
op_name dtype size contiguous direction rocm_kernel_avg rocm_op_avg MIOpen Improvement over ROCm
Where float16 [4 8 8] contiguous bwd 11693 55810 3732 3,13
Where float16 [16 64 256] contiguous bwd 14428 51283 6754 2,14
Where float16 [32 256 1024] contiguous bwd 107766 129265 74424 1,45
Where float16 [380 114 60] contiguous bwd 43893 61296 26236 1,67
Where float16 [378 482 201] contiguous bwd 429045 455280 314658 1,36
Where float16 [24 131 197] contiguous bwd 18092 49236 9367 1,93
Where float16 [123 329 190] contiguous bwd 101927 126642 69199 1,47
Where float16 [393 183 475] contiguous bwd 403248 430666 291288 1,38
Where float16 [46 62 101] contiguous bwd 14350 55155 7252 1,98
Where float16 [427 499 454] contiguous bwd 1111200 1138201 823626 1,35
float32
op_name dtype size contiguous direction rocm_kernel_avg rocm_op_avg MIOpen Improvement over ROCm
Where float32 [4 8 8] contiguous bwd 11756 57617 3892 3,02
Where float32 [16 64 256] contiguous bwd 15035 43637 7501 2,00
Where float32 [32 256 1024] contiguous bwd 136830 161000 94421 1,45
Where float32 [380 114 60] contiguous bwd 53634 69534 33417 1,60
Where float32 [378 482 201] contiguous bwd 541003 566405 415464 1,30
Where float32 [24 131 197] contiguous bwd 21931 56819 11092 1,98
Where float32 [123 329 190] contiguous bwd 125266 149916 87490 1,43
Where float32 [393 183 475] contiguous bwd 504520 532226 387420 1,30
Where float32 [46 62 101] contiguous bwd 14988 43733 7447 2,01
Where float32 [427 499 454] contiguous bwd 1406867 1431501 1107410 1,27
bfloat16
op_name dtype size contiguous direction rocm_kernel_avg rocm_op_avg MIOpen Improvement over ROCm
Where bfloat16 [4 8 8] contiguous bwd 11501 53699 3697 3,11
Where bfloat16 [16 64 256] contiguous bwd 14238 51059 6950 2,05
Where bfloat16 [32 256 1024] contiguous bwd 108724 123713 73749 1,47
Where bfloat16 [380 114 60] contiguous bwd 44805 72910 26005 1,72
Where bfloat16 [378 482 201] contiguous bwd 430119 458033 308848 1,39
Where bfloat16 [24 131 197] contiguous bwd 18667 49172 9563 1,95
Where bfloat16 [123 329 190] contiguous bwd 101143 122994 68329 1,48
Where bfloat16 [393 183 475] contiguous bwd 404289 433658 287344 1,41
Where bfloat16 [46 62 101] contiguous bwd 14797 45157 7163 2,07
Where bfloat16 [427 499 454] contiguous bwd 1112234 1139333 813237 1,37

@cognaiger9 cognaiger9 requested a review from a team as a code owner October 3, 2024 02:59
@amd-jnovotny
Copy link
Contributor

@cognaiger9 : Do we need to add any material to the ROCm docs to cover this?

@cognaiger9
Copy link
Collaborator Author

@cognaiger9 : Do we need to add any material to the ROCm docs to cover this?

This operation belongs to the joining operations category according to PyTorch documentation. MIOpen doesn't currently have this category, so I added new material. Should I change it to a more general category, or should I use the existing category from the ROCm documentation?

@amd-jnovotny
Copy link
Contributor

@cognaiger9: Oh, in terms of where you added it in reference/index.rst, I think it's fine. I was only wondering if we needed to add extra material to any of the conceptual or how-to documents? (Maybe I'm not sure what you're referring to by the new material?)

@cognaiger9
Copy link
Collaborator Author

@amd-jnovotny I think the current docs is sufficient and does not require extra material

Copy link
Contributor

@iq136boy iq136boy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI log:
3295_log.txt

@long10024070 long10024070 changed the title Implement Where Implement Where Backward Nov 4, 2024
@long10024070 long10024070 changed the title Implement Where Backward Implement Where backward Nov 4, 2024
Copy link
Contributor

@iq136boy iq136boy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error_log:
3295_log (2).txt

std::unique_ptr<GPUMem> inGrad_dev = nullptr;
std::unique_ptr<GPUMem> otherGrad_dev = nullptr;

std::vector<uint8_t> cond;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reason of using type uint8_t for cond tensor instead of bool?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants