Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Got Error when I load a 2of4 model using vllm. #926

Open
jiangjiadi opened this issue Nov 19, 2024 · 13 comments
Open

Got Error when I load a 2of4 model using vllm. #926

jiangjiadi opened this issue Nov 19, 2024 · 13 comments
Assignees
Labels
bug Something isn't working

Comments

@jiangjiadi
Copy link

Describe the bug
I'm compressing a qwen2.5_7b model using examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py, but I failed to load the stage_sparsity model. The error is shown below:
image
And when I use the stage_quantization model to inference using vllm, the output is abnormal. See below:
image

Expected behavior
The stage_sparsity model should be loaded normally and the output of the stage_quantization model should be normal.

Environment
Include all relevant environment information:

  1. OS [e.g. Ubuntu 20.04]:
  2. Python version [e.g. 3.7]: 3.10.13
  3. LLM Compressor version or commit hash [e.g. 0.1.0, f7245c8]: 0.3.0
  4. ML framework version(s) [e.g. torch 2.3.1]:
  5. Other Python package versions [e.g. vLLM, compressed-tensors, numpy, ONNX]: compressed-tensors=0.8.0
  6. Other relevant environment information [e.g. hardware, CUDA version]:

To Reproduce
Exact steps to reproduce the behavior:

  • Setting model_stub = Qwen/Qwen2.5-7B-Instruct, then run examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py to get the model.
  • using LLM(model_path) to load the model and inference.
@jiangjiadi jiangjiadi added the bug Something isn't working label Nov 19, 2024
@dsikka
Copy link
Collaborator

dsikka commented Nov 19, 2024

@jiangjiadi Hi! Can you share the config that you're running in vllm? Running a model with 2:4 sparsity will require setting the dtype to float16 explicitly as that is required by the kernel. Can you try updating that as well?
Thanks

@dsikka dsikka self-assigned this Nov 19, 2024
@jiangjiadi
Copy link
Author

@dsikka Here is my config:

sparsity_stage:
  run_type: oneshot
  sparsity_modifiers:
    SparseGPTModifier:
      sparsity: 0.5
      mask_structure: "2:4"
      sequential_update: false
quantization_stage:
  run_type: oneshot
  quantization_modifiers:
    GPTQModifier:
      ignore: ["lm_head"]
      config_groups:
        group_0:
          weights:
            num_bits: 4
            type: "int"
            symmetric: true
            strategy: "channel"
          targets: ["Linear"]

Setting the dtype to float16 doesn't works.

@dsikka
Copy link
Collaborator

dsikka commented Nov 19, 2024

Hi @jiangjiadi this is the recipe. Do you mind sharing the config.json file?

@jiangjiadi
Copy link
Author

@dsikka The config.json for stage_sparsity model:

{
  "_name_or_path": "/models/Qwen__Qwen2.5-7B-Instruct",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 3584,
  "initializer_range": 0.02,
  "intermediate_size": 18944,
  "max_position_embeddings": 32768,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 28,
  "num_hidden_layers": 28,
  "num_key_value_heads": 4,
  "quantization_config": {
    "quant_method": "compressed-tensors",
    "sparsity_config": {
      "format": "sparse-bitmask",
      "global_sparsity": 0.42841499382998344,
      "ignore": [
        "lm_head"
      ],
      "registry_requires_subclass": false,
      "sparsity_structure": "2:4",
      "targets": [
        "Linear"
      ]
    },
    "version": "0.8.0"
  },
  "rms_norm_eps": 1e-06,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.44.2",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 152064
}

The config.json for stage_quantization model:

{
  "_name_or_path": "/models/Qwen__Qwen2.5-7B-Instruct",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 3584,
  "initializer_range": 0.02,
  "intermediate_size": 18944,
  "max_position_embeddings": 32768,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 28,
  "num_hidden_layers": 28,
  "num_key_value_heads": 4,
  "quantization_config": {
    "config_groups": {
      "group_0": {
        "input_activations": null,
        "output_activations": null,
        "targets": [
          "Linear"
        ],
        "weights": {
          "actorder": null,
          "block_structure": null,
          "dynamic": false,
          "group_size": null,
          "num_bits": 4,
          "observer": "minmax",
          "observer_kwargs": {},
          "strategy": "channel",
          "symmetric": true,
          "type": "int"
        }
      }
    },
    "format": "marlin-24",
    "global_compression_ratio": 1.8962897500663005,
    "ignore": [
      "lm_head"
    ],
    "kv_cache_scheme": null,
    "quant_method": "compressed-tensors",
    "quantization_status": "compressed",
    "sparsity_config": {
      "format": "dense",
      "global_sparsity": 0.45400111152191547,
      "ignore": [
        "lm_head"
      ],
      "registry_requires_subclass": false,
      "sparsity_structure": "2:4",
      "targets": [
        "Linear"
      ]
    }
  },
  "rms_norm_eps": 1e-06,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.44.2",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 152064
}

@dsikka
Copy link
Collaborator

dsikka commented Nov 20, 2024

We only support running the model with both sparsity and quantization at the moment so you will only be able to run the model produced after the quantization stage - could you try running this model in vllm with dtype=torch.float? Thanks

@jiangjiadi
Copy link
Author

@dsikka I changed the torch_dtype to 'float16' in config.json of stage_quantization model,
image
And the inference result is still abnormal :
image

@dsikka
Copy link
Collaborator

dsikka commented Nov 20, 2024

Hi @jiangjiadi sorry for being unclear, please change the dtype when calling vllm:
Example:

from vllm import LLM
llm = LLM(
    model=model,
    dtype=torch.float16,
)

@jiangjiadi
Copy link
Author

@dsikka Result is same.

@jiangjiadi
Copy link
Author

jiangjiadi commented Nov 20, 2024

@dsikka I explicitly set save_compressed=False
image
and got an uncompressed model.
I used the code in examples/compressed_inference/fp8_compressed_inference.py to inference. (I have set run_compressed to False) The I found the uncompressed model has a normal output.
image
At the same time, I got an error when inference with the compressed model:
image

There must be something wrong when compressing model.

When I use the code below

import torch
from transformers import AutoModelForCausalLM

compressed_output_dir = "output_llama7b_2of4_w4a16_channel_compressed"
model = AutoModelForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16)
model.save_pretrained(compressed_output_dir, save_compressed=True)

to compress the uncompressed model, I encounter error
image
What's wrong?

Besides, when I use VLLM to load the uncompressed model, I also encounter error
image

@jiangjiadi
Copy link
Author

Here is the config.json of the uncompressed model.

{
  "_name_or_path": "/ossfs/workspace/EngAI/models/Qwen__Qwen2.5-7B-Instruct",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 3584,
  "initializer_range": 0.02,
  "intermediate_size": 18944,
  "max_position_embeddings": 32768,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 28,
  "num_hidden_layers": 28,
  "num_key_value_heads": 4,
  "quantization_config": {
    "config_groups": {
      "group_0": {
        "input_activations": null,
        "output_activations": null,
        "targets": [
          "Linear"
        ],
        "weights": {
          "actorder": null,
          "block_structure": null,
          "dynamic": false,
          "group_size": null,
          "num_bits": 4,
          "observer": "minmax",
          "observer_kwargs": {},
          "strategy": "channel",
          "symmetric": true,
          "type": "int"
        }
      }
    },
    "format": "dense",
    "global_compression_ratio": 1.8908578382098262,
    "ignore": [
      "lm_head"
    ],
    "kv_cache_scheme": null,
    "quant_method": "compressed-tensors",
    "quantization_status": "frozen",
    "sparsity_config": {
      "format": "dense",
      "global_sparsity": 0.45400111152191547,
      "ignore": [
        "lm_head"
      ],
      "registry_requires_subclass": false,
      "sparsity_structure": "2:4",
      "targets": [
        "Linear"
      ]
    }
  },
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.2",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 152064
}

The only difference between compressed and uncompressed is the format in quantization_config.

@jiangjiadi
Copy link
Author

@dsikka Further investigation revealed that there was no issue with saving the model parameters; the problem lay in loading the model parameters. When calling 'SparseAutoModelForCausalLM.from_pretrained', the model parameters were not being loaded back.
image

@dsikka
Copy link
Collaborator

dsikka commented Nov 22, 2024

HI @jiangjiadi - we do not support decompression of marlin-24 models in compressed-tensors as of yet. You should be able to load the model in vllm however. Do you mind sharing the code you're using to run it in vllm?

@jiangjiadi
Copy link
Author

jiangjiadi commented Nov 22, 2024

@dsikka Sure, you can follow the steps below to reproduce my issue step by step.

  • Env: llmcompressor==0.3.0, transformers==4.46.3, vllm==0.5.5
  • Run examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py to get the model with the following configuration:
    • Setting model_stub = Qwen/Qwen2.5-7B-Instruct.
    • Using the recipe below:
sparsity_stage:
 run_type: oneshot
 sparsity_modifiers:
   SparseGPTModifier:
     sparsity: 0.5
     mask_structure: "2:4"
     sequential_update: false
quantization_stage:
 run_type: oneshot
 quantization_modifiers:
   GPTQModifier:
     ignore: ["lm_head"]
     config_groups:
       group_0:
         weights:
           num_bits: 4
           type: "int"
           symmetric: true
           strategy: "channel"
         targets: ["Linear"]
  • Using VLLM to inference with the code below, then you will get the abnormal output: (please set MODEL_PATH to the stage_quantization model yourself)
import torch
from vllm import LLM
from vllm.sampling_params import SamplingParams

MODEL_PATH = ""
model = LLM(MODEL_PATH, tensor_parallel_size=1, dtype=torch.float16)

max_tokens = 500
sample_params = SamplingParams(max_tokens=max_tokens, ignore_eos=False, temperature=0.0)

text = "Who are you?"
outputs = model.generate([text], sampling_params=sample_params, prompt_token_ids=None)
print("\n", outputs[0].outputs[0].text, "\n")

image

  • Using SparseAutoModelForCausalLM.from_pretrained to load the stage_quantization model, I found the model parameters were not being loaded back. The code is below:
import torch
import time
from llmcompressor.transformers import SparseAutoModelForCausalLM

start_time = time.time()
MODEL_PATH = ""

model = SparseAutoModelForCausalLM.from_pretrained(
    MODEL_PATH, torch_dtype="auto", device_map="cpu"
)
print(f"load model cost: {time.time() - start_time} s")

key = "model.layers.0.mlp.down_proj.weight"
tensor = None
for n, p in model.named_parameters():
    if n == key:
        tensor = p
        break
print(tensor)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants