Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugs in AWQ models deployed in multiple GPUs. #662

Open
Phoenix-Shen opened this issue Nov 30, 2024 · 1 comment
Open

Bugs in AWQ models deployed in multiple GPUs. #662

Phoenix-Shen opened this issue Nov 30, 2024 · 1 comment

Comments

@Phoenix-Shen
Copy link

Description

I want to use the model qwen/Qwen2.5-32B-Instruct-AWQ and deploy it to 2 4090 24GB GPUs, when I set device_map=“auto”, I get ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) error.

Environment

autoawq=0.2.7.post2
transformers=4.46.3
triton=3.1.0
torch=2.5.1+cu124

How to reproduce this error

  • I have 2 4090 GPUs, the model parameters will take ~20GB gpu memory.
  • Just use the code from https://huggingface.co/Qwen/Qwen2.5-32B-Instruct-AWQ.
  • In my case, there are ~10G of memory occupied per GPU.
    from transformers import AutoModelForCausalLM, AutoTokenizer
    model_name = "Qwen/Qwen2.5-32B-Instruct-AWQ"
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    prompt = "Give me a short introduction to large language model."
    messages = [
        {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  • It will report this error (I've only pasted part of the error message):
      File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 501, in forward
        query_states = self.q_proj(hidden_states)
      File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
        return forward_call(*args, **kwargs)
      File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/accelerate/hooks.py", line 170, in new_forward
        output = module._old_forward(*args, **kwargs)
      File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/awq/modules/linear/gemm.py", line 271, in forward
        out = WQLinearMMFunction.apply(
      File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/autograd/function.py", line 575, in apply
        return super().apply(*args, **kwargs)  # type: ignore[misc]
      File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/awq/modules/linear/gemm.py", line 63, in forward
        out = awq_dequantize_triton(qweight, scales, qzeros)
      File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/awq/modules/triton/gemm.py", line 284, in awq_dequantize_triton
        awq_dequantize_kernel[grid](
      File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/triton/runtime/jit.py", line 345, in <lambda>
        return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
      File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/triton/runtime/jit.py", line 691, in run
        kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
      File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/triton/backends/nvidia/driver.py", line 365, in __call__
        self.launch(*args, **kwargs)
    ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
    

Analysis&Solution

  • If I use a small model, such as Qwen2.5-14B-Instruct-AWQ, the entire code reports no errors because it can be loaded into a single 4090 GPU. It could be that somewhere along the way the Tensor was not moved to the correct device.
  • I have found some solution of this ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) like https://huggingface.co/microsoft/Phi-3-small-8k-instruct/discussions/23 and https://huggingface.co/THUDM/cogagent-chat-hf/blob/d519da3b191401234f4bd86ce1c287c61bc276a3/util.py#L210
  • I found that the solution linked above works just as well for the file site-packages/awq/modules/triton/gemm.py (https://github.com/casper-hansen/AutoAWQ/blob/main/awq/modules/triton/gemm.py)
  • So the solution should be: Add with torch.cuda.device(qweight.device.index): to awq_dequantize_kernel (line 283) and awq_gemm_kernel(line 335).
       # Line 283
       #Add this: with torch.cuda.device(qweight.device.index)
       with torch.cuda.device(qweight.device.index):
           awq_dequantize_kernel[grid](
               qweight,
               scales,
               zeros,
               group_size,
               result,
               X,
               Y,
               BLOCK_SIZE_X=block_size_x,
               BLOCK_SIZE_Y=block_size_y,
           )
       # Line 335
       #Add this: with torch.cuda.device(qweight.device.index)
       with torch.cuda.device(qweight.device.index):
           awq_gemm_kernel[grid](
               input,
               qweight,
               result,
               qzeros,
               scales,
               M,
               N,
               K,
               group_size,
               BLOCK_SIZE_M=block_size_m,
               BLOCK_SIZE_N=block_size_n,
               BLOCK_SIZE_K=block_size_k,
               SPLIT_K=split_k_iters,
           )
  • I've tried it and my method works, but I'm not sure if other files have this type of problem, so I didn't submit a PR
@casper-hansen
Copy link
Owner

Hi @Phoenix-Shen, thanks for the detailed report. Please submit a PR with your changes as I agree that they are needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants