From eff700935b3f739c741053d3a7cfb0f050ba5db2 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Fri, 13 Dec 2024 10:56:55 +0800 Subject: [PATCH] further exp of hqq q4_0 --- .../transformers/npu_models/convert.py | 10 +- .../transformers/npu_models/quantize.py | 110 ++++++++++++++++-- 2 files changed, 106 insertions(+), 14 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 4fcd3f901ad..7cedf31b4ba 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -109,11 +109,11 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, enable_scale_search=enable_scale_search, imatrix=imatrix) if qtype == "sym_int4_rtn" and os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0": - from .quantize import update_scale_grid_search - # scale grid search - qweights, scale = update_scale_grid_search(layer.weight.data.to(torch.float32), - (1.0 / scale.to(torch.float32)), - [-8, 7]) + from .quantize import update_scale_inverse_median + # scale search by hqq + qweights, scale = update_scale_inverse_median(layer.weight.data.to(torch.float32), + (1.0 / scale.to(torch.float32)), + [-8, 7]) zero = None # split scale to scale & zero if qtype == "asym_int4_rtn": diff --git a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py index c47df92d340..68430dbef38 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py @@ -33,17 +33,18 @@ import torch from torch import Tensor +import numpy as np def c_round(x: Tensor): return torch.sign(x) * torch.floor(torch.abs(x) + 0.5) + def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = 128 + 1): iscale = iscale.unsqueeze(1) assert N % 2 == 1, "Please check whether N: odd number" rng_dump = 0.05 # 0.05 / 1. - z_val = 2e-4 device = iscale.device dtype = iscale.dtype @@ -58,14 +59,6 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = .repeat(n_clusters, 1) ) + iscale - # Safe inverse - iscale_shifted[ - torch.logical_and(iscale_shifted >= 0, torch.abs(iscale_shifted) <= z_val) - ] = z_val - iscale_shifted[ - torch.logical_and(iscale_shifted < 0, torch.abs(iscale_shifted) <= z_val) - ] = -z_val - err = torch.empty([n_clusters, N], dtype=dtype, device=device) for i in range(N): W_r = W_q * iscale_shifted[:, i][:, None] @@ -92,3 +85,102 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = qweights = high_bit | low_bit return qweights.view(torch.uint8), scale_b.to(torch.float16) + + +# Shrinking operator +def shrink_lp_op(x: Tensor, beta: float, lp_norm: float) -> Tensor: + if lp_norm == 1: + return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) + else: + return torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1) + ) + + +def update_scale_hqq(x: Tensor, iscale: Tensor, min_max: list): + iscale = iscale.unsqueeze(1) + opt_params: dict = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} + lp_norm, beta, kappa, iters = ( + opt_params["lp_norm"], + opt_params["beta"], + opt_params["kappa"], + opt_params["iters"], + ) + z_val = 1e-4 + delta = 1e-4 + + best_error = 1e4 + for i in range(iters): + W_q = c_round(x * iscale).clamp(min_max[0], min_max[1]) + W_q_mask = W_q == 0 + W_q[W_q_mask] = delta + W_r = W_q / iscale + W_e = shrink_lp_op(x - W_r, beta, lp_norm) + W_ = (x - W_e).clone() + W_mask = torch.abs(W_) < z_val + W_[W_mask] = z_val + iscale, _ = torch.median(W_q / W_q, axis=1, keepdim=True) + beta *= kappa + + current_error = float(torch.abs(x - W_r).mean()) + if current_error < best_error: + best_error = current_error + else: + break + + scale_b = 1.0 / iscale + qweights = (c_round(x * iscale)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n + qweights = qweights.reshape(x.shape[0], -1 , 2) # m * n/2 * 2 + low_bit, high_bit = qweights.split(1, dim=-1) + high_bit = high_bit.squeeze().view(torch.int8) + low_bit = low_bit.squeeze().view(torch.int8) + high_bit = high_bit << 4 + low_bit = low_bit & 0x0f + qweights = high_bit | low_bit + + return qweights.view(torch.uint8), scale_b.to(torch.float16) + + + +# re-estimate the scale based on the inverse median: Only tested with axis==0 +def update_scale_inverse_median( + W_f: Tensor, iscale: Tensor, min_max: list +) -> tuple: + iscale = iscale.unsqueeze(1) + scale_rng = 2e4 + z_val = 1e-4 + + W_q = c_round(W_f * iscale).clamp(min_max[0], min_max[1]) + + # Build scale tensor + W_f_c = W_f.clone() + W_f_c_mask = torch.abs(W_f_c) < z_val + W_f_c[W_f_c_mask] = z_val + + scale_tensor = (W_q).float() / W_f_c.float() + + # Normalize scale_tensor + scale_b = torch.median(scale_tensor, axis=1, keepdim=True)[0] + scale_b = scale_b.clamp(min=-scale_rng, max=scale_rng) + + # Mix with older scale + W_r = (W_q) / scale_b + err_b = torch.abs(W_f - W_r).mean(axis=1, keepdim=True) + + W_r = (W_q) / iscale + err_a = torch.abs(W_f - W_r).mean(axis=1, keepdim=True) + + mask = (err_b < err_a) + iscale_b = mask * scale_b + (~mask) * iscale + + scale_b = 1.0 / iscale_b + qweights = (c_round(W_f * iscale_b)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n + qweights = qweights.reshape(W_f.shape[0], -1 , 2) # m * n/2 * 2 + low_bit, high_bit = qweights.split(1, dim=-1) + high_bit = high_bit.squeeze().view(torch.int8) + low_bit = low_bit.squeeze().view(torch.int8) + high_bit = high_bit << 4 + low_bit = low_bit & 0x0f + qweights = high_bit | low_bit + + return qweights.view(torch.uint8), scale_b.to(torch.float16)