diff --git a/scripts/quantize.py b/scripts/quantize.py index 1691d1325..fab5f73f0 100644 --- a/scripts/quantize.py +++ b/scripts/quantize.py @@ -164,34 +164,28 @@ def main(): save_path, ) - elif mode == QuantMode.Q4: - quantize_q4( - model, - save_path, - block_size=quantization_args.block_size, - is_symmetric=quantization_args.is_symmetric, - accuracy_level=quantization_args.accuracy_level, - ) - - elif mode == QuantMode.Q4F16: + elif mode in (QuantMode.Q4, QuantMode.Q4F16): + block_size = quantization_args.block_size or 32 + q4_model = quantize_q4( model, save_path=None, - block_size=quantization_args.block_size, + block_size=block_size, is_symmetric=quantization_args.is_symmetric, accuracy_level=quantization_args.accuracy_level, ) - quantize_fp16( - q4_model, - save_path, - ) + if mode == QuantMode.Q4F16: + quantize_fp16( + q4_model, + save_path, + ) elif mode == QuantMode.BNB4: quantize_bnb4( model, save_path, - block_size=quantization_args.block_size, - quant_type=quantization_args.quant_type, + block_size=quantization_args.block_size or 64, + quant_type=quantization_args.quant_type if quantization_args.quant_type is not None else MatMulBnb4Quantizer.NF4, ) elif mode in (QuantMode.Q8, QuantMode.QI8, QuantMode.QU8):