-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[INT4] Compress model by quantizing weights to int4 #3307
Comments
I dont think this is needed. Ref doesnt run const folding, and it is fine if it runs it at runtime anyways.
This should just be a function call similar to the |
Yes it's just a check box that need to checked out. No work is required. |
@lakhinderwalia follow up on this one. On how to expose that to APIs whether
|
Actually for int4, we shouldn't expose an API for this since we most likely wont compute correct int4 weights(as training is needed to get corrrect values). We can have an internal function for perf testing(similiar to our int8 flags in the driver). So we do need to be able to read them from the onnx file correctly though. We dont need to add an onnx option for that, we just need to add a pass to find the clips and replace it with in4 versions. |
The "Clip" operator is just for notation purposes so we should replace clip with the pack/unpack pair. |
I can imagine a case where let's say Client is using same fake-quantized int4 model on two different machines. MIGraphX would need to provide a switch for that. |
This is very unlikely. A fake-quantized model implies that the weights can be computed with a simple scale and shift from the original floats, which is not the case. The values are carefully chosen from retraining the model.
We dont provide this switch for fake-quantized int8, either. I think this is out of scope for this feature and we can decide whether this needed at a later time. |
Sounds good. Updated work list. |
Need a way to remove "pack" and "unpack" though for the "Ref" run. |
Why? It will still run with those operators in there. |
I see what you are saying. It will run entier Q+ Pack + Unpack + DQ pipeline and therefore shouldn't require any special handling. |
A couple more tasks that need to be addressed with onnx support:
|
@pfultz2 , given some const Fp16/32 node, it would transform, per the workflow mentioned above as: |
If we had the ability to name the type of the packed tensor as something other than But for the immediate case, you could |
Thanks, @krzysz00 . That |
quantizelinear already does clipping, so it will clip it for int8 and then we just need to update pack to clip it for int4. We dont want to insert an explicit clip. It is true it will work for this case, but for packing other data types such fp4/fp6 it wont work, so for consistency we should just clip in the pack operator. Also, we already do implicit clipping in the |
Not related to this issue or near term deliverables but at some point in future we would require :
This may require having int4 as native type in migraphx IR in some form |
@umangyadav thank you for sharing the roadmap. I subscribed the MIGraphX issues panel and occasionally see your updates on quantization support. As for this question
I wonder what's teams plan for LLM? I deem MIGraphX as an inference engine with compiler stack against to TRT, but not TRT-LLM : There is a huge gap to run LLM part (huggingface converter, chunk-prefill optimization, continuous batching manager...). Currently our llama2 is demo level. But meanwhile, to support MI300 which mainly runs in datacenter for LLM applications, we must develop a clear roadmap to support LLM (NLP instead CV application, but multi-modal is possible such as LLaVa) : MIGraphX-LLM for example. Waiting to see your feedbacks. |
For now, you can use For these kinds of general questions, I find MIGraphX-discussions better place. https://github.com/ROCm/AMDMIGraphX/discussions |
|
|
@umangyadav Thank you for this message. Yes the to run LLM on MI300s, vLLM is temporally best supported on vLLM . However vLLM does not support graph level optimization such as Hip graph capture (and its memory management), layout optimization (arith pass in MLIR), ops scheduling. I used to work with ONNX solution (Graphcore PopART, PopRT), so I personally hope the MIGraphX could stand out. Another reason, I wish our compiler architects could work together to form a strong product to make significant impact. |
The task still needed are:
|
To get constant propagation working, I think we can just skip over aliases(and reshape which is almost an alias): bool skip_propagate(instruction_ref ins)
{
if(contains({"contiguous", "dequantizelinear", "reshape"}))
return skip_propagate(ins->inputs().front());
auto alias = instruction::get_output_alias(ins, true);
if(alias != ins)
return skip_propagate(alias);
if(ins->name() == "unpack_int4")
return true;
auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar())
return true;
if(s.scalar() and s.elements() != 1)
return true;
return false;
} We may want to add an additional condition that the number of elements are not smaller after the alias: |
#3528 |
Idea
Use int4 as the compression technique to fit larger models onto Navi machines or possibly MI series machines. Weights would be compressed using encoding scheme that would pack two 4 bits numbers inside single uint8 value.
Input to MIGraphX
Input model to MIGraphX would be fp16 or Fp32 models entirely with weights in fp16 or fp32 as well.
Operations to Focus
Only GEMMs and Convolutions for now
Targeted ASICs
Navi3x/4x. Can anyone or @hgaspar please confirm if MI300 or beyond should be part of this or not?
Workflow
Given fp16 or fp32 weights as a node/literal, MIGraphX would transform that weight literal/node into following set of operations:
Fp16/32 Weight -> QuantizeLinear -> UInt8 Weights -> PackInt4 -> Int4 weights -> UnpackInt4 -> UInt8 Weights -> Dequantize Linear -> Fp16/32 Weights
During quantization, QuantizeLinear operation would set zero point such that UInt8 weights would come out as unsigned integer into range of
[0, 15]
values. Range to computescale
parameter for the QuantizeLinear should be set accordingly.Special handling is required to disable constant propagation on above graph. Otherwise, it would undo what’s being done.
rocMLIR
MLIR would take following operations from above transformed graph and make them part of fused kernel for conv/gemm.
PackedInt4Weights -> UnpackInt4 -> UInt8 Weights -> DequantizeLinear -> Fp16/32 Weights
MIGraphX Work Items list
Add PackInt4 operator. Done with PackInt4 Operator #2730
Add UnpackInt4 operator. Done with Unpack int4 #2779
Disable Constant folding on UnpackInt4 to avoid undoing compression. #3323
It can be done by adding
UnpackInt4
operation toskip_ops
list here :AMDMIGraphX/src/include/migraphx/propagate_constant.hpp
Line 41 in 4a3c7b7
Add unit-tests for the same similar to
AMDMIGraphX/test/propagate_constant_test.cpp
Line 189 in 4a3c7b7
Hook Pack and Unpack operators in MIGraphX with rocMLIR's corresponding operators. This should work automatically mostly if Names of the operators for both pack and unpack are same across MLIR and MIGraphX.
Add verification tests for a simple program with pack/unpack instructions if MLIR materializes them internally.
Enable fusion pipeline with Int4 weights in
fuse_mlir
pass.Make sure INT4 conv/gemms are offloaded to MLIR and not BLAS/MIOpen.
Make sure MLIR knows which
axis
is packed.Introduce
--int4-weights
option inmigraphx-driver
. This would require changes in MIGraphX's naive quantizer to set range between[0, 15]
During quantization it should also insert "pack" and "unpack" instructions. #3341Inspect pre-quantized Int4 onnx models to identify quantization patterns. One such pattern is Int4 quantization would appear as "QuantizeLinear + Clip" pattern on the weights. Make sure MIGraphX parses those models such correctly, recogizes the patterns and insert "Pack" after the "clip" to make it Packed Int4 weights.
Enable "Ref" pipeline on by allowing const folding on(Not needed, see discussion).UnpackInt4
so that It runs the model in original precisionAdd flag in compile options to expose Int4 in C++/Python APIs.(Out of Scope for now)Add signed int4 support for pack_int4 operator. Currently only supports unsigned. #3358
Add signed int4 support for unpack_int4 operator. Currently only supports unsigned.
Onnx support: Parse int4 constants in onnx by inserting the packed buffer and an unpack operator. #3374
Handle const folding of zero-ed int4 zero_points.
Testing
Future work
cc : @pfultz2 @causten @hgaspar @krzysz00
The text was updated successfully, but these errors were encountered: