Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[ Misc ] Refactor w8a8 to use process_weights_after_load (Simplify …
Browse files Browse the repository at this point in the history
…Weight Loading) (vllm-project#5940)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>
  • Loading branch information
robertgshaw2-neuralmagic and Robert Shaw committed Jul 1, 2024
1 parent 4153e58 commit 27a711a
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 158 deletions.
28 changes: 19 additions & 9 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,22 @@
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType)

if should_skip_test_group(group_name="TEST_QUANTIZATION"):
pytest.skip("TEST_QUANTIZATION=DISABLE, skipping quantization test group",
allow_module_level=True)


@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor"),
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel"),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
QuantizationType.INT, 2560),
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
QuantizationType.INT, 2560),
])
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy = model_args
model_path, strategy, quant_type, shape_0 = model_args
with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
Expand All @@ -39,17 +43,23 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
CompressedTensorsLinearMethod)
assert isinstance(down_proj.quant_method,
CompressedTensorsLinearMethod)

assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)

assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8
assert o_proj.weight.dtype is torch.int8
assert gate_up_proj.weight.dtype is torch.int8
expected_type = (torch.int8 if quant_type == QuantizationType.INT else
torch.float8_e4m3fn)

assert qkv_proj.weight.dtype is expected_type
assert o_proj.weight.dtype is expected_type
assert gate_up_proj.weight.dtype is expected_type

if qkv_proj.scheme.strategy == "tensor":
assert qkv_proj.weight_scale.shard_splitter is not None
assert qkv_proj.weight_scale.logical_widths is not None
# Make sure it is a channelwise buffer
# After running process_weights_after_loading
assert len(qkv_proj.weight_scale.shape) == 2
assert qkv_proj.weight_scale.shape[0] == shape_0
assert qkv_proj.weight_scale.shape[1] == 1
assert qkv_proj.weight_scale.dtype is torch.float32
assert qkv_proj.input_scale.dtype is torch.float32


Expand Down
17 changes: 17 additions & 0 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,23 @@
pytest.skip("TEST_QUANTIZATION=DISABLE, skipping quantization test group",
allow_module_level=True)

MODELS = [
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8",
"nm-testing/Phi-3-mini-128k-instruct-FP8",
]


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
def test_model_load_and_run(vllm_runner, model: str):
with vllm_runner(model) as llm:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
outputs = llm.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)
print(outputs[0][1])


@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
Expand Down
110 changes: 39 additions & 71 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,29 @@ def adjust_bitsandbytes_shard(param: Parameter,
return quantized_size, quantized_offset


def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
"""For fused modules (QKV and MLP) we have an array of length
N that holds 1 scale for each "logical" matrix. So the param
is an array of length N. The loaded_weight corresponds to
one of the shards on disk. Here, we slice the param based on
the shard_id for loading.
"""
qkv_idxs = {"q": 0, "k": 1, "v": 2}

if isinstance(shard_id, str):
shard_id = qkv_idxs[shard_id]
elif not isinstance(shard_id, int):
raise ValueError(f"Unknown Shard Id {shard_id}")

# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if len(loaded_weight.shape) != 0:
assert loaded_weight.shape[0] == 1
loaded_weight = loaded_weight[0]

return param[shard_id], loaded_weight


class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""

Expand Down Expand Up @@ -366,37 +389,15 @@ def weight_loader(self,
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)

param_shard_splitter = getattr(param, "shard_splitter", None)

if output_dim is not None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if loaded_shard_id is None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)

# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)

if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv/mlp).
if output_dim is None:
# If fp8 + scale, need to send to each shard.
if fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
if needs_scalar_to_array is not None:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
Expand Down Expand Up @@ -458,15 +459,9 @@ def weight_loader(self,
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)

# If a param_shard_splitter is defined by the LinearMethod, use it.
elif param_shard_splitter is not None:
logical_widths = getattr(param, "logical_widths", None)
param_data, loaded_weight = param_shard_splitter(
param_data, loaded_weight, loaded_shard_id, logical_widths)

# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, loaded_shard_id)

else:
Expand Down Expand Up @@ -569,36 +564,15 @@ def weight_loader(self,
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)

param_shard_splitter = getattr(param, "shard_splitter", None)

if output_dim is not None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if loaded_shard_id is None and param_shard_splitter is not None:
raise NotImplementedError(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)

# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)

if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv/mlp).
if output_dim is None:
# If fp8 + scale, need to send to each shard.
if fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
if needs_scalar_to_array is not None:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
Expand Down Expand Up @@ -688,15 +662,9 @@ def weight_loader(self,
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif param_shard_splitter is not None:
logical_widths = getattr(param, "logical_widths", None)
param_data, loaded_weight = param_shard_splitter(
param_data, loaded_weight, loaded_shard_id, logical_widths)

# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, loaded_shard_id)
else:
ignore_warning = getattr(param, "ignore_warning", False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return layer.scheme.process_weights_after_loading(layer)

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,11 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
"""
raise NotImplementedError

@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation.
"""

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self,
raise ValueError(
"group_size must be given when using strategy group")

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,70 +15,63 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
def __init__(self, strategy: str):
self.strategy = strategy

def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id

assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]

def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset:offset + size], loaded_weight
# Cutlass kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), we convert to the per-channel case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if (self.strategy == QuantizationStrategy.TENSOR
and len(self.logical_widths) > 1):

# Load the N per-tensor scales into the channelwise buffer.
weight_scale_channel = torch.empty(
(sum(self.logical_widths), 1),
dtype=torch.float32,
device=layer.weight_scale.device)
start = 0
for idx, logical_width in enumerate(self.logical_widths):
end = start + logical_width
weight_scale_channel[start:end, :] = layer.weight_scale[idx]
start = end

layer.weight_scale = Parameter(weight_scale_channel,
requires_grad=False)

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.logical_widths = output_partition_sizes

is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(output_partition_sizes) if (
is_tensor_partitioned
or self.strategy == QuantizationStrategy.CHANNEL) else 1

shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
# WEIGHT SCALE
shape: Union[Tuple[int], Tuple[int, int]]
if self.strategy == QuantizationStrategy.CHANNEL:
shape = (weight_scale_dim, 1)
shape = (sum(self.logical_widths), 1)
else:
shape = (len(self.logical_widths), )

weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
requires_grad=False)

layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
if self.strategy == QuantizationStrategy.CHANNEL:
set_weight_attrs(weight_scale, {
"weight_loader": weight_loader,
"output_dim": 0,
})
else:
set_weight_attrs(weight_scale, {
"weight_loader": weight_loader,
"needs_scalar_to_array": True,
})

# WEIGHT
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)

layer.register_parameter("weight", weight)
set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
"logical_widths": output_partition_sizes
})

# Don't need a shard_splitter for channel-wise quantization
# Use the default loading method
if self.strategy == QuantizationStrategy.CHANNEL:
set_weight_attrs(weight_scale, {
"output_dim": 0,
})
else:
set_weight_attrs(
weight_scale, {
"logical_widths": output_partition_sizes,
"shard_splitter": self.scales_shard_splitter,
})
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self,
raise ValueError(
"group_size must be given when using strategy group")

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
Expand Down
Loading

0 comments on commit 27a711a

Please sign in to comment.