Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INT4] Compress model by quantizing weights to int4 #3307

Open
8 of 18 tasks
umangyadav opened this issue Jul 25, 2024 · 29 comments · Fixed by #3507 or #3510
Open
8 of 18 tasks

[INT4] Compress model by quantizing weights to int4 #3307

umangyadav opened this issue Jul 25, 2024 · 29 comments · Fixed by #3507 or #3510
Assignees

Comments

@umangyadav
Copy link
Member

umangyadav commented Jul 25, 2024

Idea

Use int4 as the compression technique to fit larger models onto Navi machines or possibly MI series machines. Weights would be compressed using encoding scheme that would pack two 4 bits numbers inside single uint8 value.

Input to MIGraphX

Input model to MIGraphX would be fp16 or Fp32 models entirely with weights in fp16 or fp32 as well.

Operations to Focus

Only GEMMs and Convolutions for now

Targeted ASICs

Navi3x/4x. Can anyone or @hgaspar please confirm if MI300 or beyond should be part of this or not?

Workflow

Given fp16 or fp32 weights as a node/literal, MIGraphX would transform that weight literal/node into following set of operations:

Fp16/32 Weight -> QuantizeLinear -> UInt8 Weights -> PackInt4 -> Int4 weights -> UnpackInt4 -> UInt8 Weights -> Dequantize Linear -> Fp16/32 Weights

During quantization, QuantizeLinear operation would set zero point such that UInt8 weights would come out as unsigned integer into range of [0, 15] values. Range to compute scale parameter for the QuantizeLinear should be set accordingly.

Special handling is required to disable constant propagation on above graph. Otherwise, it would undo what’s being done.

rocMLIR

MLIR would take following operations from above transformed graph and make them part of fused kernel for conv/gemm.
PackedInt4Weights -> UnpackInt4 -> UInt8 Weights -> DequantizeLinear -> Fp16/32 Weights

MIGraphX Work Items list

Testing

  • Get prequantized such FP16 model and test for accuracy on Navi/MI machines. (Contact @hgaspar for access to such models.)
  • Make sure this works on Windows.
  • Create test plan for QA to test this feature.
  • Figure out how to test this feature in CI

Future work

  • KV-cache with Int4

cc : @pfultz2 @causten @hgaspar @krzysz00

@pfultz2
Copy link
Collaborator

pfultz2 commented Jul 26, 2024

Enable "Ref" pipeline on by allowing const folding on UnpackInt4 so that It runs the model in original precision

I dont think this is needed. Ref doesnt run const folding, and it is fine if it runs it at runtime anyways.

Add flag in compile options to expose Int4 in C++/Python APIs.

This should just be a function call similar to the quantize_fp16 as well.

@umangyadav
Copy link
Member Author

umangyadav commented Jul 26, 2024

I dont think this is needed. Ref doesnt run const folding, and it is fine if it runs it at runtime anyways.

Yes it's just a check box that need to checked out. No work is required.

@umangyadav
Copy link
Member Author

umangyadav commented Jul 26, 2024

This should just be a function call similar to the quantize_fp16 as well.

@lakhinderwalia follow up on this one. On how to expose that to APIs whether

  1. As onnx_options or
  2. as a separate function call.

@pfultz2
Copy link
Collaborator

pfultz2 commented Jul 26, 2024

This should just be a function call similar to the quantize_fp16 as well.

@lakhinderwalia follow up on this one. On how to expose that to APIs whether

1. As onnx_options or

2. as a separate function call.

Actually for int4, we shouldn't expose an API for this since we most likely wont compute correct int4 weights(as training is needed to get corrrect values). We can have an internal function for perf testing(similiar to our int8 flags in the driver).

So we do need to be able to read them from the onnx file correctly though. We dont need to add an onnx option for that, we just need to add a pass to find the clips and replace it with in4 versions.

@pfultz2
Copy link
Collaborator

pfultz2 commented Jul 26, 2024

Make sure MIGraphX parses those models such correctly, recogizes the patterns and insert "Pack" after the "clip" to make it Packed Int4 weight

The "Clip" operator is just for notation purposes so we should replace clip with the pack/unpack pair.

@umangyadav
Copy link
Member Author

So we do need to be able to read them from the onnx file correctly though. We dont need to add an onnx option for that, we just need to add a pass to find the clips and replace it with in4 versions.

I can imagine a case where let's say Client is using same fake-quantized int4 model on two different machines.
One on Navi and other on MI.
On Navi they probably want to "realize" the compression to Int4. Having QuantizeLinear and Clip would likely have an accuracy impact.
On MI machines they probably don't want to "realize" int4 compression and const-fold "QDQ" because they want to preserve accuracy.

MIGraphX would need to provide a switch for that.

@pfultz2
Copy link
Collaborator

pfultz2 commented Jul 26, 2024

let's say Client is using same fake-quantized int4 model

This is very unlikely. A fake-quantized model implies that the weights can be computed with a simple scale and shift from the original floats, which is not the case. The values are carefully chosen from retraining the model.

MIGraphX would need to provide a switch for that.

We dont provide this switch for fake-quantized int8, either. I think this is out of scope for this feature and we can decide whether this needed at a later time.

@umangyadav
Copy link
Member Author

umangyadav commented Jul 26, 2024

I think this is out of scope for this feature and we can decide whether this needed at a later time.

Sounds good. Updated work list.

@umangyadav
Copy link
Member Author

I dont think this is needed. Ref doesnt run const folding, and it is fine if it runs it at runtime anyways.

Need a way to remove "pack" and "unpack" though for the "Ref" run.

@pfultz2
Copy link
Collaborator

pfultz2 commented Jul 26, 2024

Need a way to remove "pack" and "unpack" though for the "Ref" run.

Why? It will still run with those operators in there.

@umangyadav
Copy link
Member Author

Why? It will still run with those operators in there.

I see what you are saying. It will run entier Q+ Pack + Unpack + DQ pipeline and therefore shouldn't require any special handling.
Updated work items list.

@umangyadav
Copy link
Member Author

@pfultz2
Copy link
Collaborator

pfultz2 commented Aug 5, 2024

A couple more tasks that need to be addressed with onnx support:

@lakhinderwalia
Copy link
Contributor

@pfultz2 , given some const Fp16/32 node, it would transform, per the workflow mentioned above as:
Fp16/32 Weight -> QuantizeLinear -> UInt8 Weights -> PackInt4 -> Int4 weights -> UnpackInt4 -> UInt8 Weights -> Dequantize Linear -> Fp16/32 Weights
This wouldn't work in some models. A variation, based on a supplied sample model has these pre-supplied nodes:
[Int4 ZeroPoints] & [Fp 16 Weights] QuantizeLinear (output in Int4) --> [Int4 Zero Points] [FP16 scales] DeQuantizeLinear (output in Fp16)
And in this case, there is no extra node- that should be inserted. But we should directly support it in QuantizeLinear and DeQuantizeLinear.

@krzysz00
Copy link
Contributor

krzysz00 commented Aug 5, 2024

If we had the ability to name the type of the packed tensor as something other than uint8 - say int4x2 (being an alias for uint8 except that trying to do scalar arithmetic on the thing's an error), then you'd just have reinterpret [ZeroPoints byte literal] as int4x2, where that int4x2 is the same thing unpack produces.

But for the immediate case, you could unpack the [Int4 ZeroPoints] and then rewrite QuantizeLinear to QuantizeLinear + pack (or, my preference, QuantizeLinear + clip + pack because I really don't like implicit clipping behavior)

@lakhinderwalia
Copy link
Contributor

Thanks, @krzysz00 . That clip would still work in int8, however.

@pfultz2
Copy link
Collaborator

pfultz2 commented Aug 7, 2024

That clip would still work in int8, however.

quantizelinear already does clipping, so it will clip it for int8 and then we just need to update pack to clip it for int4.

We dont want to insert an explicit clip. It is true it will work for this case, but for packing other data types such fp4/fp6 it wont work, so for consistency we should just clip in the pack operator.

Also, we already do implicit clipping in the convert operator, and the pack_int4 is just a "fancy" convert, so for consistency we should clip there as well.

@umangyadav
Copy link
Member Author

umangyadav commented Aug 12, 2024

Not related to this issue or near term deliverables but at some point in future we would require :

  • Update onnx.proto to allow parsing of INT4 and UINT4 types from onnx, which would require bumping onnx version as well.
  • Enable ONNX backend tests for Int4 and UInt4.

This may require having int4 as native type in migraphx IR in some form

@yiakwy-xpu-ml-framework-team
Copy link

yiakwy-xpu-ml-framework-team commented Aug 14, 2024

@umangyadav thank you for sharing the roadmap. I subscribed the MIGraphX issues panel and occasionally see your updates on quantization support.

As for this question

"Navi3x/4x. Can anyone or @hgaspar please confirm if MI300 or beyond should be part of this or not?"

I wonder what's teams plan for LLM?

I deem MIGraphX as an inference engine with compiler stack against to TRT, but not TRT-LLM : There is a huge gap to run LLM part (huggingface converter, chunk-prefill optimization, continuous batching manager...). Currently our llama2 is demo level.

But meanwhile, to support MI300 which mainly runs in datacenter for LLM applications, we must develop a clear roadmap to support LLM (NLP instead CV application, but multi-modal is possible such as LLaVa) : MIGraphX-LLM for example.

Waiting to see your feedbacks.

@umangyadav
Copy link
Member Author

umangyadav commented Aug 14, 2024

I wonder what's teams plan for LLM?

For now, you can use vLLM for best support and performance on MI300s. MIGraphX can run LLMs too but it doesn't have kv-cache support yet. It's in progress and on future roadmap.

For these kinds of general questions, I find MIGraphX-discussions better place. https://github.com/ROCm/AMDMIGraphX/discussions

@lakhinderwalia
Copy link
Contributor

  • Handle QDQ pairs for int4 zero_points, -- int4 tensors which are getting unpack-ed

@lakhinderwalia
Copy link
Contributor

lakhinderwalia commented Aug 17, 2024

@yiakwy-xpu-ml-framework-team

I wonder what's teams plan for LLM?

For now, you can use vLLM for best support and performance on MI300s. MIGraphX can run LLMs too but it doesn't have kv-cache support yet. It's in progress and on future roadmap.

For these kinds of general questions, I find MIGraphX-discussions better place. https://github.com/ROCm/AMDMIGraphX/discussions

@umangyadav Thank you for this message. Yes the to run LLM on MI300s, vLLM is temporally best supported on vLLM . However vLLM does not support graph level optimization such as Hip graph capture (and its memory management), layout optimization (arith pass in MLIR), ops scheduling.

I used to work with ONNX solution (Graphcore PopART, PopRT), so I personally hope the MIGraphX could stand out. Another reason, I wish our compiler architects could work together to form a strong product to make significant impact.

@pfultz2
Copy link
Collaborator

pfultz2 commented Sep 23, 2024

The task still needed are:

  • Enable fusing unpack_int4 and dequantizelinear operators on the weights with mlir.
  • Improve constant propagation so it doesnt convert unpack_int4 or the scales from block quantization to constants
  • Enable fusing the reshapes for the scales and zero points into the mlir operator, so we dont convert it to a constant later on(in eliminate_contiguous).

@pfultz2
Copy link
Collaborator

pfultz2 commented Sep 23, 2024

To get constant propagation working, I think we can just skip over aliases(and reshape which is almost an alias):

bool skip_propagate(instruction_ref ins)
{
    if(contains({"contiguous", "dequantizelinear", "reshape"}))
        return skip_propagate(ins->inputs().front());
    auto alias = instruction::get_output_alias(ins, true);
    if(alias != ins)
        return skip_propagate(alias);
    if(ins->name() == "unpack_int4")
        return true;
    auto&& s = ins->get_shape();
    if(s.broadcasted() and not s.scalar())
        return true;
    if(s.scalar() and s.elements() != 1)
        return true;
    return false;
}

We may want to add an additional condition that the number of elements are not smaller after the alias: if(alias != ins and alias->get_shape().elements() >= ins->get_shape().elements(), so we wont skip over operators like slice or step. However, block quantization does use a slice for some cases so we might need tweak this further if we add this condition.

@lakhinderwalia
Copy link
Contributor

Other changes:
#3511 For global counters. (merged)
#3513 For Dequantizelinear Input Fusion.
#3494 For Propagate Constant. (merged)

@lakhinderwalia
Copy link
Contributor

lakhinderwalia commented Oct 16, 2024

#3528 unpack_int4 kernel.
#3531 dequantizelinear: remove ZP with zeros.
ROCm/rocMLIR#1682 RocMLIR vectorization fix.
#3523 Onnx Verify tests for unpack_int4.

@lakhinderwalia
Copy link
Contributor

#3541 Enable non packed inputs for MLIR.
#3609 Always output a packed type for q/dq

@lakhinderwalia
Copy link
Contributor

More relevant PRs: #3645, #3629, #3632, #3637, #3582

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants