Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Qwen2 (Llama2 based) #101

Open
wants to merge 5 commits into
base: release-2.20
Choose a base branch
from

Conversation

bevhanno
Copy link

Adding Qwen2 model module that hast been tested with Qwen2-7B.

The Qwen2 module is based on the Llama2 module and differs in following points:

  • the QwenAttention module has bias=True in the QKV projection
  • the QKV projection bias needs to be loaded in the load_weights function of the model

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@weiliw-amz
Copy link

weiliw-amz commented Dec 12, 2024

Does this work for Qwen2.5? I tried Qwen/Qwen2.5-0.5B-Instructand received following errors:

Code to reproduce:

model_name = "Qwen/Qwen2.5-0.5B-Instruct"

n_neuron_cores = 1
os.environ["NEURON_RT_NUM_CORES"] = str(n_neuron_cores)

batch_size = 16
sequence_length = 2048

neuron_config = NeuronConfig(
    continuous_batching=ContinuousBatchingConfig(batch_size_for_shared_caches=batch_size),
    group_query_attention=constants.GQA.SHARD_OVER_HEADS,
    attention_layout="HSB",
    fuse_qkv=True,
)

model_neuron = Qwen2ForSampling.from_pretrained(
    model_name,
    batch_size=batch_size,
    tp_degree=n_neuron_cores,
    n_positions=sequence_length,
    amp="f16",
    neuron_config=neuron_config,
)
model_neuron.to_neuron()

Errors:

File [/opt/aws_neuronx_venv_transformers_neuronx/lib64/python3.9/site-packages/transformers_neuronx/base.py:81](https://localhost:9999/opt/aws_neuronx_venv_transformers_neuronx/lib64/python3.9/site-packages/transformers_neuronx/base.py#line=80), in NeuronModelBase.to_neuron(self)
     79     self.load_presharded_weights()
     80 else:
---> 81     self.load_weights()
     82 if hasattr(self, "_compiled_artifacts_directory"):
     83     self._load_compiled_artifacts(self._compiled_artifacts_directory)

File [~/neuronx_code/bevhanno-transformers-neuronx/src/transformers_neuronx/qwen2/model.py:159](https://localhost:9999/home/ec2-user/neuronx_code/bevhanno-transformers-neuronx/src/transformers_neuronx/qwen2/model.py#line=158), in Qwen2ForSampling.load_weights(self)
    156             else:
    157                 new_layer.add_parameter(mlp.down_proj.weight, sharding=1, allow_pad=True,
    158                                     allow_quantize=True, out_feature_dim=0)
--> 159     new_layer.to_neuron()
    160     layer.nullify()
    161 if self.neuron_config.shard_over_sequence:

File [/opt/aws_neuronx_venv_transformers_neuronx/lib64/python3.9/site-packages/transformers_neuronx/decoder.py:1433](https://localhost:9999/opt/aws_neuronx_venv_transformers_neuronx/lib64/python3.9/site-packages/transformers_neuronx/decoder.py#line=1432), in DecoderLayer.to_neuron(self)
   1431 fused_qkv_weight = interleave_qkv(self.attn_q_weight, self.attn_k_weight, self.attn_v_weight, self.tp_degree, dim=1)
   1432 if self.attn_q_bias is not None:
-> 1433     fused_qkv_bias = interleave_qkv(self.attn_q_bias, self.attn_k_bias, self.attn_v_bias, self.tp_degree, dim=0)
   1434 else:
   1435     fused_qkv_bias = None

File [/opt/aws_neuronx_venv_transformers_neuronx/lib64/python3.9/site-packages/transformers_neuronx/utils.py:229](https://localhost:9999/opt/aws_neuronx_venv_transformers_neuronx/lib64/python3.9/site-packages/transformers_neuronx/utils.py#line=228), in interleave_qkv(q, k, v, tp_degree, dim)
    227             tensor[:, (idx)*shard.shape[dim]:(idx+1)*shard.shape[dim]] = shard
    228 else:
--> 229     q_hidden_dim, q_interleave_dim = q.shape
    230     _, kv_interleave_dim = k.shape
    231     tensor = torch.zeros((q_hidden_dim, q_interleave_dim + kv_interleave_dim * 2), dtype=q.dtype)

ValueError: not enough values to unpack (expected 2, got 1)

@bevhanno
Copy link
Author

Does this work for Qwen2.5? I tried Qwen/Qwen2.5-0.5B-Instructand received following errors:

Yes I remember I tested 2.5 as well. Does the original Qwen2 config work for you ? What TP degree did you use?

@weiliw-amz
Copy link

Does this work for Qwen2.5? I tried Qwen/Qwen2.5-0.5B-Instructand received following errors:

Yes I remember I tested 2.5 as well. Does the original Qwen2 config work for you ? What TP degree did you use?

My TP degree in the test code is 1. You can take a look at my code above for full configurations.

@bevhanno
Copy link
Author

My TP degree in the test code is 1. You can take a look at my code above for full configurations.
Please try TP=2 (there are two cores per device), not tested unsharded as that's not an option for 7B

@weiliw-amz
Copy link

weiliw-amz commented Dec 13, 2024

OK, I figured out the configuration error for Qwen2.5-0.5B.

  • For TP=1 one cannot configure group_query_attention=GQA.SHARD_OVER_HEADS and fuse_qkv=True at the same time, otherwise run into this compiling error. However, one can configure it to group_query_attention=GQA.REPLICATED_HEADS and fuse_qkv=True and it passes the compiling.

  • For TP=2 one cannot configure fuse_qkv=True, otherwise run into this error.

  • For TP>2 I haven't got a working configuration, all report this error:

    RuntimeError: nrt_tensor_write status=1 message="Unknown Failure"
    

For Qwen2-7B, fuse_qkv=True can work for TP=2.

A little weird to me... Is there an explanation for this configuration?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants