EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) is a new baseline for fast decoding of Large Language Models (LLMs) with provable performance maintenance. This approach involves extrapolating the second-top-layer contextual feature vectors of LLMs, enabling a significant boost in generation efficiency. EAGLE is building upon the following First Principle:
The sequence of LLM feature vectors is compressible over time, making the prediction of subsequent feature vectors from previous ones easy.
- EAGLE is:
- certified by the third-party evaluation as the fastest speculative method so far.
- achieving 2x speedup on gpt-fast.
- 3x faster than vanilla decoding (13B).
- 2x faster than Lookahead (13B).
- 1.6x faster than Medusa (13B).
- provably maintaining the consistency with vanilla decoding in the distribution of generated texts.
- trainable (within 1-2 days) and testable on 8x RTX 3090 GPUs. So even the GPU poor can afford it.
- combinable with other parallelled techniques such as vLLM, DeepSpeed, Mamba, FlashAttention, quantization, and hardware optimization.
Inference is conducted on RTX 3090 GPUs at fp16 precision using the Vicuna 33B model. For an enhanced viewing experience, the animation has been sped up fourfold.
2024.2.25: EAGLE is certified by the third-party evaluation as the fastest speculative method.
2024.1.17: We now support Mixtral-8x7B-Instruct.
2024.1.17: We have integrated gpt-fast into EAGLE, further accelerating the generation speed.
2024.1.15: We now support batch size > 1 generation.
2023.12.8: EAGLE v1.0 is released.
- Support non-greedy inference (provably maintaining text distribution).
- Support bs > 1.
- Support gpt-fast.
- Support more LLMs such as Mixtral 8x7B.
- Support vLLM.
pip install eagle-llm
git clone https://github.com/SafeAILab/EAGLE.git
cd EAGLE
pip install -e .
Base Model | EAGLE on Hugging Face | # EAGLE Parameters | Base Model | EAGLE on Hugging Face | # EAGLE Parameters |
---|---|---|---|---|---|
Vicuna-7B-v1.3 | yuhuili/EAGLE-Vicuna-7B-v1.3 | 0.24B | LLaMA2-Chat 7B | yuhuili/EAGLE-llama2-chat-7B | 0.24B |
Vicuna-13B-v1.3 | yuhuili/EAGLE-Vicuna-13B-v1.3 | 0.37B | LLaMA2-Chat 13B | yuhuili/EAGLE-llama2-chat-13B | 0.37B |
Vicuna-33B-v1.3 | yuhuili/EAGLE-Vicuna-33B-v1.3 | 0.56B | LLaMA2-Chat 70B | yuhuili/EAGLE-llama2-chat-70B | 0.99B |
Mixtral-8x7B-Instruct-v0.1 | yuhuili/EAGLE-mixtral-instruct-8x7B | 0.28B |
The inference code we provide automatically allocates model weights (loading a model across multiple GPUs), allowing you to run models that exceed the memory of a single GPU.
We have provided a suggested web interface, which you can use by running the following command. After the model is fully loaded, a URL will be output in the terminal, which you can enter into your browser to access.
python -m eagle.application.webui --ea-model-path [path of EAGLE weight]\
--base-model-path [path of the original model]\
--model-type [vicuna or llama-2-chat]
You can use our provided "eagenerate" for speedup generation just like using 'generate' from Hugging Face. Here is an example.
from eagle.model.ea_model import EaModel
from fastchat.model import get_conversation_template
model = EaModel.from_pretrained(
base_model_path=base_model_path,
ea_model_path=EAGLE_model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto"
)
model.eval()
your_message="Hello"
if use_llama_2_chat:
conv = get_conversation_template("llama-2-chat")
sys_p = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
conv.system_message = sys_p
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt() + " "
if use_vicuna:
conv = get_conversation_template("vicuna")
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids=model.tokenizer([prompt]).input_ids
input_ids = torch.as_tensor(input_ids).cuda()
output_ids=model.eagenerate(input_ids,temperature=0.5,max_new_tokens=512)
output=model.tokenizer.decode(output_ids[0])
Note: Vicuna and LLaMA2-Chat are both chat models. You need to use the correct chat template, otherwise it will cause abnormal output from the model and affect the performance of EAGLE.
Here is an example. Note that left padding is needed.
from eagle.modelbsne1.ea_model import EaModel
from fastchat.model import get_conversation_template
model = EaModel.from_pretrained(
base_model_path=base_model_path,
ea_model_path=EAGLE_model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto"
)
# left padding
model.eval()
model.tokenizer.padding_side = "left"
model.tokenizer.pad_token = model.tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
sys_p = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
your_message="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions."
conv = get_conversation_template("llama-2-chat")
conv.system_message = sys_p
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt1 = conv.get_prompt()+" "
your_message="Hello"
conv = get_conversation_template("llama-2-chat")
conv.system_message = sys_p
conv.append_message(conv.roles[0], your_message)
conv.append_message(conv.roles[1], None)
prompt2 = conv.get_prompt()+" "
input_s=model.tokenizer([prompt1,prompt2],return_tensors="pt",padding=True).to("cuda")
output_ids=model.eagenerate(input_s.input_ids,input_s.attention_mask,temperature=0.0,max_new_tokens=512,top_k=15)
output=model.tokenizer.batch_decode(output_ids)
print(output)
# vanilla auto-regression
# output_ids, new_token, idx=model.naivegenerate(input_s.input_ids,input_s.attention_mask,temperature=0.0,max_new_tokens=512,top_k=15,log=True)
You can run the following command to generate the training data.
python -m eagle.ge_data.allocation --outdir [path of data]
accelerate launch -m --mixed_precision=bf16 eagle.train.main --tmpdir [path of data]\
--cpdir [path of checkpoints] -- configpath [path of config file]
eagle/train provides examples of configuration files.
If the original LLM structure differs from LLaMA and Mixtral, you can utilize EAGLE in two ways.
This approach directly encapsulates the native Transformers LLM. Here is an example. Note: transformers version should be higher than 4.36.
from eagle.modeling_eagle import EAGLE
from transformers import AutoModelForCausalLM,AutoTokenizer
tokenizer=AutoTokenizer.from_pretrained(base_model_path)
model=AutoModelForCausalLM.from_pretrained("base_model_path",torch_dtype=torch.float16,device_map="auto",)
# for bs>1, the padding side should be right
if bs>1:
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
text=prompt1
# text=[prompt1,prompt2]
inputs = tokenizer(text, return_tensors="pt",padding=True)
eagle=EAGLE(model,eagle_path)
outs=eagle.generate(**inputs, max_new_tokens=200,temperature=0.0)
output=tokenizer.decode(outs)
# output=tokenizer.batch_decode(outs)
Copy the modeling_basemodelname.py from the Transformers library and proceed to make modifications to leverage the pre-allocated kv_cache for enhanced speed in the base model. You can refer to model/modeling_llama_kv.py for guidance, where places that require modifications are annotated with # [MODIFIED]. These modifications are minimal.
You can test the speed of EAGLE on MT-bench using the following command.
python -m eagle.evaluation.gen_ea_answer_vicuna(or gen_ea_answer_vicuna_llama2chat)\
--ea-model-path [path of EAGLE weight]\
--base-model-path [path of the original model]\
If you need specific acceleration ratios, you will also need to run the following command to get the speed of vanilla auto-regression.
python -m eagle.evaluation.gen_baseline_answer_vicuna\
(or gen_ea_answer_vicuna_llama2chat)\
--ea-model-path [path of EAGLE weight]\
--base-model-path [path of the original model]\
The above two commands will each generate a .jsonl file that records the generation results and wall time. Then, you can use evaluation/speed.py to calculate the ratio of speeds.
GPT-Fast primarily accelerates generation through quantization and compilation, which we have integrated into EAGLE. Here is the result of an experiment conducted on MT-bench with a single RTX3090, using LLaMA2-chat 7B.
Precision | fp16 | int4 |
---|---|---|
vanilla | 24.5 tokens/s | N/A |
gpt-fast | 55.1 tokens/s | 106.9 tokens/s |
EAGLE+gpt-fast | 100.2 tokens/s | 160.4 tokens/s |
Inference is conducted on a single RTX3090 GPU at int4 precision using the LLaMA2-chat 7B model. No additional training required.
In EAGLE, using gpt-fast only requires three steps: setting up the environment, quantizing weights, and modifying the model path.
Switch to the eaglefast branch.
git clone https://github.com/SafeAILab/EAGLE.git
git checkout eaglefast
Install the Preview (Nightly) version of PyTorch with CUDA 12.1, do not use "pip install torch" as it installs the Stable version, which lacks some of the new features used by gpt-fast.
This is a requirement for gpt-fast, whereas other branches of eagle can use the Stable version of PyTorch.
Convert Huggingface weights to the format required by gpt-fast.
python convert/convert_hf_checkpoint.py --checkpoint_dir path_of_base_model
python convert/convert_hf_checkpoint_EAGLE.py --checkpoint_dir path_of_eagle
Quantize weights.
python -m model.quantize_llama --checkpoint_path path_of_base_model/model.pth
python -m model.quantize_EAGLE --checkpoint_path path_of_eagle/model.pth
When specifying the model weights (including the base model and EAGLE), change "path" to "path/model_int4.g32.pth".
A heartfelt thank you to all our contributors.
For technical details and full experimental results, please check the paper.
@article{li2024eagle,
author = {Yuhui Li and Fangyun Wei and Chao Zhang and Hongyang Zhang},
title = {EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty},
journal = {arXiv preprint arXiv:2401.15077},
year = {2024}
}
This project has been influenced by many excellent projects in the LLM community, such as Medusa, FastChat, and others. The logo is designed by GPT-4. We also appreciate many valuable discussions with Tianle Cai, Hao Zhang, Ziteng Sun, and others.