Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
remove generalized pack routine, save for future PR, plus add removed…
Browse files Browse the repository at this point in the history
… init file
  • Loading branch information
LucasWilkinson committed Jul 31, 2024
1 parent ef704ac commit fc58012
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 31 deletions.
5 changes: 3 additions & 2 deletions tests/kernels/test_marlinv2_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_weights_into_int32, quantize_weights)
pack_rows, quantize_weights, pack_weights_into_int32)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types

Expand Down Expand Up @@ -68,7 +68,8 @@ def marlinv2_quantize_and_pack(w: torch.Tensor,
# to match how the kernel applies zps
ref_zero_points_after_scales=True)

w_q = pack_weights_into_int32(w_q, wtype)
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
w_q = w_q.t().contiguous().t() # convert to col major
w_q_marlinv2 = ops.marlinv2_prepack_B(w_q, wtype)

return w_ref, w_q_marlinv2, w_s, w_zp
Expand Down
Empty file.
30 changes: 1 addition & 29 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""This file is used for /tests and /benchmarks"""
from typing import List
from typing import List, Union

import numpy
import torch
Expand Down Expand Up @@ -229,34 +229,6 @@ def pack_rows(
return q_res


def pack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
dim: int = 0):
orig_device = w_q.device

# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != dim], dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
w_q_perm = w_q.permute(perm)

w_q_perm = w_q_perm.cpu().numpy().astype(numpy.uint32)
pack_factor = 32 // wtype.size_bits
mask = (1 << wtype.size_bits) - 1

new_shape_perm = list(w_q_perm.shape)
new_shape_perm[-1] //= pack_factor
assert new_shape_perm[-1] % pack_factor == 0

w_q_res = numpy.zeros(new_shape_perm, dtype=numpy.uint32)
for i in range(pack_factor):
w_q_res |= (w_q_perm[..., i::pack_factor]
& mask) << wtype.size_bits * i

w_q_res = torch.from_numpy(w_q_res.astype(numpy.int32)).to(orig_device)
w_q_res = w_q_res.permute(inv_perm)

return w_q_res


def pack_cols(
q_w: torch.Tensor,
Expand Down

0 comments on commit fc58012

Please sign in to comment.