Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support batch size 0 #101

Open
Birch-san opened this issue Aug 18, 2024 · 0 comments
Open

Support batch size 0 #101

Birch-san opened this issue Aug 18, 2024 · 0 comments

Comments

@Birch-san
Copy link

Downstream PyTorch issue:
pytorch/pytorch#133780

Describe the bug
cuDNN frontend rejects batch_size=0 input with CUDNN_STATUS_BAD_PARAM

Expected behavior
cuDNN should return to me a tensor [0, num_head, sequence_length, dims_per_head]
something like that, maybe the heads/seq are permuted differently, but the important part is that batch_size would be 0.
it would have the same dimensions as Q.
you could even just return Q, probably.

System Environment (please complete the following information):
Accessing cuDNN via torch sdpa,
PyTorch 2.5.0.dev20240811+cu121

torch.backends.cudnn.version()
90100
  • cudnn_frontend version: not sure how to look this up in PyTorch
  • cudnn_backend version: 90100
  • GPU arch: RTX 4090
  • cuda runtime version: PyTorch bundled 12.1 (though 12.2 is installed on the system)
  • cuda driver version: 535.183.01
  • host compiler: g++ (Ubuntu 12.3.0-17ubuntu1) 12.3.0
  • OS: ubuntu24.04

API logs
Please attach API logs for both cudnn_frontend and cudnn_backend.

[cudnn_frontend] INFO: Validating SDPANode CUDNN_SDPA...
[cudnn_frontend] INFO: Inferrencing properties for Scaled_dot_product_flash_attention node  CUDNN_SDPA...
[cudnn_frontend] INFO: Validating matmul node bmm1...
[cudnn_frontend] INFO: Inferrencing properties for matmul node bmm1...
[cudnn_frontend] INFO: Validating pointwise node attn_scale...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node attn_scale...
[cudnn_frontend] INFO: Validating SoftmaxNode softmax...
[cudnn_frontend] INFO: Inferrencing properties for Softmax node softmax.
[cudnn_frontend] INFO: Validating reduction node M...
[cudnn_frontend] INFO: Inferrencing properties for reduction node M...
[cudnn_frontend] INFO: Validating pointwise node sub...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node sub...
[cudnn_frontend] INFO: Validating pointwise node exp...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node exp...
[cudnn_frontend] INFO: Validating reduction node sum...
[cudnn_frontend] INFO: Inferrencing properties for reduction node sum...
[cudnn_frontend] INFO: Validating pointwise node log...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node log...
[cudnn_frontend] INFO: Validating pointwise node add...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node add...
[cudnn_frontend] INFO: Validating pointwise node div...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node div...
[cudnn_frontend] INFO: Validating matmul node bmm2...
[cudnn_frontend] INFO: Inferrencing properties for matmul node bmm2...
[cudnn_frontend] INFO: Creating cudnn tensors for node named 'CUDNN_SDPA':
[cudnn_frontend] ERROR: CUDNN_BACKEND_TENSOR_DESCRIPTOR: Check and Set the CUDNN_ATTR_TENSOR_DIMENSIONS Correctly cudnn_status: CUDNN_STATUS_BAD_PARAM. ["CUDNN_BACKEND_API_FAILED"] because (e.getCudnnStatus() != CUDNN_STATUS_SUCCESS) at /pytorch/third_party/cudnn_frontend/include/cudnn_frontend/cudnn_interface.h:89
[cudnn_frontend] ERROR: create_cudnn_tensor(tensor, tensors) at /pytorch/third_party/cudnn_frontend/include/cudnn_frontend/node_interface.h:707
[cudnn_frontend] ERROR: create_cudnn_tensors_(uid_to_backend_tensors) at /pytorch/third_party/cudnn_frontend/include/cudnn_frontend/node_interface.h:316
[cudnn_frontend] ERROR: sub_node->create_cudnn_tensors(uid_to_backend_tensors) at /pytorch/third_party/cudnn_frontend/include/cudnn_frontend/node_interface.h:318
[cudnn_frontend] ERROR: create_cudnn_tensors(uid_to_tensors) at /pytorch/third_party/cudnn_frontend/include/cudnn_frontend/node_interface.h:408


I! CuDNN (v90100 70) function cudnnCreate() called:
i!     handle: location=host; addr=0x582fcf32a680;
i! Time: 2024-08-18T02:59:15.167979 (0d+0h+0m+0s since start)
i! Process=7823; Thread=7823; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v90100 70) function cudnnGraphLibraryConfigInit() called:
i!     apiLog: type=cudnnLibConfig_t; val=CUDNN_STANDARD;
i! Time: 2024-08-18T02:59:15.168030 (0d+0h+0m+0s since start)
i! Process=7823; Thread=7823; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v90100 70) function cudnnSetStream() called:
i!     handle: type=cudnnHandle_t; streamId=(nil) (defaultStream);
i!     streamId: type=cudaStream_t; streamId=(nil) (defaultStream);
i! Time: 2024-08-18T02:59:15.169306 (0d+0h+0m+0s since start)
i! Process=7823; Thread=7823; GPU=0; Handle=0x582fcf7a4130; StreamId=(nil) (defaultStream).


I! CuDNN (v90100 70) function cudnnGetVersion() called:
i! Time: 2024-08-18T02:59:15.169405 (0d+0h+0m+0s since start)
i! Process=7823; Thread=7823; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v90100 70) function cudnnGetVersion() called:
i! Time: 2024-08-18T02:59:15.169445 (0d+0h+0m+0s since start)
i! Process=7823; Thread=7823; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v90100 70) function cudnnGetVersion() called:
i! Time: 2024-08-18T02:59:15.169458 (0d+0h+0m+0s since start)
i! Process=7823; Thread=7823; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v90100 70) function cudnnGetVersion() called:
i! Time: 2024-08-18T02:59:15.169468 (0d+0h+0m+0s since start)
i! Process=7823; Thread=7823; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v90100 70) function cudnnGetVersion() called:
i! Time: 2024-08-18T02:59:15.169483 (0d+0h+0m+0s since start)
i! Process=7823; Thread=7823; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v90100 70) function cudnnGetVersion() called:
i! Time: 2024-08-18T02:59:15.169494 (0d+0h+0m+0s since start)
i! Process=7823; Thread=7823; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v90100 70) function cudnnGetVersion() called:
i! Time: 2024-08-18T02:59:15.169516 (0d+0h+0m+0s since start)
i! Process=7823; Thread=7823; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v90100 70) function cudnnGetVersion() called:
i! Time: 2024-08-18T02:59:15.169533 (0d+0h+0m+0s since start)
i! Process=7823; Thread=7823; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v90100 70) function cudnnGetErrorString() called:
i!     status: type=int; val=2000;
i! Time: 2024-08-18T02:59:15.169624 (0d+0h+0m+0s since start)
i! Process=7823; Thread=7823; GPU=NULL; Handle=NULL; StreamId=NULL.


To Reproduce
Steps to reproduce the behavior:

import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel

device = torch.device('cuda')
dtype = torch.float16

batch = 0
q_heads = kv_heads = 10
q_tokens = 3952
kv_tokens = 16
head_dim = 64
q = torch.zeros(batch, q_heads, q_tokens, head_dim, device=device, dtype=dtype)
k = torch.zeros(batch, kv_heads, kv_tokens, head_dim, device=device, dtype=dtype)
v = torch.zeros(batch, kv_heads, kv_tokens, head_dim, device=device, dtype=dtype)

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    scaled_dot_product_attention(q, k, v)

Additional context
Add any other context about the problem here.

I'm trying to do attention on a batch-of-zero, because my program uses a static graph and I rely on zero-batching (index_select zero-batch of inputs, index_add zero-batch of outputs) to toggle functionality without adding branches to the logic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant