Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in attention code with various sharding configurations #1031

Open
Corendos opened this issue Nov 11, 2024 · 3 comments
Open

Error in attention code with various sharding configurations #1031

Corendos opened this issue Nov 11, 2024 · 3 comments
Labels
bug Something isn't working Inf2

Comments

@Corendos
Copy link

Summary

Hi again! I encountered a bug while playing with attention and sharding in JAX. The issue occurs with specific sharding setups and fails under certain core configurations.

Steps to Reproduce

The following code snippet can reproduce the issue:

import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P

device_count = len(jax.devices())

# Define meshes for sharding
mesh = jax.sharding.Mesh(jax.devices(), ('x'))
mesh2 = jax.sharding.Mesh(np.array(jax.devices()).reshape((device_count // 2, 2)), ('x', 'y'))

# Model and input parameters
BATCH_SIZE = 1
SEQ_LEN = 256
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 256
DIM = HEAD_DIM * NUM_HEADS
KV_DIM = HEAD_DIM * NUM_KV_HEADS
SIZE = 5000
DTYPE = jnp.bfloat16

# Sharding configuration
SHARDING_TYPE = "double"

if SHARDING_TYPE == "simple":
    input_sharding = jax.sharding.NamedSharding(mesh, P(None, 'x'))
    output_sharding = jax.sharding.NamedSharding(mesh, P('x', None))
elif SHARDING_TYPE == "double":
    input_sharding = jax.sharding.NamedSharding(mesh2, P('y', 'x'))
    output_sharding = jax.sharding.NamedSharding(mesh2, P('x', 'y'))
else:
    input_sharding = None
    output_sharding = None

# Attention function
def attention(x: jax.Array, q_proj: jax.Array, k_proj: jax.Array, v_proj: jax.Array, o_proj: jax.Array) -> jax.Array:
    q = jax.lax.dot_general(x, q_proj, (([2], [0]), ([], []))).reshape(BATCH_SIZE, SEQ_LEN, NUM_HEADS, HEAD_DIM)
    k = jax.lax.dot_general(x, k_proj, (([2], [0]), ([], []))).reshape(BATCH_SIZE, SEQ_LEN, NUM_KV_HEADS, HEAD_DIM)
    v = jax.lax.dot_general(x, v_proj, (([2], [0]), ([], []))).reshape(BATCH_SIZE, SEQ_LEN, NUM_KV_HEADS, HEAD_DIM)

    k = jnp.repeat(k, (NUM_HEADS // NUM_KV_HEADS), 2)
    v = jnp.repeat(v, (NUM_HEADS // NUM_KV_HEADS), 2)

    attn = jax.nn.dot_product_attention(q, k, v, is_causal=False).reshape(BATCH_SIZE, SEQ_LEN, DIM)
    out = jax.lax.dot_general(attn, o_proj, (([2], [0]), ([], [])))
    return out

# Test inputs
x = jnp.ones((BATCH_SIZE, SEQ_LEN, SIZE), DTYPE)
q_proj = jnp.ones((SIZE, DIM), DTYPE, device=input_sharding)
k_proj = jnp.ones((SIZE, KV_DIM), DTYPE, device=input_sharding)
v_proj = jnp.ones((SIZE, KV_DIM), DTYPE, device=input_sharding)
o_proj = jnp.ones((DIM, SIZE), DTYPE, device=output_sharding)

lowered = jax.jit(attention).lower(x, q_proj, k_proj, v_proj, o_proj)
compiled = lowered.compile()

print(f"lowered: {lowered.as_text()}")
print(f"compiled: {compiled.as_text()}")
print(compiled(x, q_proj, k_proj, v_proj, o_proj))

Observed Behavior

  1. When SHARDING_TYPE is set to "simple"

    • Run with NEURON_RT_NUM_CORES=16 python xxx.py
    • Fails with this error message:
  2. When SHARDING_TYPE is set to "double"

    • Run with NEURON_RT_NUM_CORES=16 python xxx.py
    • Also fails with this error message:
  3. When SHARDING_TYPE is "none" (sharding disabled)

    • Run with NEURON_RT_NUM_CORES=16 python xxx.py
    • The code works without errors.

Additional Observations

When I increased the number of heads to:

NUM_HEADS = 48
NUM_KV_HEADS = 12

and used NEURON_RT_NUM_CORES=24, the script worked even with sharding enabled. Interestingly, with these higher head counts, setting NEURON_RT_NUM_CORES=16 made the "simple" case work (but not the "double").

Expected Behavior

The code should work consistently across different sharding configurations and core settings.

Environment

  • Python 3.10
  • Packages:
    • neuronx-cc==2.15.141.0+d3cfc8ca
    • libneuronxla==2.0.4986.0
    • jaxlib==0.4.31
    • jax-neuronx==0.1.1
    • jax==0.4.31
  • inf2.48xlarge instance

Additional Information

Please let me know if further details are needed for reproducing or debugging the issue. Thank you!

@aws-taylor aws-taylor added bug Something isn't working Inf2 labels Nov 11, 2024
@devesr-amzn
Copy link
Contributor

Thank you for reporting the issue, we are able to reproduce the issue with the provided sample. We currently do not support all sharding configurations. Meshes which use non-connected devices might result in runtime failures during execution. The topology for inferentia instances can be seen here. Or by using neuron-ls --topology.

@Corendos
Copy link
Author

Thank you for your answer !

So, now that we know for sure that it's not supported, I have some followup questions:

  • Is there a known workaround for that ? Because I fail to see what can be done. To me, sharding is quite opaque from the StableHLO perspective and the collective operations that are produced in HLO are generic. Is something preventing the Neuron compiler to produce collective operations that respect the topology ?
  • Is this something that will be supported in the future ?

@devesr-amzn
Copy link
Contributor

Is there a known workaround for that ? Because I fail to see what can be done. To me, sharding is quite opaque from the StableHLO perspective and the collective operations that are produced in HLO are generic.

It can be handled by using shard_map + jax.lax collectives based APIs.

Is this something that will be supported in the future ?

We will look at fixing collectives for GSPMD support, for supported mesh configurations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Inf2
Projects
None yet
Development

No branches or pull requests

3 participants