You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi again! I encountered a bug while playing with attention and sharding in JAX. The issue occurs with specific sharding setups and fails under certain core configurations.
Steps to Reproduce
The following code snippet can reproduce the issue:
and used NEURON_RT_NUM_CORES=24, the script worked even with sharding enabled. Interestingly, with these higher head counts, setting NEURON_RT_NUM_CORES=16 made the "simple" case work (but not the "double").
Expected Behavior
The code should work consistently across different sharding configurations and core settings.
Environment
Python 3.10
Packages:
neuronx-cc==2.15.141.0+d3cfc8ca
libneuronxla==2.0.4986.0
jaxlib==0.4.31
jax-neuronx==0.1.1
jax==0.4.31
inf2.48xlarge instance
Additional Information
Please let me know if further details are needed for reproducing or debugging the issue. Thank you!
The text was updated successfully, but these errors were encountered:
Thank you for reporting the issue, we are able to reproduce the issue with the provided sample. We currently do not support all sharding configurations. Meshes which use non-connected devices might result in runtime failures during execution. The topology for inferentia instances can be seen here. Or by using neuron-ls --topology.
So, now that we know for sure that it's not supported, I have some followup questions:
Is there a known workaround for that ? Because I fail to see what can be done. To me, sharding is quite opaque from the StableHLO perspective and the collective operations that are produced in HLO are generic. Is something preventing the Neuron compiler to produce collective operations that respect the topology ?
Is this something that will be supported in the future ?
Is there a known workaround for that ? Because I fail to see what can be done. To me, sharding is quite opaque from the StableHLO perspective and the collective operations that are produced in HLO are generic.
It can be handled by using shard_map + jax.lax collectives based APIs.
Is this something that will be supported in the future ?
We will look at fixing collectives for GSPMD support, for supported mesh configurations.
Summary
Hi again! I encountered a bug while playing with attention and sharding in JAX. The issue occurs with specific sharding setups and fails under certain core configurations.
Steps to Reproduce
The following code snippet can reproduce the issue:
Observed Behavior
When
SHARDING_TYPE
is set to"simple"
NEURON_RT_NUM_CORES=16 python xxx.py
When
SHARDING_TYPE
is set to"double"
NEURON_RT_NUM_CORES=16 python xxx.py
When
SHARDING_TYPE
is"none"
(sharding disabled)NEURON_RT_NUM_CORES=16 python xxx.py
Additional Observations
When I increased the number of heads to:
and used
NEURON_RT_NUM_CORES=24
, the script worked even with sharding enabled. Interestingly, with these higher head counts, settingNEURON_RT_NUM_CORES=16
made the"simple"
case work (but not the"double"
).Expected Behavior
The code should work consistently across different sharding configurations and core settings.
Environment
neuronx-cc==2.15.141.0+d3cfc8ca
libneuronxla==2.0.4986.0
jaxlib==0.4.31
jax-neuronx==0.1.1
jax==0.4.31
inf2.48xlarge
instanceAdditional Information
Please let me know if further details are needed for reproducing or debugging the issue. Thank you!
The text was updated successfully, but these errors were encountered: