+```
+
+## Detailed of Pramater
+
+### Initialization
+
+Initialize **LLMLingua**, **LongLLMLingua**, and **LLMLingua-2** with the following parameters:
```python
from llmlingua import PromptCompressor
llm_lingua = PromptCompressor(
- model_name="NousResearch/Llama-2-7b-hf", # Default model
+ model_name="NousResearch/Llama-2-7b-hf", # Default model, use "microsoft/llmlingua-2-xlm-roberta-large-meetingbank" or "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank" for LLMLingua-2
device_map="cuda", # Device environment (e.g., 'cuda', 'cpu', 'mps')
model_config={}, # Configuration for the Huggingface model
open_api_config={}, # Configuration for OpenAI Embedding
+ use_llmlingua2=False, # Whether to use llmlingua-2
)
```
-### Parameters
+#### Parameters
-- **model_name** (str): Name of the small language model from Huggingface. Defaults to "NousResearch/Llama-2-7b-hf".
+- **model_name** (str): Name of the small language model from Huggingface, use "microsoft/llmlingua-2-xlm-roberta-large-meetingbank" or "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank" for LLMLingua-2. Defaults to "NousResearch/Llama-2-7b-hf".
- **device_map** (str): The computing environment. Options include 'cuda', 'cpu', 'mps', 'balanced', 'balanced_low_0', 'auto'. Default is 'cuda'.
- **model_config** (dict, optional): Configuration for the Huggingface model. Defaults to {}.
- **open_api_config** (dict, optional): Configuration for OpenAI Embedding in coarse-level prompt compression. Defaults to {}.
+- **use_llmlingua2** (bool, optional): Whether to use llmlingua-2 for prompt compression. Defaults is False.
-## Function Call
+### Function Call
Utilize (Long)LLMLingua for prompt compression with a range of customizable parameters:
@@ -61,10 +272,21 @@ compressed_prompt = llm_lingua.compress_prompt(
add_instruction: bool = False, # Adds instruction before the prompt
rank_method: str = "longllmlingua", # Method for ranking in coarse-level compression
concate_question: bool = True, # Includes the question in the compressed prompt
+ # Parameters for LLMLingua-2
+ target_context: int = -1, # Context Budget for Coarse-level Prompt Compression
+ context_level_rate: float = 1.0, # Compression rate for Coarse-level Prompt Compression
+ context_level_target_token: int = -1, # Token Budget for Coarse-level Prompt Compression
+ return_word_label: bool = False, # Whether to return words with corresponding labels. Default is False.
+ word_sep: str = '\t\t|\t\t', # The sep token used in fn_labeled_original_prompt to partition words.
+ label_sep: str = " ", # The sep token used in fn_labeled_original_prompt to partition word and label.
+ token_to_word: str = 'mean', # How to convert token probability to word probability. Default is 'mean'.
+ force_tokens: List[str] = [], # List of specific tokens to always include in the compressed result. Default is [].
+ force_reserve_digit: bool = False, # Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False.
+ drop_consecutive: bool = False, # Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt. Default is False
+ chunk_end_tokens: List[str] = [".", "\n"] # The early stop tokens for segmenting chunk. Default is [".", "\n"].
)
```
-
### Parameters
- **context** (str or List[str]): Contexts, documents, or demonstrations in the prompt, exhibiting low sensitivity to compression.
@@ -83,7 +305,7 @@ compressed_prompt = llm_lingua.compress_prompt(
- **keep_last_sentence** (int, optional): Specifies whether to retain the last 'k' sentences in each context. Default is 0.
- **keep_sentence_number** (int, optional): Specifies the number of sentences to retain in each context. Default is 0.
- **high_priority_bonus** (int, optional): Assigns a priority bonus to sentences retained by the **keep_first_sentence** or **keep_last_sentence** settings. Default is 100.
-- **context_budget** (str, optional): Budget for Coarse-level Prompt Compression, with supported operators like "*1.5" or "+100". Default is "+100".
+- **context_budget** (str, optional): Budget for Coarse-level Prompt Compression, with supported operators like "\*1.5" or "+100". Default is "+100".
- **token_budget_ratio** (float, optional): Budget ratio for sentence-level Prompt Compression. Default is 1.4.
- **condition_in_question** (str, optional): Determines the use of question-aware coarse-level prompt compression. Options include "none", "after", "before". Default is "none".
- **reorder_context** (str, optional): Method for document reordering before compression in LongLLMLingua. Options include "original", "sort", "two_stage". Default is "original".
@@ -91,23 +313,33 @@ compressed_prompt = llm_lingua.compress_prompt(
- **condition_compare** (bool, optional): Enables Iterative Token-level Question-aware Fine-Grained Compression in LongLLMLingua. Default is False.
- **add_instruction** (bool, optional): Determines whether to add an instruction before the prompt in Iterative Token-level Question-aware Fine-Grained Compression. Default is False.
- **rank_method** (str, optional): Selects the ranking method for Coarse-level Prompt Compression, with support for various embedding and reranker methods, as well as LLMLingua and LongLLMLingua. Default is "llmlingua".
- - "llmlingua": Employs the coarse-grained prompt compression technique of **LLMLingua**.
- - "longllmlingua": Utilizes the question-aware coarse-grained prompt compression method in **LongLLMLingua** (recommended).
- - Traditional Retrieval Methods:
- - "bm25": A bag-of-words retrieval function that ranks documents based on the occurrence of query terms, irrespective of their proximity within the documents.
- - "gzip": A retrieval method based on GZIP compression. For further information, see [GZIP Retrieval Method](https://aclanthology.org/2023.findings-acl.426).
- - Embedding-Based Retrieval Methods:
- - "sentbert": An embedding-based retrieval method. Learn more at [SentenceBERT](https://www.sbert.net).
- - "openai": Utilizes "text-embedding-ada-002" as the embedding model from OpenAI.
- - "bge": An embedding-based retrieval method using "BAAI/bge-large-en-v1.5". For additional information, visit [BGE-Large-EN-V1.5](https://huggingface.co/BAAI/bge-large-en-v1.5).
- - "voyageai": An embedding-based retrieval method provided by VoyageAI. More details at [VoyageAI](https://www.voyageai.com).
- - "jinza": An embedding-based retrieval method using "jinaai/jina-embeddings-v2-base-en". Further details are available at [JinaAI Embeddings](https://huggingface.co/jinaai/jina-embeddings-v2-base-en).
- - Reranker Methods:
- - "bge_reranker": A reranker-based method using "BAAI/bge-reranker-large". More information can be found at [BGE Reranker Large](https://huggingface.co/BAAI/bge-reranker-large).
- - "bge_llmembedder": A reranker-based method using "BAAI/llm-embedder". For more details, refer to [BAAI LLM Embedder](https://huggingface.co/BAAI/llm-embedder).
- - "cohere": A reranker-based method using "rerank-english-v2.0" from Cohere. Learn more at [Cohere Rerank](https://cohere.com/rerank).
+ - "llmlingua": Employs the coarse-grained prompt compression technique of **LLMLingua**.
+ - "longllmlingua": Utilizes the question-aware coarse-grained prompt compression method in **LongLLMLingua** (recommended).
+ - Traditional Retrieval Methods:
+ - "bm25": A bag-of-words retrieval function that ranks documents based on the occurrence of query terms, irrespective of their proximity within the documents.
+ - "gzip": A retrieval method based on GZIP compression. For further information, see [GZIP Retrieval Method](https://aclanthology.org/2023.findings-acl.426).
+ - Embedding-Based Retrieval Methods:
+ - "sentbert": An embedding-based retrieval method. Learn more at [SentenceBERT](https://www.sbert.net).
+ - "openai": Utilizes "text-embedding-ada-002" as the embedding model from OpenAI.
+ - "bge": An embedding-based retrieval method using "BAAI/bge-large-en-v1.5". For additional information, visit [BGE-Large-EN-V1.5](https://huggingface.co/BAAI/bge-large-en-v1.5).
+ - "voyageai": An embedding-based retrieval method provided by VoyageAI. More details at [VoyageAI](https://www.voyageai.com).
+ - "jinza": An embedding-based retrieval method using "jinaai/jina-embeddings-v2-base-en". Further details are available at [JinaAI Embeddings](https://huggingface.co/jinaai/jina-embeddings-v2-base-en).
+ - Reranker Methods:
+ - "bge_reranker": A reranker-based method using "BAAI/bge-reranker-large". More information can be found at [BGE Reranker Large](https://huggingface.co/BAAI/bge-reranker-large).
+ - "bge_llmembedder": A reranker-based method using "BAAI/llm-embedder". For more details, refer to [BAAI LLM Embedder](https://huggingface.co/BAAI/llm-embedder).
+ - "cohere": A reranker-based method using "rerank-english-v2.0" from Cohere. Learn more at [Cohere Rerank](https://cohere.com/rerank).
- **concate_question** (bool, optional): Determines whether to include the question in the compressed prompt. Default is True.
-
+- **target_context** (int): The maximum number of contexts to be achieved in context level compression. Default is -1 (no compression on context level).
+- **context_level_rate** (float): The compression rate target to be achieved in context level. Default is 1.0 (no compression on context level).
+- **context_level_target_token** (int): The maximum number of tokens to be achieved in context level compression. Default is -1 (no compression on context level).
+- **return_word_label** (bool): Whether to return words with corresponding labels. Default is False.
+- **word_sep** (str): The sep token used in fn_labeled_original_prompt to partition words. Only used when return_word_label==True. Default is '\t\t|\t\t'
+- **label_sep** (str): The sep token used in fn_labeled_original_prompt to partition word and label. Only used when return_word_label==True. Default is ' '
+- **token_to_word** (str): The method to convert token probability to word probability. Default is 'mean'
+- **force_tokens** (List[str], optional): List of specific tokens to always include in the compressed result. Default is [].
+- **force_reserve_digit** (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False.
+- **drop_consecutive** (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt. Default is False.
+- **chunk_end_tokens** (List[str], optinal): The early stop tokens for segmenting chunk. Default is [".", "\n"].
### Response
@@ -117,7 +349,12 @@ compressed_prompt = llm_lingua.compress_prompt(
- **ratio** (str): Actual compression ratio.
- **saving** (str): Savings in GPT-4 cost.
-## Post-Processing
+Additional Response Parameter for LLMLingua-2.
+
+- **fn_labeled_original_prompt** (str): original words along with their labels indicating whether to reserve in compressed prompt, in the format (word1 label_sep label2 word_sep word2 label_sep label2 ...). Only return when return_word_label==True.
+- **compressed_prompt_list** (str): List of the compressed prompt.
+
+### Post-Processing
Recover the original response from a compressed prompt:
@@ -129,93 +366,12 @@ recovered_response = llm_lingua.recover(
)
```
-### Parameters
+#### Parameters
- **original_prompt** (str): The original prompt.
- **compressed_prompt** (str): The compressed prompt.
- **response** (str): The response from black-box LLMs based on the compressed prompt.
-### Response
+#### Response
- **recovered_response** (str): The recovered response, integrating the original prompt's context.
-
-## Advanced Usage
-
-### Utilizing Small Models
-
-### Using phi-2
-
-Thanks to the efforts of the community, phi-2 is now available for use in LLMLingua.
-
-Before using it, please update your transformers to the GitHub version by running `pip install -U git+https://github.com/huggingface/transformers.git`.
-
-```python
-llm_lingua = PromptCompressor("microsoft/phi-2")
-```
-
-### Quantized Models
-
-(Long)LLMLingua supports the use of quantized small models such as `TheBloke/Llama-2-7b-Chat-GPTQ`, which require less than 8GB of GPU memory.
-
-To begin, ensure you install the necessary packages with:
-
-```bash
-pip install optimum auto-gptq
-```
-
-Then, initialize your model as follows:
-
-```python
-from llmlingua import PromptCompressor
-
-llm_lingua = PromptCompressor("TheBloke/Llama-2-7b-Chat-GPTQ", model_config={"revision": "main"})
-```
-
-### Integration with LangChain
-
-Thanks to the contributions of Ayo Ayibiowu (@thehapyone), (Long)LLMLingua can be seamlessly integrated into LangChain. Here's an example of how to initialize (Long)LLMLingua within LangChain:
-
-```python
-from langchain.retrievers import ContextualCompressionRetriever
-from langchain_community.retrievers.document_compressors import LLMLinguaCompressor
-from langchain_openai import ChatOpenAI
-
-llm = ChatOpenAI(temperature=0)
-
-compressor = LLMLinguaCompressor(model_name="openai-community/gpt2", device_map="cpu")
-compression_retriever = ContextualCompressionRetriever(
- base_compressor=compressor, base_retriever=retriever
-)
-
-compressed_docs = compression_retriever.get_relevant_documents(
- "What did the president say about Ketanji Jackson Brown"
-)
-pretty_print_docs(compressed_docs)
-```
-
-For a more detailed guide, please refer to [Notebook](https://github.com/langchain-ai/langchain/blob/master/docs/docs/integrations/retrievers/llmlingua.ipynb).
-
-### Integration with LlamaIndex
-
-Thanks to the contributions of Jerry Liu (@jerryjliu), (Long)LLMLingua can be seamlessly integrated into LlamaIndex. Here's an example of how to initialize (Long)LLMLingua within LlamaIndex:
-
-```python
-from llama_index.query_engine import RetrieverQueryEngine
-from llama_index.response_synthesizers import CompactAndRefine
-from llama_index.indices.postprocessor import LongLLMLinguaPostprocessor
-
-node_postprocessor = LongLLMLinguaPostprocessor(
- instruction_str="Given the context, please answer the final question",
- target_token=300,
- rank_method="longllmlingua",
- additional_compress_kwargs={
- "condition_compare": True,
- "condition_in_question": "after",
- "context_budget": "+100",
- "reorder_context": "sort", # Enables document reordering
- "dynamic_context_compression_ratio": 0.4, # Enables dynamic compression ratio
- },
-)
-```
-
-For a more detailed guide, please refer to [RAGLlamaIndex Example](https://github.com/microsoft/LLMLingua/blob/main/examples/RAGLlamaIndex.ipynb).
diff --git a/README.md b/README.md
index 65489dc..2d89810 100644
--- a/README.md
+++ b/README.md
@@ -3,21 +3,24 @@
-
(Long)LLMLingua: Enhancing Large Language Model Inference via Prompt Compression
+ LLMLingua Series | Effectively Deliver Information to LLMs via Prompt Compression
| Project Page |
- LLMLingua Paper |
- LongLLMLingua Paper |
- HF Space Demo |
+ LLMLingua |
+ LongLLMLingua |
+ LLMLingua-2 |
+ LLMLingua Demo |
+ LLMLingua-2 Demo |
https://github.com/microsoft/LLMLingua/assets/30883354/eb0ea70d-6d4c-4aa7-8977-61f94bb87438
## News
+- 🦚 We're excited to announce the release of **LLMLingua-2**, boasting a 3x-6x speed improvement over LLMLingua! For more information, check out our [paper](https://arxiv.org/abs/2403.), visit the [project page](https://llmlingua.com/llmlingua-2.html), and explore our [demo](https://huggingface.co/spaces/microsoft/LLMLingua-2).
- 👾 LLMLingua has been integrated into [LangChain](https://github.com/langchain-ai/langchain/blob/master/docs/docs/integrations/retrievers/llmlingua.ipynb) and [LlamaIndex](https://github.com/run-llama/llama_index/blob/main/docs/examples/node_postprocessor/LongLLMLingua.ipynb), two widely-used RAG frameworks.
- 🤳 Talk slides are available in [AI Time Jan, 24](https://drive.google.com/file/d/1fzK3wOvy2boF7XzaYuq2bQ3jFeP1WMk3/view?usp=sharing).
- 🖥 EMNLP'23 slides are available in [Session 5](https://drive.google.com/file/d/1GxQLAEN8bBB2yiEdQdW4UKoJzZc0es9t/view) and [BoF-6](https://drive.google.com/file/d/1LJBUfJrKxbpdkwo13SgPOqugk-UjLVIF/view).
@@ -28,13 +31,19 @@ https://github.com/microsoft/LLMLingua/assets/30883354/eb0ea70d-6d4c-4aa7-8977-6
## TL;DR
LLMLingua utilizes a compact, well-trained language model (e.g., GPT2-small, LLaMA-7B) to identify and remove non-essential tokens in prompts. This approach enables efficient inference with large language models (LLMs), achieving up to 20x compression with minimal performance loss.
+
- [LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models](https://aclanthology.org/2023.emnlp-main.825/) (EMNLP 2023)
-_Huiqiang Jiang, Qianhui Wu, Chin-Yew Lin, Yuqing Yang and Lili Qiu_
+ _Huiqiang Jiang, Qianhui Wu, Chin-Yew Lin, Yuqing Yang and Lili Qiu_
LongLLMLingua mitigates the 'lost in the middle' issue in LLMs, enhancing long-context information processing. It reduces costs and boosts efficiency with prompt compression, improving RAG performance by up to 21.4% using only 1/4 of the tokens.
+
- [LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression](https://arxiv.org/abs/2310.06839) (ICLR ME-FoMo 2024)
-_Huiqiang Jiang, Qianhui Wu, Xufang Luo, Dongsheng Li, Chin-Yew Lin, Yuqing Yang and Lili Qiu_
+ _Huiqiang Jiang, Qianhui Wu, Xufang Luo, Dongsheng Li, Chin-Yew Lin, Yuqing Yang and Lili Qiu_
+
+LLMLingua-2, a small-size yet powerful prompt compression method trained via data distillation from GPT-4 for token classification with a BERT-level encoder, excels in task-agnostic compression. It surpasses LLMLingua in handling out-of-domain data, offering 3x-6x faster performance.
+- [LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression](https://arxiv.org/abs/2403.) (Under Review)
+ _Zhuoshi Pan, Qianhui Wu, Huiqiang Jiang, Menglin Xia, Xufang Luo, Jue Zhang, Qingwei Lin, Victor Ruhle, Yuqing Yang, Chin-Yew Lin, H. Vicky Zhao, Lili Qiu, Dongmei Zhang_
## 🎥 Overview
@@ -48,11 +57,11 @@ While Large Language Models like ChatGPT and GPT-4 excel in generalization and r
![Motivation for LLMLingua](./images/motivation.png)
-Now you can use **LLMLingua** & **LongLLMLingua**!
+Now you can use **LLMLingua**, **LongLLMLingua**, and **LLMLingua-2**!
These tools offer an efficient solution to compress prompts by up to **20x**, enhancing the utility of LLMs.
-- 💰 **Cost Savings**: Reduces both prompt and generation lengths.
+- 💰 **Cost Savings**: Reduces both prompt and generation lengths with minimal overhead.
- 📝 **Extended Context Support**: Enhances support for longer contexts, mitigates the "lost in the middle" issue, and boosts overall performance.
- ⚖️ **Robustness**: No additional training needed for LLMs.
- 🕵️ **Knowledge Retention**: Maintains original prompt information like ICL and reasoning.
@@ -63,7 +72,7 @@ These tools offer an efficient solution to compress prompts by up to **20x**, en
![Framework of LongLLMLingua](./images/LongLLMLingua.png)
-![Demo of LLMLingua](./images/LLMLingua_demo.png)
+![Framework of LLMLingua-2](./images/LLMLingua-2.png)
PS: This demo is based on the [alt-gpt](https://github.com/feedox/alt-gpt) project. Special thanks to @Livshitz for their valuable contribution.
@@ -82,6 +91,7 @@ If you find this repo helpful, please cite the following papers:
pages = "13358--13376",
}
```
+
```bibtex
@article{jiang-etal-2023-longllmlingua,
title = "{L}ong{LLML}ingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression",
@@ -93,19 +103,30 @@ If you find this repo helpful, please cite the following papers:
}
```
+```bibtex
+@article{wu2024llmlingua2,
+ title = "{LLML}ingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression",
+ author = "Zhuoshi Pan and Qianhui Wu and Huiqiang Jiang and Menglin Xia and Xufang Luo and Jue Zhang and Qingwei Lin and Victor Ruhle and Yuqing Yang and Chin-Yew Lin and H. Vicky Zhao and Lili Qiu and Dongmei Zhang",
+ url = "https://arxiv.org/abs/2403.",
+ journal = "ArXiv preprint",
+ volume = "abs/2403.",
+ year = "2024",
+}
+```
+
## 🎯 Quick Start
-#### 1. **Installing (Long)LLMLingua:**
+#### 1. **Installing LLMLingua:**
-To get started with (Long)LLMLingua, simply install it using pip:
+To get started with LLMLingua, simply install it using pip:
```bash
pip install llmlingua
```
-#### 2. **Using (Long)LLMLingua for Prompt Compression:**
+#### 2. **Using LLMLingua Series Methods for Prompt Compression:**
-With (Long)LLMLingua, you can easily compress your prompts. Here’s how you can do it:
+With **LLMLingua**, you can easily compress your prompts. Here’s how you can do it:
```python
from llmlingua import PromptCompressor
@@ -120,7 +141,6 @@ compressed_prompt = llm_lingua.compress_prompt(prompt, instruction="", question=
# 'saving': ', Saving $0.1 in GPT-4.'}
## Or use the phi-2 model,
-## Before that, you need to update the transformers to the github version, like pip install -U git+https://github.com/huggingface/transformers.git
llm_lingua = PromptCompressor("microsoft/phi-2")
## Or use the quantation model, like TheBloke/Llama-2-7b-Chat-GPTQ, only need <8GB GPU memory.
@@ -128,6 +148,44 @@ llm_lingua = PromptCompressor("microsoft/phi-2")
llm_lingua = PromptCompressor("TheBloke/Llama-2-7b-Chat-GPTQ", model_config={"revision": "main"})
```
+To try **LongLLMLingua** in your scenorias, you can use
+
+```python
+from llmlingua import PromptCompressor
+
+llm_lingua = PromptCompressor()
+compressed_prompt = llm_lingua.compress_prompt(
+ prompt_list,
+ question=question,
+ ratio=0.55,
+ # Set the special parameter for LongLLMLingua
+ condition_in_question="after_condition",
+ reorder_context="sort",
+ dynamic_context_compression_ratio=0.3, # or 0.4
+ condition_compare=True,
+ context_budget="+100",
+ rank_method="longllmlingua",
+)
+```
+
+To try **LLMLingua-2** in your scenorias, you can use
+
+```python
+from llmlingua import PromptCompressor
+
+llm_lingua = PromptCompressor(
+ model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
+ use_llmlingua2=True, # Whether to use llmlingua-2
+)
+compressed_prompt = llm_lingua.compress_prompt(prompt, rate=0.33, force_tokens = ['\n', '?'])
+
+## Or use LLMLingua-2-small model
+llm_lingua = PromptCompressor(
+ model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
+ use_llmlingua2=True, # Whether to use llmlingua-2
+)
+```
+
#### 3. **Advanced usage - Structured Prompt Compression:**
Split text into sections, decide on whether to compress and its rate. Use `` tags for context segmentation, with optional rate and compress parameters.
@@ -148,13 +206,17 @@ print(compressed_prompt['compressed_prompt'])
To understand how to apply LLMLingua and LongLLMLingua in real-world scenarios like RAG, Online Meetings, CoT, and Code, please refer to our [**examples**](./examples). For detailed guidance, the [**documentation**](./DOCUMENT.md) provides extensive recommendations on effectively utilizing LLMLingua.
+#### 5. **Data collection and model training of LLMLingua-2:**
+
+To train the compressor on your custom data, please refer to our [**data_collection**](./experiments/llmlingua2/data_collection) and [**model_training**](./experiments/llmlingua2/model_training).
+
## Frequently Asked Questions
For more insights and answers, visit our [FAQ section](./Transparency_FAQ.md).
## Contributing
-This project welcomes contributions and suggestions. Most contributions require you to agree to a
+This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
diff --git a/Transparency_FAQ.md b/Transparency_FAQ.md
index 61f8f3b..150b791 100644
--- a/Transparency_FAQ.md
+++ b/Transparency_FAQ.md
@@ -120,7 +120,7 @@ Out[3]:
}
```
-## How to reproduce the result in LLMLingua & LongLLMLingua?
+## How to reproduce the result in LLMLingua Series work?
We release the parameter in the [issue1](https://github.com/microsoft/LLMLingua/issues/76), [issue2](https://github.com/microsoft/LLMLingua/issues/86).
@@ -157,6 +157,25 @@ compressed_prompt = llm_lingua.compress_prompt(
Experiments in LLMLingua and most experiments in LongLLMLingua were conducted in completion mode, whereas chat mode tends to be more sensitive to token-level compression. However, OpenAI has currently disabled GPT-3.5-turbo's completion; you can use GPT-3.5-turbo-instruction or Azure OpenAI service instead.
+**LLMLingua-2**:
+
+```python
+from llmlingua import PromptCompressor
+
+llm_lingua = PromptCompressor(
+ model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
+ use_llmlingua2=True, # Whether to use llmlingua-2
+)
+compressed_prompt = llm_lingua.compress_prompt(prompt, rate=0.33, force_tokens = ['\n', '?'])
+
+## Or use LLMLingua-2-small model
+llm_lingua = PromptCompressor(
+ model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
+ use_llmlingua2=True, # Whether to use llmlingua-2
+)
+```
+
+And you can find the details of the LLMLingua-2 experiments at [experiments/llmlingua2](./examples/llmlingua2).
## How to use LLMLingua in LangChain and LlamaIndex?
diff --git a/examples/LLMLingua2.ipynb b/examples/LLMLingua2.ipynb
new file mode 100644
index 0000000..c20f917
--- /dev/null
+++ b/examples/LLMLingua2.ipynb
@@ -0,0 +1,1032 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## LLMLingua2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ " \n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "LLMLingua-2 focuses on task-agnostic prompt compression for better generalizability and efficiency. It is a small-size yet powerful prompt compression method trained via data distillation from GPT-4 for token classification with a BERT-level encoder, excels in task-agnostic compression. It surpasses LLMLingua in handling out-of-domain data, offering 3x-6x faster performance.\n",
+ "\n",
+ "Below, We showcase the usage and compression results of LLMLingua-2 on both in-domain and out-of-domain datasets, including various tasks such as single-document QA, multi-document QA, summarization and in-context learning.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from llmlingua import PromptCompressor\n",
+ "\n",
+ "llm_lingua = PromptCompressor(\n",
+ " model_name=\"microsoft/llmlingua-2-xlm-roberta-large-meetingbank\",\n",
+ " use_llmlingua2=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Target LLM Config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install openai==0.28"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Using the OAI\n",
+ "import openai\n",
+ "\n",
+ "openai.api_key = \"\"\n",
+ "\n",
+ "# or Using the AOAI\n",
+ "import openai\n",
+ "\n",
+ "openai.api_key = \"\"\n",
+ "openai.api_base = \"\"\n",
+ "openai.api_type = \"azure\"\n",
+ "openai.api_version = \"2023-05-15\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## In-Domain"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Below, we present the results of LLMLingua-2 compared to the strong baselines on In-Domain data: test set of MeetingBank.\n",
+ "Despite the fact that our compressors are much smaller than the LLaMa-2-7B used in the baselines, \n",
+ "our approach achieves significantly better performance on both the QA and Summary tasks, and comes close to matching the performance of the original prompt."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### MeetingBank\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 117,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Download the original prompt and dataset\n",
+ "from datasets import load_dataset\n",
+ "\n",
+ "dataset = load_dataset(\"huuuyeah/meetingbank\", split=\"test\")\n",
+ "context = dataset[0][\"transcript\"]\n",
+ "\n",
+ "question = \"What is the agenda item three resolution 31669 about?\\nAnswer:\"\n",
+ "reference = \"Encouraging individualized tenant assessment.\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"id\": \"chatcmpl-94T49ZkAUgmY2EQQVuzS8EcklBZQO\",\n",
+ " \"object\": \"chat.completion\",\n",
+ " \"created\": 1710852069,\n",
+ " \"model\": \"gpt-4-32k\",\n",
+ " \"choices\": [\n",
+ " {\n",
+ " \"index\": 0,\n",
+ " \"message\": {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": \"Agenda item three resolution 31669 is about encouraging the use of an individualized tenant assessment using the Fair Housing Act's discriminatory effect standards to avoid Fair Housing Act violations when criminal history is used as a screening criterion in the Landlord Screening Process. The resolution aims to ensure that landlords understand the law when it comes to making decisions based on criminal history. It also highlights the policies that the Department of Housing and Urban Development (HUD) is currently promoting and the policy direction that the city will be pursuing\"\n",
+ " },\n",
+ " \"finish_reason\": \"length\"\n",
+ " }\n",
+ " ],\n",
+ " \"usage\": {\n",
+ " \"prompt_tokens\": 1362,\n",
+ " \"completion_tokens\": 100,\n",
+ " \"total_tokens\": 1462\n",
+ " },\n",
+ " \"system_fingerprint\": null\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# The response from original prompt, using GPT-4-32k\n",
+ "import json\n",
+ "\n",
+ "prompt = \"\\n\\n\".join([context, question])\n",
+ "\n",
+ "message = [\n",
+ " {\"role\": \"user\", \"content\": prompt},\n",
+ "]\n",
+ "\n",
+ "request_data = {\n",
+ " \"messages\": message,\n",
+ " \"max_tokens\": 100,\n",
+ " \"temperature\": 0,\n",
+ " \"top_p\": 1,\n",
+ " \"n\": 1,\n",
+ " \"stream\": False,\n",
+ "}\n",
+ "response = openai.ChatCompletion.create(\n",
+ " engine=\"gpt-4-32k\",\n",
+ " **request_data,\n",
+ ")\n",
+ "print(json.dumps(response, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 2000 Compression\n",
+ "compressed_prompt = llm_lingua.compress_prompt(\n",
+ " context,\n",
+ " rate=0.33,\n",
+ " force_tokens=[\"!\", \".\", \"?\", \"\\n\"],\n",
+ " drop_consecutive=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"id\": \"chatcmpl-94T4RJZkt4dZz01FQv5gZhNp0qGq5\",\n",
+ " \"object\": \"chat.completion\",\n",
+ " \"created\": 1710852087,\n",
+ " \"model\": \"gpt-4-32k\",\n",
+ " \"choices\": [\n",
+ " {\n",
+ " \"index\": 0,\n",
+ " \"message\": {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": \"The agenda item three resolution 31669 is about individualized tenant assessment under the Fair Housing Act. It aims to avoid discriminatory standards and violations in the landlord screening process, particularly in relation to criminal history screening criteria. The resolution also discusses the Certificate of Restoration of Opportunity, a state legislation designed to provide potential employers and housing providers with information about individuals who have served prison time and have been released, to facilitate their reintegration into society.\"\n",
+ " },\n",
+ " \"finish_reason\": \"stop\"\n",
+ " }\n",
+ " ],\n",
+ " \"usage\": {\n",
+ " \"prompt_tokens\": 444,\n",
+ " \"completion_tokens\": 87,\n",
+ " \"total_tokens\": 531\n",
+ " },\n",
+ " \"system_fingerprint\": null\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# The response from original prompt, using GPT-4-32k\n",
+ "import json\n",
+ "\n",
+ "prompt = \"\\n\\n\".join([compressed_prompt[\"compressed_prompt\"], question])\n",
+ "\n",
+ "message = [\n",
+ " {\"role\": \"user\", \"content\": prompt},\n",
+ "]\n",
+ "\n",
+ "request_data = {\n",
+ " \"messages\": message,\n",
+ " \"max_tokens\": 100,\n",
+ " \"temperature\": 0,\n",
+ " \"top_p\": 1,\n",
+ " \"n\": 1,\n",
+ " \"stream\": False,\n",
+ "}\n",
+ "response = openai.ChatCompletion.create(\n",
+ " engine=\"gpt-4-32k\",\n",
+ " **request_data,\n",
+ ")\n",
+ "print(json.dumps(response, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Out-of-Domain\n",
+ "\n",
+ "As our model is only trained on meeting transcripts data from MeetingBank, here we explore its generalization ability across various benchmarks of long-context scenarios, reasoning, and in-context learning.\n",
+ "Although the compressor of LLMLingua-2 is only trained on MeetingBank data, LLMLingua-2 is also effective on out-of-domain data, \n",
+ "with its performance comparable to or even surpassing the SOTA task-agnostic compression baselines. \n",
+ "\n",
+ "Below, we showcase several compression results on LongBench and GSM8K, including single-document QA, multi-document QA, summarization and in-context learning tasks."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load LongBench Prompt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset2prompt = {\n",
+ " \"narrativeqa\": \"You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\\n\\nStory: {context}\\n\\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\\n\\nQuestion: {input}\\n\\nAnswer:\",\n",
+ " \"gov_report\": \"You are given a report by a government agency. Write a one-page summary of the report.\\n\\nReport:\\n{context}\\n\\nNow, write a one-page summary of the report.\\n\\nSummary:\",\n",
+ " \"triviaqa\": \"Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\\n\\n{context}\\n\\n{input}\",\n",
+ "}\n",
+ "\n",
+ "dataset2maxlen = {\n",
+ " \"narrativeqa\": 128,\n",
+ " \"gov_report\": 512,\n",
+ " \"triviaqa\": 32,\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Single-Doc QA"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['To smuggle Socrates out of prison and into a life of exile.']\n"
+ ]
+ }
+ ],
+ "source": [
+ "task = \"narrativeqa\"\n",
+ "dataset = load_dataset(\"THUDM/LongBench\", task, split=\"test\")\n",
+ "sample = dataset[3]\n",
+ "context = sample[\"context\"]\n",
+ "reference = sample[\"answers\"]\n",
+ "print(reference)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"id\": \"chatcmpl-94TFlOkf2ps8qW0Y4jKf6cnsn6Mbi\",\n",
+ " \"object\": \"chat.completion\",\n",
+ " \"created\": 1710852789,\n",
+ " \"model\": \"gpt-4-32k\",\n",
+ " \"choices\": [\n",
+ " {\n",
+ " \"index\": 0,\n",
+ " \"message\": {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": \"To convince Socrates to escape from prison.\"\n",
+ " },\n",
+ " \"finish_reason\": \"stop\"\n",
+ " }\n",
+ " ],\n",
+ " \"usage\": {\n",
+ " \"prompt_tokens\": 9059,\n",
+ " \"completion_tokens\": 9,\n",
+ " \"total_tokens\": 9068\n",
+ " },\n",
+ " \"system_fingerprint\": null\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# The response from original prompt, using GPT-4-32k\n",
+ "import json\n",
+ "\n",
+ "prompt_format = dataset2prompt[task]\n",
+ "max_gen = int(dataset2maxlen[task])\n",
+ "prompt = prompt_format.format(**sample)\n",
+ "\n",
+ "message = [\n",
+ " {\"role\": \"user\", \"content\": prompt},\n",
+ "]\n",
+ "\n",
+ "request_data = {\n",
+ " \"messages\": message,\n",
+ " \"max_tokens\": max_gen,\n",
+ " \"temperature\": 0,\n",
+ " \"top_p\": 1,\n",
+ " \"n\": 1,\n",
+ " \"stream\": False,\n",
+ "}\n",
+ "response = openai.ChatCompletion.create(\n",
+ " engine=\"gpt-4-32k\",\n",
+ " **request_data,\n",
+ ")\n",
+ "print(json.dumps(response, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 3000 Compression\n",
+ "compressed_prompt = llm_lingua.compress_prompt(\n",
+ " context,\n",
+ " target_token=3000,\n",
+ " force_tokens=[\"!\", \".\", \"?\", \"\\n\"],\n",
+ " drop_consecutive=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"id\": \"chatcmpl-94TG5pKpMAsmKwGKcQy3ZvFjKPM5S\",\n",
+ " \"object\": \"chat.completion\",\n",
+ " \"created\": 1710852809,\n",
+ " \"model\": \"gpt-4-32k\",\n",
+ " \"choices\": [\n",
+ " {\n",
+ " \"index\": 0,\n",
+ " \"message\": {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": \"To persuade Socrates to escape from prison.\"\n",
+ " },\n",
+ " \"finish_reason\": \"stop\"\n",
+ " }\n",
+ " ],\n",
+ " \"usage\": {\n",
+ " \"prompt_tokens\": 3064,\n",
+ " \"completion_tokens\": 9,\n",
+ " \"total_tokens\": 3073\n",
+ " },\n",
+ " \"system_fingerprint\": null\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# The response from original prompt, using GPT-4-32k\n",
+ "import json\n",
+ "\n",
+ "prompt_format = dataset2prompt[task]\n",
+ "max_gen = int(dataset2maxlen[task])\n",
+ "sample[\"context\"] = compressed_prompt[\"compressed_prompt\"]\n",
+ "prompt = prompt_format.format(**sample)\n",
+ "\n",
+ "message = [\n",
+ " {\"role\": \"user\", \"content\": prompt},\n",
+ "]\n",
+ "\n",
+ "request_data = {\n",
+ " \"messages\": message,\n",
+ " \"max_tokens\": max_gen,\n",
+ " \"temperature\": 0,\n",
+ " \"top_p\": 1,\n",
+ " \"n\": 1,\n",
+ " \"stream\": False,\n",
+ "}\n",
+ "response = openai.ChatCompletion.create(\n",
+ " engine=\"gpt-4-32k\",\n",
+ " **request_data,\n",
+ ")\n",
+ "print(json.dumps(response, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Multi-Doc QA"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 105,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['The United States of America', 'United States Of Amerca', 'Us of a', 'U.–S.–A.', 'Americaland', 'United States (U.S.A.)', 'Amurika', 'Unite states of america', 'United States of America (redirect)', 'The U S A', 'Unietd States', 'EE UU', 'The U.S.A.', 'U.-S.-A.', 'Usa', 'United Staets of America', 'Unites States', \"États-Unis d'Amérique\", 'Verenigde State', 'U.–S.', 'The United States of America.', 'The U-S-A', 'EEUU', 'U. S. A.', 'Nagkaisang mga Estado', 'The U. S. of America', 'The USA', 'America (United States)', 'The U. S. A.', 'U S of America', 'UNITED STATES', 'Estados Unidos', 'The U–S', 'American United States', 'US and A', 'Unitd states', 'The US of A', 'EE.UU.', 'U-S', 'The U-S', 'Etymology of the United States', 'U.S.A.)', 'EE. UU.', 'United states of america', 'US of america', 'Verenigde State van Amerika', 'Nited States', 'United-States', 'Unite States', 'Estados Unidos de América', 'UnitedStates', 'Estaos Unios', 'US of America', 'The Usa', 'United states of America', 'Untied States of America', 'The U S of America', 'THE AMERICAN UNITED STATES', 'The United-States', 'U S A', 'AmericA', 'Estados Unidos de America', 'United states', 'The U.S. of America', 'Amerka', 'United–States', 'U.s.a.', 'United States of America', 'United State of America', 'United States (US)', 'The U.S. of A', 'America', 'Amercia', \"Stati Uniti d'America\", 'Los Estados Unidos de America', 'United Stated', 'U.S.', 'United States (of America)', 'United States', 'States of America', 'America-class', 'Los Estados Unidos', 'U,S,', 'United States (country)', 'Federal United States', 'ISO 3166-1:US', 'Untied States', 'The U.–S.–A.', 'VS America', 'Amurica', \"Etats-Unis d'Amerique\", 'US', 'U.S. OF A', 'USofA', 'Etats-Unis', 'U.S. of A', 'United States of America (U.S.A.)', 'Amarica', 'The United States', 'U-S-A', 'United States/Introduction', 'The Us', 'Unitesd states', 'The U S of A', 'America class', 'America magazine', 'الولايات المتحدة الأمريكية', 'The U. S. of A', 'U S', '(USA)', 'The United–States', 'United States (U.S.)', 'U.-S.', 'United States of America (USA)', \"'merica\", 'The US', 'United States of America.', 'UNited States', 'The U.S.', 'AMERICA', 'United States of America/OldPage', 'United+States', 'The U S', 'United Sates', 'THE UNITED STATES OF AMERICA', 'U–S–A', 'United States Of America', 'U.S. of America', 'U–S', 'Los Estados Unidos de América', 'The U.-S.', 'United sates', 'The United States Of America', 'America (country)', 'United States of American', 'United state of america', 'The U.–S.', 'Amurka', 'U. S. of A', 'The U. S.', 'United States America', 'US of A', 'États-Unis', 'USoA', 'USA', 'Estaos Uníos', 'America, United States of', 'U. S. of America', 'U.S.American', '(US)', 'The U–S–A', 'U. S.', 'U.S. America', 'U.S. A', 'Yankee land', 'America (US)', 'U.S', 'America (United States of)', 'US (country)', 'UNITED STATES OF AMERICA', 'U.S.A', 'Estados unidos', 'Americia', 'The US of america', 'Vereinigte Staaten', 'US America', 'These United States of America', 'VS Amerika', 'Name of the United States', 'The united states of america', 'Estatos Unitos', 'America (USA)', 'The U.-S.-A.', 'United States of America/Introduction', 'The US of America', 'Americophile', 'V.S. America', 'U.S.A.', 'U S of A', 'V.S. Amerika', 'United+States+of+America', 'The Unites States of America']\n"
+ ]
+ }
+ ],
+ "source": [
+ "task = \"triviaqa\"\n",
+ "dataset = load_dataset(\"THUDM/LongBench\", task, split=\"test\")\n",
+ "sample = dataset[0]\n",
+ "context = sample[\"context\"]\n",
+ "reference = sample[\"answers\"]\n",
+ "print(reference)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 106,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"id\": \"chatcmpl-94U7u0jJgszqemEVuDemWOAA09ivp\",\n",
+ " \"object\": \"chat.completion\",\n",
+ " \"created\": 1710856146,\n",
+ " \"model\": \"gpt-4-32k\",\n",
+ " \"choices\": [\n",
+ " {\n",
+ " \"index\": 0,\n",
+ " \"message\": {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": \"United States\"\n",
+ " },\n",
+ " \"finish_reason\": \"stop\"\n",
+ " }\n",
+ " ],\n",
+ " \"usage\": {\n",
+ " \"prompt_tokens\": 5527,\n",
+ " \"completion_tokens\": 2,\n",
+ " \"total_tokens\": 5529\n",
+ " },\n",
+ " \"system_fingerprint\": null\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# The response from original prompt, using GPT-4-32k\n",
+ "import json\n",
+ "\n",
+ "prompt_format = dataset2prompt[task]\n",
+ "max_gen = int(dataset2maxlen[task])\n",
+ "prompt = prompt_format.format(**sample)\n",
+ "\n",
+ "message = [\n",
+ " {\"role\": \"user\", \"content\": prompt},\n",
+ "]\n",
+ "\n",
+ "request_data = {\n",
+ " \"messages\": message,\n",
+ " \"max_tokens\": max_gen,\n",
+ " \"temperature\": 0,\n",
+ " \"top_p\": 1,\n",
+ " \"n\": 1,\n",
+ " \"stream\": False,\n",
+ "}\n",
+ "response = openai.ChatCompletion.create(\n",
+ " engine=\"gpt-4-32k\",\n",
+ " **request_data,\n",
+ ")\n",
+ "print(json.dumps(response, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "context_list = context.split(\"\\nPassage:\")\n",
+ "context_list = [\"\\nPassage:\" + c for c in context_list]\n",
+ "\n",
+ "# 2000 Compression\n",
+ "compressed_prompt = llm_lingua.compress_prompt(\n",
+ " context_list,\n",
+ " target_token=2000,\n",
+ " force_tokens=[\"\\nPassage:\", \".\", \"?\", \"\\n\"],\n",
+ " drop_consecutive=True,\n",
+ " use_context_level_filter=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 112,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"id\": \"chatcmpl-94UAXMDXa6LtrDpYz35inJxWPChyc\",\n",
+ " \"object\": \"chat.completion\",\n",
+ " \"created\": 1710856309,\n",
+ " \"model\": \"gpt-4-32k\",\n",
+ " \"choices\": [\n",
+ " {\n",
+ " \"index\": 0,\n",
+ " \"message\": {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": \"United States\"\n",
+ " },\n",
+ " \"finish_reason\": \"stop\"\n",
+ " }\n",
+ " ],\n",
+ " \"usage\": {\n",
+ " \"prompt_tokens\": 1805,\n",
+ " \"completion_tokens\": 2,\n",
+ " \"total_tokens\": 1807\n",
+ " },\n",
+ " \"system_fingerprint\": null\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# The response from original prompt, using GPT-4-32k\n",
+ "import json\n",
+ "\n",
+ "prompt_format = dataset2prompt[task]\n",
+ "max_gen = int(dataset2maxlen[task])\n",
+ "sample[\"context\"] = compressed_prompt[\"compressed_prompt\"]\n",
+ "prompt = prompt_format.format(**sample)\n",
+ "\n",
+ "message = [\n",
+ " {\"role\": \"user\", \"content\": prompt},\n",
+ "]\n",
+ "\n",
+ "request_data = {\n",
+ " \"messages\": message,\n",
+ " \"max_tokens\": max_gen,\n",
+ " \"temperature\": 0,\n",
+ " \"top_p\": 1,\n",
+ " \"n\": 1,\n",
+ " \"stream\": False,\n",
+ "}\n",
+ "response = openai.ChatCompletion.create(\n",
+ " engine=\"gpt-4-32k\",\n",
+ " **request_data,\n",
+ ")\n",
+ "print(json.dumps(response, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Summarization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[\"Multiyear procurement (MYP) and block buy contracting (BBC) are special contracting mechanisms that Congress permits the Department of Defense (DOD) to use for a limited number of defense acquisition programs. Compared to the standard or default approach of annual contracting, MYP and BBC have the potential for reducing weapon procurement costs by a few or several percent. Under annual contracting, DOD uses one or more contracts for each year's worth of procurement of a given kind of item. Under MYP, DOD instead uses a single contract for two to five years' worth of procurement of a given kind of item without having to exercise a contract option for each year after the first year. DOD needs congressional approval for each use of MYP. There is a permanent statute governing MYP contracting—10 U.S.C. 2306b. Under this statute, a program must meet several criteria to qualify for MYP. Compared with estimated costs under annual contracting, estimated savings for programs being proposed for MYP have ranged from less than 5% to more than 15%, depending on the particulars of the program in question, with many estimates falling in the range of 5% to 10%. In practice, actual savings from using MYP rather than annual contracting can be difficult to observe or verify because of cost growth during the execution of the contract due to changes in the program independent of the use of MYP rather than annual contracting. BBC is similar to MYP in that it permits DOD to use a single contract for more than one year's worth of procurement of a given kind of item without having to exercise a contract option for each year after the first year. BBC is also similar to MYP in that DOD needs congressional approval for each use of BBC. BBC differs from MYP in the following ways: There is no permanent statute governing the use of BBC. There is no requirement that BBC be approved in both a DOD appropriations act and an act other than a DOD appropriations act. Programs being considered for BBC do not need to meet any legal criteria to qualify for BBC, because there is no permanent statute governing the use of BBC that establishes such criteria. A BBC contract can cover more than five years of planned procurements. Economic order quantity (EOQ) authority—the authority to bring forward selected key components of the items to be procured under the contract and purchase the components in batch form during the first year or two of the contract—does not come automatically as part of BBC authority because there is no permanent statute governing the use of BBC that includes EOQ authority as an automatic feature. BBC contracts are less likely to include cancellation penalties. Potential issues for Congress concerning MYP and BBC include whether to use MYP and BBC in the future more frequently, less frequently, or about as frequently as they are currently used; whether to create a permanent statute to govern the use of BBC, analogous to the permanent statute that governs the use of MYP; and whether the Coast Guard should begin making use of MYP and BBC.\"]\n"
+ ]
+ }
+ ],
+ "source": [
+ "task = \"gov_report\"\n",
+ "dataset = load_dataset(\"THUDM/LongBench\", task, split=\"test\")\n",
+ "sample = dataset[0]\n",
+ "context = sample[\"context\"]\n",
+ "reference = sample[\"answers\"]\n",
+ "print(reference)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"id\": \"chatcmpl-94TMClPIEoBfTqHqw78DJV2W5r97x\",\n",
+ " \"object\": \"chat.completion\",\n",
+ " \"created\": 1710853188,\n",
+ " \"model\": \"gpt-4-32k\",\n",
+ " \"choices\": [\n",
+ " {\n",
+ " \"index\": 0,\n",
+ " \"message\": {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": \"The report discusses the use of multiyear procurement (MYP) and block buy contracting (BBC) by the Department of Defense (DOD) for defense acquisition programs. These special contracting mechanisms, permitted by Congress, have the potential to reduce weapon procurement costs by a few or several percent. The report explores whether MYP and BBC should be used more or less frequently in the future, and whether a permanent statute should be created to govern the use of BBC, similar to the one that exists for MYP. It also discusses whether the Coast Guard should start using MYP and BBC. The report clarifies that MYP and BBC are contracting mechanisms, not funding approaches, and that they can significantly change the total procurement cost of a ship. The report also explains the difference between MYP and annual contracting, and how much MYP can save. It also discusses the potential savings from BBC compared to MYP. The report concludes by discussing potential issues for Congress concerning MYP and BBC.\"\n",
+ " },\n",
+ " \"finish_reason\": \"stop\"\n",
+ " }\n",
+ " ],\n",
+ " \"usage\": {\n",
+ " \"prompt_tokens\": 10705,\n",
+ " \"completion_tokens\": 197,\n",
+ " \"total_tokens\": 10902\n",
+ " },\n",
+ " \"system_fingerprint\": null\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# The response from original prompt, using GPT-4-32k\n",
+ "import json\n",
+ "\n",
+ "prompt_format = dataset2prompt[task]\n",
+ "max_gen = int(dataset2maxlen[task])\n",
+ "prompt = prompt_format.format(**sample)\n",
+ "\n",
+ "message = [\n",
+ " {\"role\": \"user\", \"content\": prompt},\n",
+ "]\n",
+ "\n",
+ "request_data = {\n",
+ " \"messages\": message,\n",
+ " \"max_tokens\": max_gen,\n",
+ " \"temperature\": 0,\n",
+ " \"top_p\": 1,\n",
+ " \"n\": 1,\n",
+ " \"stream\": False,\n",
+ "}\n",
+ "response = openai.ChatCompletion.create(\n",
+ " engine=\"gpt-4-32k\",\n",
+ " **request_data,\n",
+ ")\n",
+ "print(json.dumps(response, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 3000 Compression\n",
+ "compressed_prompt = llm_lingua.compress_prompt(\n",
+ " context,\n",
+ " target_token=3000,\n",
+ " force_tokens=[\"!\", \".\", \"?\", \"\\n\"],\n",
+ " drop_consecutive=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"id\": \"chatcmpl-94TN0NVjsaQRlqjLtyTNQP88shi6f\",\n",
+ " \"object\": \"chat.completion\",\n",
+ " \"created\": 1710853238,\n",
+ " \"model\": \"gpt-4-32k\",\n",
+ " \"choices\": [\n",
+ " {\n",
+ " \"index\": 0,\n",
+ " \"message\": {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": \"The report discusses the issues related to multiyear procurement (MYP) and block buy contracting (BBC), special mechanisms used by the Department of Defense (DOD) for certain defense acquisition programs. These mechanisms can potentially reduce weapon procurement costs. However, they also affect defense practices, funding, and the industrial base. The report highlights that most DOD programs use traditional full funding and annual contracting, with a few using incremental funding. MYP and BBC are used in limited DOD programs. \\n\\nThe report explains that MYP is an alternative to annual contracting, allowing for a single contract for two to five years of procurement without congressional approval. The savings from MYP can range from 5% to 15%, but these savings can be difficult to assess due to cost growth. The report also discusses the benefits of MYP, such as cost optimization, workforce stability, and production facility investments. \\n\\nThe report also covers block buy contracting (BBC), which allows for a single contract for one year of procurement without a contract option. BBC can reduce unit procurement costs, but the savings are typically less than MYP. \\n\\nThe report also discusses potential issues for Congress, such as the lack of a permanent statute for BBC, the risks and benefits of MYP and BBC, and the use of these mechanisms by the Coast Guard. \\n\\nThe report concludes by discussing the Department of Defense's funding requests for MYP contracts and new MYP and block contracts for major acquisition programs in FY2020.\"\n",
+ " },\n",
+ " \"finish_reason\": \"stop\"\n",
+ " }\n",
+ " ],\n",
+ " \"usage\": {\n",
+ " \"prompt_tokens\": 3254,\n",
+ " \"completion_tokens\": 298,\n",
+ " \"total_tokens\": 3552\n",
+ " },\n",
+ " \"system_fingerprint\": null\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# The response from original prompt, using GPT-4-32k\n",
+ "import json\n",
+ "\n",
+ "prompt_format = dataset2prompt[task]\n",
+ "max_gen = int(dataset2maxlen[task])\n",
+ "sample[\"context\"] = compressed_prompt[\"compressed_prompt\"]\n",
+ "prompt = prompt_format.format(**sample)\n",
+ "\n",
+ "message = [\n",
+ " {\"role\": \"user\", \"content\": prompt},\n",
+ "]\n",
+ "\n",
+ "request_data = {\n",
+ " \"messages\": message,\n",
+ " \"max_tokens\": max_gen,\n",
+ " \"temperature\": 0,\n",
+ " \"top_p\": 1,\n",
+ " \"n\": 1,\n",
+ " \"stream\": False,\n",
+ "}\n",
+ "response = openai.ChatCompletion.create(\n",
+ " engine=\"gpt-4-32k\",\n",
+ " **request_data,\n",
+ ")\n",
+ "print(json.dumps(response, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### In-Context Learning (GSM8K)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+ "To disable this warning, you can either:\n",
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "--2024-03-19 21:21:01-- https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt\n",
+ "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...\n",
+ "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 8464 (8.3K) [text/plain]\n",
+ "Saving to: ‘prompt_hardest.txt’\n",
+ "\n",
+ "prompt_hardest.txt 100%[===================>] 8.27K --.-KB/s in 0s \n",
+ "\n",
+ "2024-03-19 21:21:02 (32.6 MB/s) - ‘prompt_hardest.txt’ saved [8464/8464]\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading readme: 100%|██████████| 7.94k/7.94k [00:00<00:00, 8.38MB/s]\n",
+ "Downloading data: 100%|██████████| 2.31M/2.31M [00:01<00:00, 2.16MB/s]\n",
+ "Downloading data: 100%|██████████| 419k/419k [00:00<00:00, 1.25MB/s]\n",
+ "Generating train split: 100%|██████████| 7473/7473 [00:00<00:00, 306960.40 examples/s]\n",
+ "Generating test split: 100%|██████████| 1319/1319 [00:00<00:00, 249866.17 examples/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "!wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt\n",
+ "prompt_complex = open(\"./prompt_hardest.txt\").read()\n",
+ "gsm8k = load_dataset(\"gsm8k\", \"main\")\n",
+ "gsm8k_test = gsm8k[\"test\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Question: Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?\n",
+ "Answer: The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130,000\n",
+ "He increased the value of the house by 80,000*1.5=<<80000*1.5=120000>>120,000\n",
+ "So the new value of the house is 120,000+80,000=$<<120000+80000=200000>>200,000\n",
+ "So he made a profit of 200,000-130,000=$<<200000-130000=70000>>70,000\n",
+ "#### 70000\n"
+ ]
+ }
+ ],
+ "source": [
+ "# select an example from GSM8K\n",
+ "question, answer = [gsm8k_test[2][key] for key in [\"question\", \"answer\"]]\n",
+ "# Ground-truth Answer\n",
+ "print(\"Question:\", question)\n",
+ "print(\"Answer:\", answer)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 67,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{\n",
+ " \"id\": \"cmpl-94TkfUcusW4yGAXSr0mTpc9PURcnJ\",\n",
+ " \"object\": \"text_completion\",\n",
+ " \"created\": 1710854705,\n",
+ " \"model\": \"gpt-35-turbo-instruct\",\n",
+ " \"choices\": [\n",
+ " {\n",
+ " \"text\": \"\\nLet's think step by step\\nThe value of the house increased by 150%, meaning it is now worth 100% + 150% = 250% of its original value.\\nIf the original value of the house was $80,000, then the new value is 250% * $80,000 = $200,000.\\nJosh spent $80,000 to buy the house and $50,000 on repairs, so his total investment was $80,000 + $50,000 = $130,000.\\nHis profit is the new value of the house ($200,000) minus his total investment ($130,000), so his profit is $200,000 - $130,000 = $70,000.\\nThe answer is $70,000\",\n",
+ " \"index\": 0,\n",
+ " \"logprobs\": null,\n",
+ " \"finish_reason\": \"stop\"\n",
+ " }\n",
+ " ],\n",
+ " \"usage\": {\n",
+ " \"prompt_tokens\": 2428,\n",
+ " \"completion_tokens\": 158,\n",
+ " \"total_tokens\": 2586\n",
+ " }\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# The response from original prompt\n",
+ "import json\n",
+ "\n",
+ "instruction = \"Please reference the following examples to answer the math question,\\n\"\n",
+ "prompt = instruction + prompt_complex + \"\\n\\nQuestion: \" + question\n",
+ "\n",
+ "request_data = {\n",
+ " \"prompt\": prompt,\n",
+ " \"max_tokens\": 400,\n",
+ " \"temperature\": 0,\n",
+ " \"top_p\": 1,\n",
+ " \"n\": 1,\n",
+ " \"stream\": False,\n",
+ " \"stop\": \"\\n\\n\",\n",
+ "}\n",
+ "response = openai.Completion.create(\n",
+ " engine=\"gpt-35-turbo-instruct\",\n",
+ " **request_data,\n",
+ ")\n",
+ "print(json.dumps(response, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'compressed_prompt': 'Sam bought dozen boxes 30 highlighter pens $10 rearranged five boxes six highlighters sold $3 per package sold rest three pens $2 profit\\n Sam bought 12 boxes x $10 = $120 highlighters\\n 12 * 30 = 360 highlighters\\n 5 boxes × 6 highlighters/box = 30\\n sold 5 * $3 = $15\\n 5 360 - 30 = 330 highlighters remaining\\n 330 / 3 = 110 groups three pens\\n sold $2 110 * 2 = $220\\n earned $220 + $15 = $235.\\n original cost $120 earned $235 - $120 = $115 profit\\nThe answer is 115', 'compressed_prompt_list': ['Sam bought dozen boxes 30 highlighter pens $10 rearranged five boxes six highlighters sold $3 per package sold rest three pens $2 profit\\n Sam bought 12 boxes x $10 = $120 highlighters\\n 12 * 30 = 360 highlighters\\n 5 boxes × 6 highlighters/box = 30\\n sold 5 * $3 = $15\\n 5 360 - 30 = 330 highlighters remaining\\n 330 / 3 = 110 groups three pens\\n sold $2 110 * 2 = $220\\n earned $220 + $15 = $235.\\n original cost $120 earned $235 - $120 = $115 profit\\nThe answer is 115'], 'origin_tokens': 2359, 'compressed_tokens': 148, 'ratio': '15.9x', 'rate': '6.3%', 'saving': ', Saving $0.1 in GPT-4.'}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 3000 Compression\n",
+ "compressed_prompt = llm_lingua.compress_prompt(\n",
+ " prompt_complex.split(\"\\n\\n\"),\n",
+ " target_token=150,\n",
+ " force_tokens=[\"+\", \"-\", \"*\", \"×\", \"/\", \"÷\", \"=\", \"The answer is\", \"\\n\"],\n",
+ " drop_consecutive=True,\n",
+ " force_reserve_digit=True,\n",
+ " use_context_level_filter=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 75,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Response: {\n",
+ " \"id\": \"cmpl-94Tof3oyRFQlgEzjhurEiOoYiYDsR\",\n",
+ " \"object\": \"text_completion\",\n",
+ " \"created\": 1710854953,\n",
+ " \"model\": \"gpt-35-turbo-instruct\",\n",
+ " \"choices\": [\n",
+ " {\n",
+ " \"text\": \"\\n\\nTo find the new value of the house, we need to multiply the original value by 150% and add it to the original value.\\n150% of $80,000 = $80,000 * 1.5 = $120,000\\nNew value of the house = $80,000 + $120,000 = $200,000\\nProfit = New value - (Original value + Repair cost)\\n= $200,000 - ($80,000 + $50,000)\\n= $200,000 - $130,000\\n= $70,000\\nJosh made a profit of $70,000.\",\n",
+ " \"index\": 0,\n",
+ " \"logprobs\": null,\n",
+ " \"finish_reason\": \"stop\"\n",
+ " }\n",
+ " ],\n",
+ " \"usage\": {\n",
+ " \"prompt_tokens\": 211,\n",
+ " \"completion_tokens\": 128,\n",
+ " \"total_tokens\": 339\n",
+ " }\n",
+ "}\n"
+ ]
+ }
+ ],
+ "source": [
+ "instruction = \"Please reference the following examples to answer the math question,\\n\"\n",
+ "prompt = (\n",
+ " instruction + compressed_prompt[\"compressed_prompt\"] + \"\\n\\nQuestion: \" + question\n",
+ ")\n",
+ "\n",
+ "request_data = {\n",
+ " \"prompt\": prompt,\n",
+ " \"max_tokens\": 400,\n",
+ " \"temperature\": 0,\n",
+ " \"top_p\": 1,\n",
+ " \"n\": 1,\n",
+ " \"stream\": False,\n",
+ " \"stop\": \"\\r\\n\",\n",
+ "}\n",
+ "response = openai.Completion.create(\n",
+ " engine=\"gpt-35-turbo-instruct\",\n",
+ " **request_data,\n",
+ ")\n",
+ "print(\"Response:\", response)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "zspan",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.18"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/OnlineMeeting.ipynb b/examples/OnlineMeeting.ipynb
index 016d292..0a6199b 100644
--- a/examples/OnlineMeeting.ipynb
+++ b/examples/OnlineMeeting.ipynb
@@ -532,7 +532,11 @@
"question = (\n",
" \"Question: what are the arrangements the Police Department will make this year?\"\n",
")\n",
- "reference = \"enhancing community engagement and internal communication models, building a culture of accountability and transparency, and prioritizing recruitment and retention.\""
+ "reference = (\n",
+ " \"enhancing community engagement and internal communication models, building a\"\n",
+ " \" culture of accountability and transparency, and prioritizing recruitment and\"\n",
+ " \" retention.\"\n",
+ ")"
]
},
{
diff --git a/experiments/llmlingua2/README.md b/experiments/llmlingua2/README.md
new file mode 100644
index 0000000..8ba4f62
--- /dev/null
+++ b/experiments/llmlingua2/README.md
@@ -0,0 +1,35 @@
+# LLMLingua-2 Experiments
+
+## Getting Started
+
+To get started with LLMLingua-2 experiments, simply install it using pip:
+
+```bash
+pip install llmlingua
+```
+
+To collect your own data using GPT-4, install the following packages:
+```bash
+pip install openai==0.28
+
+pip install spacy
+python -m spacy download en_core_web_sm
+```
+
+To train your own compressor on the collected data, install:
+```bash
+pip install scikit-learn
+pip install tensorboard
+```
+
+## Data collection
+
+We will release our collected GPT-4 compression result at [HF](https://huggingface.co/datasets/microsoft/LLMLingua-2-data-MeetingBankComp) after review. We also provide the whole data collection pipeline at [**collect_data.sh**](data_collection/collect_data.sh) to help you construct your custom compression dataset.
+
+## Model Training
+
+To train a compressor on the collected data, simply run [**train.sh**](model_training/train.sh)
+
+## Evaluation
+
+We provide a script [**compress.sh**](evaluation/scripts/compress.sh) to compress the original context on several benchmarks. After compression, run [**evaluate.sh**](evaluation/scripts/evaluate.sh) to evalate on down-stream task using the compressed prompt.
diff --git a/experiments/llmlingua2/data_collection/GPT4_compressor.py b/experiments/llmlingua2/data_collection/GPT4_compressor.py
new file mode 100644
index 0000000..e876816
--- /dev/null
+++ b/experiments/llmlingua2/data_collection/GPT4_compressor.py
@@ -0,0 +1,74 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+from time import sleep
+
+from utils import load_model_and_tokenizer
+
+SLEEP_TIME_SUCCESS = 10
+SLEEP_TIME_FAILED = 62
+
+
+class PromptCompressor:
+ def __init__(
+ self,
+ model_name,
+ user_prompt,
+ system_prompt=None,
+ temperature=0.3,
+ top_p=1.0,
+ n_max_token=32700,
+ ):
+ self.model_name = model_name
+ self.temperature = temperature
+ self.top_p = top_p
+
+ self.system_prompt = system_prompt
+ self.user_prompt = user_prompt
+ print(self.system_prompt)
+ print(self.user_prompt)
+
+ self.model, self.tokenizer = load_model_and_tokenizer(
+ self.model_name, chat_completion=True
+ )
+ self.n_max_token = n_max_token
+
+ def query_template(self, text, n_max_new_token=4096):
+ if self.user_prompt and "{text_to_compress}" in self.user_prompt:
+ prompt = self.user_prompt.format(text_to_compress=text)
+ else:
+ prompt = text
+
+ len_sys_prompt = 0
+ if self.system_prompt:
+ messages = [{"role": "system", "content": self.system_prompt}]
+ len_sys_prompt = len(self.tokenizer.encode(self.system_prompt))
+ token_ids = self.tokenizer.encode(prompt)
+ if len(token_ids) > (self.n_max_token - n_max_new_token - len_sys_prompt):
+ half = int((self.n_max_token - n_max_new_token - len_sys_prompt) / 2) - 1
+ prompt = self.tokenizer.decode(token_ids[:half]) + self.tokenizer.decode(
+ token_ids[-half:]
+ )
+ messages.append({"role": "user", "content": prompt})
+ return messages
+
+ def compress(self, text, n_max_new_token=4096):
+ messages = self.query_template(text, n_max_new_token)
+ comp = None
+ while comp is None:
+ try:
+ request = {
+ "messages": messages,
+ "temperature": self.temperature,
+ "top_p": self.top_p,
+ "max_tokens": n_max_new_token,
+ }
+ response = self.model.create(engine=self.model_name, **request)
+ if "choices" not in response:
+ print(response)
+ comp = response["choices"][0]["message"]["content"]
+ except Exception as e:
+ print(f"error: {e}")
+ sleep(SLEEP_TIME_FAILED)
+ # sleep(SLEEP_TIME_SUCCESS)
+ return comp
diff --git a/experiments/llmlingua2/data_collection/README.md b/experiments/llmlingua2/data_collection/README.md
new file mode 100644
index 0000000..8083e11
--- /dev/null
+++ b/experiments/llmlingua2/data_collection/README.md
@@ -0,0 +1,57 @@
+### Use our collected data
+
+We will release our collected GPT-4 compression result at [HF](https://huggingface.co/datasets/microsoft/LLMLingua-2-data-MeetingBankComp) after review. To load data, simply use
+
+```python
+from datasets import load_dataset
+data = load_dataset("microsoft/LLMLingua-2-data-MeetingBankComp", split="train")
+print(len(data))
+for idx, sample in enumerate(data):
+ # concatenation of all chunks
+ prompt = sample["prompt"]
+ compressed_prompt = sample["compressed_prompt"]
+```
+**prompt** is the original meeting transcript. **compressed_prompt** is the compression result after merging all compressed chunks of a transcript.
+
+To load compressed chunks along with original chunks, simply use
+```python
+from datasets import load_dataset
+data = load_dataset("microsoft/LLMLingua-2-data-MeetingBankComp", split="train")
+print(len(data))
+for idx, sample in enumerate(data):
+ # chunk list
+ prompt_list = sample["prompt_list"]
+ compressed_prompt_list = sample["compressed_prompt_list"]
+```
+
+### Construct your custom compression dataset
+
+First, format your data to a list of dict, with each dict containing at least two keys: *idx* and *prompt*. [**format_data.py**](format_data.py) illustrates how we format the meetingbank data.
+
+Then, instruct GPT-4 to compress the original context.
+
+```bash
+python compress.py --load_origin_from \
+--chunk_size 512 \
+--compressor llmcomp \
+--model_name gpt-4-32k \
+--save_path
+
+```
+
+Then, assign label to the original words and filter out poor compression samples.
+
+
+```bash
+python label_word.py \
+--load_prompt_from \
+--window_size 400 \
+--save_path \
+
+```
+
+Filter out some poorly compressed / labeled samples.
+```bash
+python filter.py --load_path \
+--save_path
+```
diff --git a/experiments/llmlingua2/data_collection/collect_data.sh b/experiments/llmlingua2/data_collection/collect_data.sh
new file mode 100644
index 0000000..acd58f8
--- /dev/null
+++ b/experiments/llmlingua2/data_collection/collect_data.sh
@@ -0,0 +1,16 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+python format_data.py
+
+python compress.py --load_origin_from ../../../results/meetingbank/origin/meetingbank_train_formated.json \
+ --compressor gpt4 \
+ --chunk_size 512 \
+ --save_path ../../../results/meetingbank/gpt-4-32k_comp/compression_cs512_meetingbank_train_formated.json
+
+python label_word.py --load_prompt_from ../../../results/meetingbank/gpt-4-32k_comp/compression_cs512_meetingbank_train_formated.json \
+ --window_size 400 \
+ --save_path ../../../results/meetingbank/gpt-4-32k_comp/annotation_cs512_meetingbank_train_formated.json
+
+python filter.py --load_path ../../../results/meetingbank/gpt-4-32k_comp/annotation_cs512_meetingbank_train_formated.pt \
+ --save_path ../../../results/meetingbank/gpt-4-32k_comp/annotation_kept_cs512_meetingbank_train_formated.pt
diff --git a/experiments/llmlingua2/data_collection/compress.py b/experiments/llmlingua2/data_collection/compress.py
new file mode 100644
index 0000000..2f5c880
--- /dev/null
+++ b/experiments/llmlingua2/data_collection/compress.py
@@ -0,0 +1,179 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+import copy
+import json
+import os
+import time
+
+import tiktoken
+from tqdm import tqdm
+
+parser = argparse.ArgumentParser(description="compress any prompt.")
+
+parser.add_argument("--compressor", help="compress method", default="gpt4")
+parser.add_argument("--model_name", help="llm used to compress", default="gpt-4-32k")
+
+parser.add_argument(
+ "--load_origin_from", help="dataset used to compress", required=True
+)
+parser.add_argument(
+ "--load_key", help="the key to load the text to compress", default="prompt"
+)
+parser.add_argument(
+ "--save_key",
+ help="the key to save the compressed text",
+ default="compressed_prompt",
+)
+
+parser.add_argument("--save_path", help="path to save results", required=True)
+# for gpt-4 compression
+parser.add_argument(
+ "--load_prompt_from", help="", default="compression_instructions.json"
+)
+parser.add_argument("--prompt_id", type=int, default=4)
+parser.add_argument("--n_max_new_token", type=int, default=4000)
+# for gpt-4 compression and selective-context
+parser.add_argument("--chunk_size", type=int, default=-1)
+# for llmlingua
+parser.add_argument(
+ "--compression_rate", help="compression rate", type=float, default=0.5
+)
+parser.add_argument(
+ "--n_target_token", help="number of target tokens", type=int, default=-1
+)
+
+args = parser.parse_args()
+os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
+
+data = json.load(open(args.load_origin_from))
+print(f"num data: {len(data)}")
+
+if args.compressor == "gpt4":
+ from GPT4_compressor import PromptCompressor
+
+ prompts = json.load(open(args.load_prompt_from))
+ system_prompt = prompts[str(args.prompt_id)]["system_prompt"]
+ user_prompt = prompts[str(args.prompt_id)]["user_prompt"]
+ compressor = PromptCompressor(
+ model_name=args.model_name, system_prompt=system_prompt, user_prompt=user_prompt
+ )
+elif args.compressor == "llmlingua" or args.compressor == "longllmlingua":
+ from llmlingua import PromptCompressor
+
+ compressor = PromptCompressor()
+elif args.compressor == "sc":
+ from select_context import SelectiveContext
+
+ compressor = SelectiveContext(model_type="NousResearch/Llama-2-7b-hf", lang="en")
+else:
+ raise NotImplementedError()
+
+results = {}
+results_list = []
+total_time = 0
+
+if os.path.exists(args.save_path):
+ results = json.load(open(args.save_path))
+
+tokenizer = tiktoken.encoding_for_model("gpt-4")
+
+
+def chunk_origin(origin_text):
+ origin_list = []
+ origin_token_ids = tokenizer.encode(origin_text)
+ end_token_ids = set(tokenizer.encode(".") + tokenizer.encode("\n"))
+ n = len(origin_token_ids)
+ st = 0
+ while st < n:
+ if st + args.chunk_size > n - 1:
+ chunk = tokenizer.decode(origin_token_ids[st:n])
+ origin_list.append(chunk)
+ break
+ else:
+ ed = st + args.chunk_size
+ for j in range(0, ed - st):
+ if origin_token_ids[ed - j] in end_token_ids:
+ ed = ed - j
+ break
+ chunk = tokenizer.decode(origin_token_ids[st : ed + 1])
+ origin_list.append(chunk)
+ st = ed + 1
+ return origin_list
+
+
+for sample in tqdm(data):
+ idx = int(sample["idx"])
+ origin = copy.deepcopy(sample[args.load_key])
+ if origin is None:
+ continue
+ if idx in results or str(idx) in results:
+ print(f"{idx}-th sample is processed")
+ continue
+
+ t = time.time()
+ if args.compressor == "llmlingua" or args.compressor == "longllmlingua":
+ comp_dict = compressor.compress_prompt(
+ origin, ratio=args.compression_rate, target_token=args.n_target_token
+ )
+ comp = comp_dict["compressed_prompt"]
+ else:
+ # multi document
+ if isinstance(origin, list):
+ if args.chunk_size > 0:
+ chunk_list = []
+ for j, document in enumerate(origin):
+ ori_list = chunk_origin(document)
+ chunk_list.extend(ori_list)
+ origin = chunk_list
+ # single document
+ else:
+ origin = [origin]
+ if args.chunk_size > 0:
+ origin = chunk_origin(origin[0])
+ print(f"num chunk: {len(origin)}")
+ comp_list = []
+ for j, chunk in enumerate(origin):
+ if args.compressor == "gpt4":
+ comp = compressor.compress(chunk, args.n_max_new_token)
+ elif args.compressor == "sc":
+ if args.n_target_token > 0:
+ reduce_ratio = 1 - min(
+ (args.n_target_token // len(origin))
+ / len(tokenizer.encode(chunk)),
+ 1.0,
+ )
+ else:
+ reduce_ratio = 1.0 - args.compression_ratio
+ comp, reduced = compressor(
+ chunk, reduce_ratio=reduce_ratio, reduce_level="token"
+ )
+ comp = comp.replace("", "").replace("", "")
+ comp_list.append(comp)
+ assert len(origin) == len(comp_list)
+ comp = "".join(comp_list)
+
+ total_time += time.time() - t
+ new_sample = copy.deepcopy(sample)
+ new_sample[args.save_key] = comp
+ if (
+ not (args.compressor == "llmlingua" or args.compressor == "longllmlingua")
+ and len(comp_list) > 0
+ ):
+ assert len(origin) == len(comp_list)
+ new_sample["prompt_list"] = origin[:]
+ new_sample["compressed_prompt_list"] = comp_list[:]
+
+ results[idx] = new_sample
+ json.dump(
+ results,
+ open(args.save_path, "w", encoding="utf8"),
+ indent=4,
+ ensure_ascii=False,
+ )
+
+print(args.save_path, total_time)
+json.dump(
+ results, open(args.save_path, "w", encoding="utf8"), indent=4, ensure_ascii=False
+)
diff --git a/experiments/llmlingua2/data_collection/compression_instructions.json b/experiments/llmlingua2/data_collection/compression_instructions.json
new file mode 100644
index 0000000..7744b1b
--- /dev/null
+++ b/experiments/llmlingua2/data_collection/compression_instructions.json
@@ -0,0 +1,7 @@
+{
+ "0":{"system_prompt": "Could you please rephrase the paragraph to make it short, and keep 5% tokens?", "user_prompt": ""},
+ "1":{"system_prompt": "Summarize the provided examples in a few sentences, maintaining all essential reasoning aspects", "user_prompt": ""},
+ "2":{"system_prompt": "Follow these steps to shorten the given text content: 1. First, calculate the amount of information contained in each sentence, and remove sentences with less information. 2. Next, further condense the text by removing stop words, unnecessary punctuation, and redundant expressions. Refine the content while ensuring that all key information is retained. Let's do it step by step.", "user_prompt": ""},
+ "3":{"system_prompt": "Remove redundancy and express the text concisely in English, ensuring that all key information and reasoning processes are preserved.", "user_prompt": ""},
+ "4":{"system_prompt": "You are an excellent linguist and very good at compressing passages into short expressions by removing unimportant words, while retaining as much information as possible.", "user_prompt": "Compress some text to short expressions, and such that you (GPT-4) can reconstruct it as close as possible to the original. Unlike the usual text compression, I need you to comply with the 5 conditions below: 1. You can ONLY remove unimportant words. 2. Do not change the order of words. 3. Do not change the original words, e.g. 'asking'->'ask' is NOT OK, 'current'->'now' is NOT OK. 4. Do not use abbreviations or emojis, e.g. 'without'->'w/o' is NOT OK, 'as soon as possible'->'ASAP' is NOT OK. 5. Do not add new words or symbols, this is very important. For example, 'dedicate 3 hours to each chapter'->'3 hours/chapter' is NOT OK because you add new token '/', just compress it into '3 hours each chapter'. '30 eggs plus 20 eggs equals 50 eggs'->'30+20=50' is also NOT OK becuase you add new symbols + and =, just compress it into '30 plus 20 equals 50'. \nCompress the origin aggressively by removing words only. Compress the origin as short as you can, while retaining as much information as possible. \nIf you understand, please compress the following text: \n{text_to_compress}\nThe compressed text is: "}
+}
diff --git a/experiments/llmlingua2/data_collection/filter.py b/experiments/llmlingua2/data_collection/filter.py
new file mode 100644
index 0000000..5b31417
--- /dev/null
+++ b/experiments/llmlingua2/data_collection/filter.py
@@ -0,0 +1,97 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+from collections import defaultdict
+
+import numpy as np
+import torch
+
+parser = argparse.ArgumentParser(description="compress any prompt.")
+parser.add_argument(
+ "--load_path",
+ help="path to load data",
+ default="../../../results/meetingbank/gpt-4-32k_comp/annotation_cs512_meetingbank_train_formated.pt",
+)
+parser.add_argument(
+ "--save_path",
+ help="path to save filtered data",
+ default="../../../results/meetingbank/gpt-4-32k_comp/annotation_kept_cs512_meetingbank_train_formated.pt",
+)
+args = parser.parse_args()
+
+res_pt = torch.load(args.load_path)
+
+## filtering
+variation_rate_list = res_pt["variation_rate"]
+print(len(variation_rate_list))
+threshold = np.percentile(variation_rate_list, 90)
+kept, filtered = defaultdict(list), defaultdict(list)
+for labels, origin, comp, retrieval, cr, vr, hr, mr, ag in zip(
+ res_pt["labels"],
+ res_pt["origin"],
+ res_pt["comp"],
+ res_pt["retrieval"],
+ res_pt["comp_rate"],
+ res_pt["variation_rate"],
+ res_pt["hitting_rate"],
+ res_pt["matching_rate"],
+ res_pt["alignment_gap"],
+):
+ if vr >= threshold:
+ filtered["labels"].append(labels)
+ filtered["origin"].append(origin)
+ filtered["comp"].append(comp)
+ filtered["retrieval"].append(retrieval)
+ filtered["comp_rate"].append(cr)
+ filtered["variation_rate"].append(vr)
+ filtered["hitting_rate"].append(hr)
+ filtered["matching_rate"].append(mr)
+ filtered["alignment_gap"].append(ag)
+ else:
+ kept["labels"].append(labels)
+ kept["origin"].append(origin)
+ kept["comp"].append(comp)
+ kept["retrieval"].append(retrieval)
+ kept["comp_rate"].append(cr)
+ kept["variation_rate"].append(vr)
+ kept["hitting_rate"].append(hr)
+ kept["matching_rate"].append(mr)
+ kept["alignment_gap"].append(ag)
+
+alignment_gap_list = kept["alignment_gap"]
+threshold = np.percentile(alignment_gap_list, 90)
+kept2 = defaultdict(list)
+for labels, origin, comp, retrieval, cr, vr, hr, mr, ag in zip(
+ kept["labels"],
+ kept["origin"],
+ kept["comp"],
+ res_pt["retrieval"],
+ kept["comp_rate"],
+ kept["variation_rate"],
+ kept["hitting_rate"],
+ kept["matching_rate"],
+ kept["alignment_gap"],
+):
+ if ag >= threshold:
+ filtered["labels"].append(labels)
+ filtered["origin"].append(origin)
+ filtered["comp"].append(comp)
+ filtered["retrieval"].append(retrieval)
+ filtered["comp_rate"].append(cr)
+ filtered["variation_rate"].append(vr)
+ filtered["hitting_rate"].append(hr)
+ filtered["matching_rate"].append(mr)
+ filtered["alignment_gap"].append(ag)
+ else:
+ kept2["labels"].append(labels)
+ kept2["origin"].append(origin)
+ kept2["comp"].append(comp)
+ kept2["retrieval"].append(retrieval)
+ kept2["comp_rate"].append(cr)
+ kept2["variation_rate"].append(vr)
+ kept2["hitting_rate"].append(hr)
+ kept2["matching_rate"].append(mr)
+ kept2["alignment_gap"].append(ag)
+
+torch.save(kept2, args.save_path)
diff --git a/experiments/llmlingua2/data_collection/format_data.py b/experiments/llmlingua2/data_collection/format_data.py
new file mode 100644
index 0000000..41a82a3
--- /dev/null
+++ b/experiments/llmlingua2/data_collection/format_data.py
@@ -0,0 +1,22 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import json
+import os
+
+from datasets import load_dataset
+
+dataset = load_dataset("huuuyeah/meetingbank", split="train")
+data = []
+for idx, instance in enumerate(dataset):
+ temp = {}
+ temp["idx"] = idx
+ temp["prompt"] = instance["transcript"]
+ temp["summary"] = instance["summary"]
+ data.append(temp)
+os.makedirs("../../../results/meetingbank/origin/", exist_ok=True)
+json.dump(
+ data,
+ open("../../../results/meetingbank/origin/meetingbank_train_formated.json", "w"),
+ indent=4,
+)
diff --git a/experiments/llmlingua2/data_collection/label_word.py b/experiments/llmlingua2/data_collection/label_word.py
new file mode 100644
index 0000000..5593b78
--- /dev/null
+++ b/experiments/llmlingua2/data_collection/label_word.py
@@ -0,0 +1,214 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+import json
+import logging
+import os
+from collections import defaultdict
+
+import spacy
+import torch
+from tqdm import tqdm
+
+parser = argparse.ArgumentParser(description="annotate token")
+parser.add_argument(
+ "--dataset_name", help="dataset used to compress", default="meetingbank"
+)
+parser.add_argument("--split", help="dataset part", default="train")
+parser.add_argument(
+ "--load_prompt_from",
+ help="where to load compressed prompt",
+ default="results/meetingbank/origin-comp-list_llmcomp_cs512.json",
+)
+parser.add_argument(
+ "--save_path",
+ help="path to save results",
+ default="results/meetingbank/annotation/label_word.json",
+)
+parser.add_argument("--window_size", help="window size", type=int, default=150)
+parser.add_argument(
+ "--verbose",
+ help="print debug info",
+ action=argparse.BooleanOptionalAction,
+ default=False,
+)
+
+args = parser.parse_args()
+os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
+logging.basicConfig(
+ filename=f"{os.path.dirname(args.save_path)}/log.log",
+ level=logging.INFO,
+ format="%(asctime)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger()
+
+nlp = spacy.load("en_core_web_sm")
+
+
+def split_string(input_string, ignore_tokens=set([","])):
+ doc = nlp(input_string)
+ word_list = []
+ for word in doc:
+ if word.lemma_ not in ignore_tokens:
+ word_list.append(word.lemma_)
+ return word_list
+
+
+def is_equal(token1, token2):
+ return token1.lower() == token2.lower()
+
+
+origins, comps = [], []
+raw_data = json.load(open(args.load_prompt_from, "r"))
+for sid, sample in raw_data.items():
+ if len(sample["prompt_list"]) != len(sample["compressed_prompt_list"]):
+ print(f"{sid}-th length not equal")
+ continue
+ origins.extend(sample["prompt_list"])
+ comps.extend(sample["compressed_prompt_list"])
+
+res = {}
+res_pt = defaultdict(list)
+
+num_sample = 0
+compression_rate_avg = 0
+find_rate_avg = 0
+variation_rate_avg = 0
+matching_rate_avg = 0
+hitting_rate_avg = 0
+alignment_gap_avg = 0
+
+for chunk_idx, (origin, comp) in tqdm(enumerate(zip(origins, comps))):
+ num_sample += 1
+ origin_tokens = split_string(origin)
+ comp_tokens = split_string(comp)
+ origin_tokens_set = set(origin_tokens)
+ for token in origin_tokens:
+ origin_tokens_set.add(token.lower())
+
+ num_find = 0
+ prev_idx = 0
+ back_cnt = 0
+ num_origin_tokens = len(origin_tokens)
+ labels = [False] * num_origin_tokens
+ for token in comp_tokens:
+ flag = False
+ if token in origin_tokens_set or token.lower() in origin_tokens_set:
+ num_find += 1
+ for i in range(args.window_size):
+ # look forward
+ token_idx = min(prev_idx + i, num_origin_tokens - 1)
+ if is_equal(origin_tokens[token_idx], token) and not labels[token_idx]:
+ labels[token_idx] = True
+ # window do not go too fast
+ if token_idx - prev_idx > args.window_size // 2:
+ prev_idx += args.window_size // 2
+ else:
+ prev_idx = token_idx
+ if args.verbose:
+ print(
+ token,
+ token_idx,
+ prev_idx,
+ origin_tokens[token_idx - 1 : token_idx + 2],
+ )
+ flag = True
+ break
+ # look backward
+ token_idx = max(prev_idx - i, 0)
+ if is_equal(origin_tokens[token_idx], token) and not labels[token_idx]:
+ labels[token_idx] = True
+ prev_idx = token_idx
+ if args.verbose:
+ print(
+ token,
+ token_idx,
+ prev_idx,
+ origin_tokens[token_idx - 1 : token_idx + 2],
+ )
+ flag = True
+ break
+
+ retrieval_tokens = []
+ for idx, token in enumerate(origin_tokens):
+ if labels[idx]:
+ retrieval_tokens.append(token)
+ retrieval = " ".join(retrieval_tokens)
+
+ comp_rate = len(comp_tokens) / len(origin_tokens)
+ if len(comp_tokens) > 0:
+ find_rate = num_find / len(comp_tokens)
+ else:
+ find_rate = 0.0
+ variation_rate = 1 - find_rate
+ hitting_rate = num_find / len(origin_tokens)
+ matching_rate = sum(labels) / len(labels)
+ alignment_gap = hitting_rate - matching_rate
+
+ compression_rate_avg += comp_rate
+ find_rate_avg += find_rate
+ variation_rate_avg += variation_rate
+ hitting_rate_avg += hitting_rate
+ matching_rate_avg += matching_rate
+ alignment_gap_avg += alignment_gap
+
+ if alignment_gap > 0.1:
+ print(origin)
+ print("-" * 50)
+ print(comp)
+ print("-" * 50)
+ print(retrieval)
+ print("-" * 50)
+ print(origin_tokens)
+ print("-" * 50)
+ print(comp_tokens)
+ print("-" * 50)
+ print(retrieval_tokens)
+ print("=" * 50)
+
+ print(
+ f"comp rate: {comp_rate}, variation_rate: {variation_rate}, alignment_gap: {alignment_gap}"
+ )
+
+ res[chunk_idx] = {
+ "labels": labels,
+ "origin": origin,
+ "comp": comp,
+ "retrieval": retrieval,
+ "origin_tokens": origin_tokens,
+ "comp_rate": comp_rate,
+ "variation_rate": variation_rate,
+ "hitting_rate": hitting_rate,
+ "matching_rate": matching_rate,
+ "alignment_gap": alignment_gap,
+ }
+
+ res_pt["labels"].append(labels)
+ res_pt["origin"].append(origin)
+ res_pt["comp"].append(comp)
+ res_pt["retrieval"].append(retrieval)
+ res_pt["origin_tokens"].append(origin_tokens)
+ res_pt["comp_rate"].append(comp_rate)
+ res_pt["variation_rate"].append(variation_rate)
+ res_pt["hitting_rate"].append(hitting_rate)
+ res_pt["matching_rate"].append(matching_rate)
+ res_pt["alignment_gap"].append(alignment_gap)
+
+ if int(chunk_idx) % 1000 == 0:
+ json.dump(res, open(args.save_path, "w"), indent=4)
+ torch.save(res_pt, args.save_path.replace(".json", ".pt"))
+
+json.dump(res, open(args.save_path, "w"), indent=4)
+torch.save(res_pt, args.save_path.replace(".json", ".pt"))
+
+compression_rate_avg = compression_rate_avg / num_sample
+find_rate_avg = find_rate_avg / num_sample
+variation_rate_avg = variation_rate_avg / num_sample
+matching_rate_avg = matching_rate_avg / num_sample
+hitting_rate_avg = hitting_rate_avg / num_sample
+alignment_gap_avg = alignment_gap_avg / num_sample
+
+print_info = f"window size: {args.window_size}, comp rate: {compression_rate_avg}, hitting_rate: {hitting_rate_avg}, retrieval rate: {matching_rate_avg}"
+print(print_info)
+logger.info(print_info)
diff --git a/experiments/llmlingua2/data_collection/utils.py b/experiments/llmlingua2/data_collection/utils.py
new file mode 100644
index 0000000..165592e
--- /dev/null
+++ b/experiments/llmlingua2/data_collection/utils.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+from time import sleep
+
+import openai
+import tiktoken
+
+
+def query_llm(
+ prompt,
+ model,
+ model_name,
+ max_tokens,
+ tokenizer=None,
+ chat_completion=False,
+ **kwargs,
+):
+ SLEEP_TIME_FAILED = 62
+
+ request = {
+ "temperature": kwargs["temperature"] if "temperature" in kwargs else 0.0,
+ "top_p": kwargs["top_p"] if "top_p" in kwargs else 1.0,
+ "seed": kwargs["seed"] if "seed" in kwargs else 42,
+ "max_tokens": max_tokens,
+ "n": 1,
+ "stream": False,
+ }
+ if chat_completion:
+ request["messages"] = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt},
+ ]
+ else:
+ request["prompt"] = prompt
+
+ answer = None
+ response = None
+ while answer is None:
+ try:
+ response = model.create(engine=model_name, **request)
+ answer = (
+ response["choices"][0]["message"]["content"]
+ if chat_completion
+ else response["choices"][0]["text"]
+ )
+ except Exception as e:
+ answer = None
+ print(f"error: {e}, response: {response}")
+ sleep(SLEEP_TIME_FAILED)
+ # sleep(SLEEP_TIME_SUCCESS)
+ return answer
+
+
+def load_model_and_tokenizer(model_name_or_path, chat_completion=False):
+ openai.api_key = "your_api_key"
+ openai.api_base = "your_api_base"
+ openai.api_type = "azure"
+ openai.api_version = "2023-05-15"
+
+ if chat_completion:
+ model = openai.ChatCompletion
+ else:
+ model = openai.Completion
+
+ tokenizer = tiktoken.encoding_for_model("gpt-4")
+ return model, tokenizer
diff --git a/experiments/llmlingua2/evaluation/compress.py b/experiments/llmlingua2/evaluation/compress.py
new file mode 100644
index 0000000..6ea8ec9
--- /dev/null
+++ b/experiments/llmlingua2/evaluation/compress.py
@@ -0,0 +1,136 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+import copy
+import json
+import os
+import time
+
+from tqdm import tqdm
+
+from llmlingua.prompt_compressor import PromptCompressor
+
+parser = argparse.ArgumentParser(description="compress any prompt.")
+
+parser.add_argument("--compressor", help="compress method", default="llmcomp")
+parser.add_argument(
+ "--model_name",
+ help="llm used to compress",
+ default="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
+)
+parser.add_argument(
+ "--load_origin_from", help="dataset used to compress", required=True
+)
+parser.add_argument(
+ "--load_key", help="the key to load the text to compress", default="prompt"
+)
+parser.add_argument(
+ "--save_key",
+ help="the key to save the compressed text",
+ default="compressed_prompt",
+)
+
+parser.add_argument("--save_path", help="path to save results", required=True)
+
+# for llmlingua2
+parser.add_argument(
+ "--compression_rate", help="compression rate", type=float, default=0.5
+)
+parser.add_argument(
+ "--target_token", help="number of target tokens", type=int, default=-1
+)
+# llmlingua2 coarse to fine
+parser.add_argument(
+ "--use_token_level_filter", action=argparse.BooleanOptionalAction, default=True
+)
+parser.add_argument(
+ "--use_context_level_filter", action=argparse.BooleanOptionalAction, default=False
+)
+parser.add_argument("--target_context", type=int, default=-1)
+parser.add_argument("--context_level_compression_rate", type=float, default=1.0)
+parser.add_argument("--context_level_target_token", type=int, default=-1)
+# llmlingua2 details
+parser.add_argument(
+ "--force_tokens",
+ help="the tokens which will be forcely preserved, comma separated",
+ type=str,
+ default=None,
+)
+parser.add_argument(
+ "--drop_consecutive", action=argparse.BooleanOptionalAction, default=True
+)
+parser.add_argument(
+ "--force_reserve_digit", action=argparse.BooleanOptionalAction, default=False
+)
+
+args = parser.parse_args()
+os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
+if args.force_tokens is not None:
+ args.force_tokens = [
+ str(item).replace("\\n", "\n") for item in args.force_tokens.split(",")
+ ]
+else:
+ args.force_tokens = []
+print(f"force tokens: {args.force_tokens}")
+
+data = json.load(open(args.load_origin_from))
+print(f"num data: {len(data)}")
+
+compressor = PromptCompressor(
+ model_name=args.model_name,
+ model_config={},
+ use_llmlingua2=True,
+)
+
+results = {}
+results_list = []
+total_time = 0
+
+if os.path.exists(args.save_path):
+ results = json.load(open(args.save_path))
+
+for sample in tqdm(data):
+ idx = int(sample["idx"])
+ origin = copy.deepcopy(sample[args.load_key])
+ if origin is None:
+ continue
+ if idx in results or str(idx) in results:
+ print(f"{idx}-th sample is processed")
+ continue
+ t = time.time()
+ comp_dict = compressor.compress_prompt_llmlingua2(
+ origin,
+ rate=args.compression_rate,
+ target_token=args.target_token,
+ use_context_level_filter=args.use_context_level_filter,
+ use_token_level_filter=args.use_token_level_filter,
+ target_context=args.target_context,
+ context_level_rate=args.context_level_compression_rate,
+ context_level_target_token=args.context_level_target_token,
+ force_tokens=args.force_tokens,
+ drop_consecutive=args.drop_consecutive,
+ force_reserve_digit=args.force_reserve_digit,
+ )
+ total_time += time.time() - t
+ comp = comp_dict["compressed_prompt"]
+ comp_list = comp_dict["compressed_prompt_list"]
+
+ new_sample = copy.deepcopy(sample)
+ new_sample[args.save_key] = comp
+ if comp_list is not None and args.load_key == "prompt_list":
+ new_sample["compressed_prompt_list"] = comp_list
+ print(len(new_sample["prompt_list"]), len(new_sample["compressed_prompt_list"]))
+
+ results[idx] = new_sample
+ json.dump(
+ results,
+ open(args.save_path, "w", encoding="utf8"),
+ indent=4,
+ ensure_ascii=False,
+ )
+
+print(args.save_path, total_time)
+json.dump(
+ results, open(args.save_path, "w", encoding="utf8"), indent=4, ensure_ascii=False
+)
diff --git a/experiments/llmlingua2/evaluation/eval_bbh.py b/experiments/llmlingua2/evaluation/eval_bbh.py
new file mode 100644
index 0000000..96fb319
--- /dev/null
+++ b/experiments/llmlingua2/evaluation/eval_bbh.py
@@ -0,0 +1,269 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+import json
+import os
+import re
+from collections import defaultdict
+
+import tiktoken
+from tqdm import tqdm
+from utils import load_model_and_tokenizer, query_llm
+
+parser = argparse.ArgumentParser(description="compress any prompt.")
+parser.add_argument(
+ "--model_name_or_path", help="LLM used to answer", default="gpt-3.5-turbo-0613"
+)
+
+parser.add_argument("--n_max_token", type=int, default=8100)
+parser.add_argument(
+ "--n_max_token_ans",
+ type=int,
+ default=400,
+ help="token num in answer, following llmlingua",
+)
+
+parser.add_argument(
+ "--load_prompt_from",
+ help="where to load compressed prompt",
+ default="results/gsm8k/origin/gsm8k_test.json",
+)
+parser.add_argument("--load_key", default="prompt", type=str)
+parser.add_argument(
+ "--save_path",
+ help="path to save results",
+ default="results/gsm8k/origin/gpt35_answer/answer_gsm8k_test.json",
+)
+
+parser.add_argument("--num_sample", default=-1, type=int)
+args = parser.parse_args()
+os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
+
+
+MULTIPLE_CHOICE_TASKS = [
+ "temporal_sequences",
+ "disambiguation_qa",
+ "date_understanding",
+ "tracking_shuffled_objects_three_objects",
+ "penguins_in_a_table",
+ "geometric_shapes",
+ "snarks",
+ "ruin_names",
+ "tracking_shuffled_objects_seven_objects",
+ "tracking_shuffled_objects_five_objects",
+ "logical_deduction_three_objects",
+ "hyperbaton",
+ "logical_deduction_five_objects",
+ "logical_deduction_seven_objects",
+ "movie_recommendation",
+ "salient_translation_error_detection",
+ "reasoning_about_colored_objects",
+]
+FREE_FORM_TASKS = [
+ "multistep_arithmetic_two",
+ "navigate",
+ "dyck_languages",
+ "word_sorting",
+ "sports_understanding",
+ "boolean_expressions",
+ "object_counting",
+ "formal_fallacies",
+ "causal_judgement",
+ "web_of_lies",
+]
+
+
+def extract_ans(ans, mode):
+ ans_line = ans.split("answer is ", 1)
+ # Expect to see 'answer is'. If not return whole string
+ if len(ans_line) == 1:
+ return ans
+ else:
+ ans = ans_line[-1].strip()
+
+ if mode == "multiple_choice":
+ options = [
+ "(A)",
+ "(B)",
+ "(C)",
+ "(D)",
+ "(E)",
+ "(F)",
+ "(G)",
+ "(H)",
+ "(I)",
+ "(J)",
+ "(K)",
+ "(L)",
+ "(M)",
+ "(N)",
+ "(O)",
+ "(P)",
+ "(Q)",
+ "(R)",
+ "(S)",
+ "(T)",
+ "(U)",
+ "(V)",
+ "(W)",
+ "(X)",
+ "(Y)",
+ "(Z)",
+ ]
+ match_g = []
+ for option in options:
+ if option in ans:
+ # ans = option[1]
+ match_g.append((ans.index(option), option[1]))
+ if match_g:
+ match_g.sort(key=lambda x: x[0])
+ return match_g[0][1]
+ elif mode == "free_form":
+ ans = ans.split(".", 1)[0]
+ if ans[-1] == ".":
+ ans = ans[:-1]
+ return ans
+
+
+def analyze_cases(good, bad, task):
+ _, good_questions, good_ans_pred, good_ans_gold = good
+ _, bad_questions, bad_ans_pred, bad_ans_gold = bad
+ mode = "multiple_choice" if task in MULTIPLE_CHOICE_TASKS else "free_form"
+ true_map, x_map = {}, {}
+ for q, p, g in zip(good_questions[task], good_ans_pred[task], good_ans_gold[task]):
+ p_ans, g_ans = extract_ans(p, mode), g
+ if p_ans == g_ans:
+ true_map[q] = (p, g, p_ans, g_ans)
+ x_map[q] = (p, g, p_ans, g_ans)
+ false_map = {}
+ for q, p, g in zip(bad_questions[task], bad_ans_pred[task], bad_ans_gold[task]):
+ p_ans, g_ans = extract_ans(p, mode), g
+ if p_ans != g_ans and q in true_map:
+ false_map[q] = (p, g, p_ans, g_ans)
+
+
+def parse_pred_ans(path: str):
+ res = open(path).read()
+ pattern = "Task:(.*?)\n(.*?)\nA_model:(.*?)\nA_target:(.*?)\n\n"
+ g, ans = defaultdict(int), defaultdict(list)
+ questions, ans_models, ans_targets = (
+ defaultdict(list),
+ defaultdict(list),
+ defaultdict(list),
+ )
+ for m in re.findall(pattern, res, re.S):
+ task, question, ans_model, ans_target = m
+ task = task.strip()
+ mode = "multiple_choice" if task in MULTIPLE_CHOICE_TASKS else "free_form"
+ question = question.strip()
+ ans_model = ans_model.strip()
+ ans_target = ans_target.strip()
+ p, gg = extract_ans(ans_model, mode), ans_target
+ g[task] += int(p == gg)
+ ans[task].append((ans_model, gg))
+ questions[task].append(question)
+ ans_models[task].append(ans_model)
+ ans_targets[task].append(ans_target)
+ scores = defaultdict(dict)
+ total_num = 0
+ for task, correct in g.items():
+ scores[task]["acc"] = correct / len(ans[task])
+ scores[task]["num"] = len(ans[task])
+ print(task, correct, len(ans[task]), correct / len(ans[task]))
+ total_num += len(ans[task])
+ print(total_num)
+ score_list = [v["acc"] for v in scores.values()]
+ scores["avg"] = sum(score_list) / len(score_list)
+ # return ans, questions, ans_models, ans_targets
+ return scores
+
+
+def get_generation_token_length(path):
+ res = open(path, "r").read()
+ pattern = "Task:(.*?)\n(.*?)\nA_model:(.*?)\nA_target:(.*?)\n\n"
+ tokenizer = tiktoken.encoding_for_model("gpt-4")
+ tokens = []
+ for m in re.findall(pattern, res, re.S):
+ task, question, ans_model, ans_target = m
+ tokens.append(len(tokenizer.encode(ans_model)))
+ return sum(tokens) / len(tokens)
+
+
+def predict():
+ model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)
+
+ results = {}
+ if os.path.exists(args.save_path):
+ results = json.load(open(args.save_path))
+
+ demonstration = json.load(open(args.load_prompt_from))
+ prompts = {}
+ instructions = {}
+ for demon in demonstration.values():
+ task = demon["task"]
+ prompt = demon[args.load_key]
+ instructions[task] = demon["instruction"]
+ prompts[task] = prompt
+ print(prompts)
+ print(instructions)
+
+ dataset = json.load(open("results/bbh/origin/bbh.json"))
+ for sample in tqdm(dataset):
+ idx = sample["idx"]
+ task = sample["task"]
+ task_type = "multiple_choice" if task in MULTIPLE_CHOICE_TASKS else "free_form"
+ cot_prompt = prompts[task]
+ instruction = instructions[task]
+ if args.num_sample > 0 and int(idx) > args.num_sample:
+ break
+ if idx in results or str(idx) in results:
+ print(f"{idx}-th processed")
+ continue
+ q = sample["question"]
+ a = sample["answer"]
+
+ if cot_prompt[0] != "\n":
+ cot_prompt = "\n\n" + cot_prompt
+ # print(cot_prompt)
+ prompt = (
+ f"{instruction}{cot_prompt}\n\nQ: {q}" + "\nA:Let's think step by step.\n"
+ )
+ token_ids = tokenizer.encode(prompt)
+ # drop in middle
+ if len(token_ids) > (args.n_max_token - args.n_max_token_ans):
+ half = int((args.n_max_token - args.n_max_token_ans) / 2) - 1
+ prompt = tokenizer.decode(token_ids[:half]) + tokenizer.decode(
+ token_ids[-half:]
+ )
+ answer = query_llm(
+ prompt,
+ model,
+ args.model_name_or_path,
+ 400 if task != "geometric_shapes" else 800,
+ )
+
+ results[idx] = {"question": q, "model_answer": answer, "truth_answer": a}
+ json.dump(results, open(args.save_path, "w"), indent=4)
+
+ ans_ = extract_ans(answer, task_type)
+ if task_type == "multiple_choice":
+ a = a[1]
+ res = "%dTask:%s\n%s\nA_model:%s\nA_target:%s\n\n" % (
+ idx,
+ task,
+ q.replace("\n", ""),
+ answer.replace("\n", "").replace("Q:", "").replace("A:", ""),
+ a.replace("\n", ""),
+ )
+ with open(args.save_path.replace(".json", ".txt"), "a") as fd:
+ fd.write(res)
+
+
+predict()
+scores = parse_pred_ans(args.save_path.replace(".json", ".txt"))
+save_path2 = os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "metrics"),
+)
+json.dump(scores, open(save_path2, "w"))
diff --git a/experiments/llmlingua2/evaluation/eval_gsm8k.py b/experiments/llmlingua2/evaluation/eval_gsm8k.py
new file mode 100644
index 0000000..259f686
--- /dev/null
+++ b/experiments/llmlingua2/evaluation/eval_gsm8k.py
@@ -0,0 +1,167 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+import json
+import os
+import re
+
+from tqdm import tqdm
+from utils import load_model_and_tokenizer, query_llm
+
+parser = argparse.ArgumentParser(description="compress any prompt.")
+parser.add_argument(
+ "--model_name_or_path", help="LLM used to answer", default="gpt-3.5-turbo-0613"
+)
+
+parser.add_argument("--n_max_token", type=int, default=8100)
+parser.add_argument(
+ "--n_max_token_ans",
+ type=int,
+ default=400,
+ help="token num in answer, following llmlingua",
+)
+
+parser.add_argument(
+ "--load_prompt_from",
+ help="where to load compressed prompt",
+ default="results/gsm8k/origin/gsm8k_test.json",
+)
+parser.add_argument("--load_key", default="prompt", type=str)
+parser.add_argument(
+ "--save_path",
+ help="path to save results",
+ default="results/gsm8k/origin/gpt35_answer/answer_gsm8k_test.json",
+)
+
+parser.add_argument("--num_sample", default=-1, type=int)
+args = parser.parse_args()
+os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
+
+
+def extract_ans(ans_model):
+ ans_model = ans_model.split("\n")
+ ans = []
+ residual = []
+ for li, al in enumerate(ans_model):
+ ans.append(al)
+ if "answer is" in al:
+ break
+ residual = list(ans_model[li + 1 :])
+ ans = "\n".join(ans)
+ residual = "\n".join(residual)
+ return ans, residual
+
+
+def parse_pred_ans(filename):
+ with open(filename) as fd:
+ lines = fd.readlines()
+ am, a = None, None
+ num_q, acc = 0, 0
+ current_mode = "none"
+ questions = []
+ ans_pred = []
+ ans_gold = []
+ for l in lines:
+ l = l.replace(",", "")
+ if l.startswith("Q: "):
+ if am is not None and a is not None:
+ questions.append(q)
+ ans_pred.append(am)
+ ans_gold.append(a)
+ if test_answer(am, a):
+ acc += 1
+ current_mode = "q"
+ q = l
+ num_q += 1
+ elif l.startswith("A_model:"):
+ current_mode = "am"
+ am = l
+ elif l.startswith("A:"):
+ current_mode = "a"
+ a = l
+ else:
+ if current_mode == "q":
+ q += l
+ elif current_mode == "am":
+ am += l
+ elif current_mode == "a":
+ a += l
+ else:
+ raise ValueError(current_mode)
+
+ questions.append(q)
+ ans_pred.append(am)
+ ans_gold.append(a)
+ if test_answer(am, a):
+ acc += 1
+ print("num_q %d correct %d ratio %.4f" % (num_q, acc, float(acc / num_q)))
+ return questions, ans_pred, ans_gold
+
+
+def get_result(text: str):
+ pattern = "\d*\.?\d+"
+ res = re.findall(pattern, text)
+ return res[-1] if res else ""
+
+
+def test_answer(pred_str, ans_str):
+ pred, gold = get_result(pred_str), get_result(ans_str)
+ return pred == gold
+
+
+def predict():
+ model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)
+ dataset = json.load(open("../../../results/gsm8k/origin/gsm8k_test.json"))
+
+ results = {}
+ if os.path.exists(args.save_path):
+ results = json.load(open(args.save_path))
+
+ demon_dict = json.load(open(args.load_prompt_from))
+ demonstrations = []
+ for demon in demon_dict["0"][args.load_key]:
+ demonstrations.append("\n\nQuestion: " + demon)
+ demonstrations = "".join(demonstrations)
+
+ for sample in tqdm(dataset):
+ idx = sample["idx"]
+ if idx in results or str(idx) in results:
+ print(f"{idx}-th processed")
+ continue
+ q = sample["question"]
+ a = sample["answer"]
+
+ prompt = f"Please reference the following examples to answer the math question. \n {demonstrations}"
+ query = f"\n\nQuestion: {q}" + "\nLet's think step by step."
+ token_ids = tokenizer.encode(prompt)
+ len2 = len(tokenizer.encode(query))
+ # drop in middle
+ if len(token_ids) > (args.n_max_token - args.n_max_token_ans - len2):
+ half = int((args.n_max_token - args.n_max_token_ans - len2) / 2) - 1
+ prompt = tokenizer.decode(token_ids[:half]) + tokenizer.decode(
+ token_ids[-half:]
+ )
+ prompt = prompt + query
+ answer = query_llm(prompt, model, args.model_name_or_path, args.n_max_token_ans)
+
+ results[idx] = {"question": q, "model_answer": answer, "truth_answer": a}
+ json.dump(results, open(args.save_path, "w"), indent=4)
+
+ ans_, _ = extract_ans(answer)
+ res = "Q: %s\nA_model:\n%s\nA:\n%s\n\n" % (
+ q,
+ ans_.replace("Q:", "").replace("A:", ""),
+ a,
+ )
+ with open(args.save_path.replace(".json", ".txt"), "a") as fd:
+ fd.write(res)
+
+
+predict()
+scores = parse_pred_ans(args.save_path.replace(".json", ".txt"))
+save_path2 = os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "metrics"),
+)
+json.dump(scores, open(save_path2, "w"))
diff --git a/experiments/llmlingua2/evaluation/eval_longbench.py b/experiments/llmlingua2/evaluation/eval_longbench.py
new file mode 100644
index 0000000..467d904
--- /dev/null
+++ b/experiments/llmlingua2/evaluation/eval_longbench.py
@@ -0,0 +1,326 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+import json
+import os
+from collections import defaultdict
+
+import numpy as np
+from metrics import (
+ classification_score,
+ code_sim_score,
+ count_score,
+ qa_f1_score,
+ qa_f1_zh_score,
+ retrieval_score,
+ retrieval_zh_score,
+ rouge_score,
+ rouge_zh_score,
+)
+from tqdm import tqdm
+from utils import load_model_and_tokenizer, query_llm
+
+dataset2metric = {
+ "narrativeqa": qa_f1_score,
+ "qasper": qa_f1_score,
+ "multifieldqa_en": qa_f1_score,
+ "multifieldqa_zh": qa_f1_zh_score,
+ "hotpotqa": qa_f1_score,
+ "2wikimqa": qa_f1_score,
+ "musique": qa_f1_score,
+ "dureader": rouge_zh_score,
+ "gov_report": rouge_score,
+ "qmsum": rouge_score,
+ "multi_news": rouge_score,
+ "vcsum": rouge_zh_score,
+ "trec": classification_score,
+ "triviaqa": qa_f1_score,
+ "samsum": rouge_score,
+ "lsht": classification_score,
+ "passage_retrieval_en": retrieval_score,
+ "passage_count": count_score,
+ "passage_retrieval_zh": retrieval_zh_score,
+ "lcc": code_sim_score,
+ "repobench-p": code_sim_score,
+}
+
+parser = argparse.ArgumentParser(description="compress any prompt.")
+parser.add_argument(
+ "--model_name_or_path", help="LLM used to answer", default="gpt-3.5-turbo-0613"
+)
+
+parser.add_argument("--n_max_token", type=int, default=8100)
+# parser.add_argument('--n_max_token_ans', type=int, default=400, help='token num in answer, following llmlingua')
+
+parser.add_argument(
+ "--load_prompt_from",
+ help="where to load compressed prompt",
+ default="results/longbench/origin/longbench_test_single_doc_qa_formated.json",
+)
+parser.add_argument("--load_key", default="prompt", type=str)
+parser.add_argument(
+ "--save_path",
+ help="path to save results",
+ default="results/longbench/origin/gpt35_chat_answer/answer_longbench_test_single_doc_qa_formated.json",
+)
+
+parser.add_argument("--e", action=argparse.BooleanOptionalAction, default=True)
+args = parser.parse_args()
+os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
+eng_datasets = [
+ "narrativeqa",
+ "qasper",
+ "multifieldqa_en",
+ "hotpotqa",
+ "2wikimqa",
+ "musique",
+ "gov_report",
+ "qmsum",
+ "multi_news",
+ "trec",
+ "triviaqa",
+ "samsum",
+ "passage_count",
+ "passage_retrieval_en",
+ "lcc",
+ "repobench-p",
+]
+all_datasets = [
+ "narrativeqa",
+ "qasper",
+ "multifieldqa_en",
+ "multifieldqa_zh",
+ "hotpotqa",
+ "2wikimqa",
+ "musique",
+ "dureader",
+ "gov_report",
+ "qmsum",
+ "multi_news",
+ "vcsum",
+ "trec",
+ "triviaqa",
+ "samsum",
+ "lsht",
+ "passage_count",
+ "passage_retrieval_en",
+ "passage_retrieval_zh",
+ "lcc",
+ "repobench-p",
+]
+
+
+def scorer_e(dataset, predictions, answers, lengths, all_classes):
+ scores = {"0-4k": [], "4-8k": [], "8k+": []}
+ for prediction, ground_truths, length in zip(predictions, answers, lengths):
+ score = 0.0
+ if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
+ prediction = prediction.lstrip("\n").split("\n")[0]
+ for ground_truth in ground_truths:
+ score = max(
+ score,
+ dataset2metric[dataset](
+ prediction, ground_truth, all_classes=all_classes
+ ),
+ )
+ if length < 4000:
+ scores["0-4k"].append(score)
+ elif length < 8000:
+ scores["4-8k"].append(score)
+ else:
+ scores["8k+"].append(score)
+ for key in scores.keys():
+ scores[key] = round(100 * np.mean(scores[key]), 2)
+ return scores
+
+
+def scorer(dataset, predictions, answers, all_classes):
+ total_score = 0.0
+ for prediction, ground_truths in zip(predictions, answers):
+ score = 0.0
+ if dataset in [
+ "trec",
+ "triviaqa",
+ "samsum",
+ "lsht",
+ "narrativeqa",
+ "qasper",
+ "multifieldqa_en",
+ "multifieldqa_zh",
+ "hotpotqa",
+ "2wikimqa",
+ "musique",
+ "dureader",
+ "vcsum",
+ ]:
+ prediction = prediction.lstrip("\n").split("\n")[0]
+ # if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
+ # prediction = prediction.lstrip('\n').split('\n')[0]
+ # for ground_truth in ground_truths:
+ # score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
+ # prediction = prediction.lstrip('\n').split('\n')[0]
+ # prediction = prediction.strip("")
+ for ground_truth in ground_truths:
+ score = max(
+ score,
+ dataset2metric[dataset](
+ prediction, ground_truth, all_classes=all_classes
+ ),
+ )
+ total_score += score
+ return round(100 * total_score / len(predictions), 2)
+
+
+def eval(load_path):
+ results = json.load(open(load_path))
+ predictions, answers, lengths = (
+ defaultdict(list),
+ defaultdict(list),
+ defaultdict(list),
+ )
+ all_classes = {}
+ for idx, data in results.items():
+ predictions[data["task"]].append(data["pred"])
+ answers[data["task"]].append(data["answers"])
+ all_classes[data["task"]] = data["all_classes"]
+ if "length" in data:
+ lengths[data["task"]].append(data["length"])
+ scores = {}
+ for task in predictions.keys():
+ pred_list, ans_list, length_list = (
+ predictions[task],
+ answers[task],
+ lengths[task],
+ )
+ score = scorer(task, pred_list, ans_list, all_classes[task])
+ print(score)
+ scores[task] = {"score": score, "num": len(pred_list)}
+ score_list = [s["score"] for s in scores.values()]
+ scores["avg"] = sum(score_list) / len(score_list)
+ return scores
+
+
+dataset2prompt = {
+ "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
+ "qasper": 'You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:',
+ "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
+ "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
+ "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
+ "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
+ "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
+ "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:",
+ "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
+ "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
+ "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
+ "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}",
+ "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
+ "passage_retrieval_en": 'Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like "Paragraph 1", "Paragraph 2", etc.\n\nThe answer is: ',
+ "passage_retrieval_zh": '以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是"段落1","段落2"等格式\n\n答案是:',
+ "lcc": "Please complete the code given below. \n{context}Next line of code:\n",
+ "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n",
+}
+
+dataset2maxlen = {
+ "narrativeqa": 128,
+ "qasper": 128,
+ "multifieldqa_en": 64,
+ "multifieldqa_zh": 64,
+ "hotpotqa": 32,
+ "2wikimqa": 32,
+ "musique": 32,
+ "dureader": 128,
+ "gov_report": 512,
+ "qmsum": 512,
+ "multi_news": 512,
+ "vcsum": 512,
+ "trec": 64,
+ "triviaqa": 32,
+ "samsum": 128,
+ "lsht": 64,
+ "passage_count": 32,
+ "passage_retrieval_en": 32,
+ "passage_retrieval_zh": 32,
+ "lcc": 64,
+ "repobench-p": 64,
+}
+
+
+def predict():
+ model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)
+
+ dataset = json.load(open(args.load_prompt_from))
+ print(len(dataset))
+ if isinstance(dataset, dict):
+ dataset = dataset.values()
+ # dataset2prompt = json.load(
+ # open("../data/LongBench/config/dataset2prompt.json", "r")
+ # )
+ # dataset2maxlen = json.load(
+ # open("../data/LongBench/config/dataset2maxlen.json", "r")
+ # )
+ # prompt_format = dataset2prompt[args.task]
+ # max_gen = int(dataset2maxlen[args.task])
+
+ results = {}
+ if os.path.exists(args.save_path):
+ results = json.load(open(args.save_path))
+
+ for sample in tqdm(dataset):
+ idx = int(sample["idx"])
+ task = sample["task"]
+ if idx in results or str(idx) in results:
+ print(f"{idx} processed")
+ continue
+ new_sample = {}
+ new_sample["context"] = sample[args.load_key]
+ new_sample["input"] = sample["question"]
+
+ prompt_format = dataset2prompt[sample["task"]]
+ max_gen = int(dataset2maxlen[sample["task"]])
+ prompt = prompt_format.format(**new_sample)
+ token_ids = tokenizer.encode(prompt)
+
+ if len(token_ids) > (args.n_max_token - max_gen):
+ half = int((args.n_max_token - max_gen) / 2) - 1
+ prompt = tokenizer.decode(token_ids[:half]) + tokenizer.decode(
+ token_ids[-half:]
+ )
+
+ pred = query_llm(
+ prompt, model, args.model_name_or_path, max_gen, tokenizer=tokenizer
+ )
+ results[idx] = {
+ "pred": pred,
+ "answers": sample["answers"],
+ "model_name": args.model_name_or_path,
+ "task": sample["task"],
+ "idx": idx,
+ "all_classes": sample["all_classes"],
+ "length": sample["length"],
+ }
+ json.dump(
+ results,
+ open(args.save_path, "w", encoding="utf8"),
+ indent=4,
+ ensure_ascii=False,
+ )
+
+
+predict()
+score_dict = eval(load_path=args.save_path)
+print(score_dict)
+json.dump(
+ score_dict,
+ open(
+ os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "metrics"),
+ ),
+ "w",
+ ),
+)
diff --git a/experiments/llmlingua2/evaluation/eval_meetingbank_qa.py b/experiments/llmlingua2/evaluation/eval_meetingbank_qa.py
new file mode 100644
index 0000000..54835bb
--- /dev/null
+++ b/experiments/llmlingua2/evaluation/eval_meetingbank_qa.py
@@ -0,0 +1,128 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+import json
+import os
+from collections import defaultdict
+
+from metrics import evaluate_with_gt
+from tqdm import tqdm
+from utils import load_model_and_tokenizer, query_llm
+
+parser = argparse.ArgumentParser(description="compress any prompt.")
+parser.add_argument(
+ "--model_name_or_path", help="LLM used to answer", default="gpt-3.5-turbo-0613"
+)
+
+parser.add_argument("--n_max_token", type=int, default=8100)
+parser.add_argument(
+ "--n_max_token_ans",
+ type=int,
+ default=100,
+ help="token num in answer, following llmlingua",
+)
+
+parser.add_argument(
+ "--load_prompt_from", help="where to load compressed prompt", required=True
+)
+parser.add_argument("--load_key", default="prompt", type=str)
+parser.add_argument("--save_path", help="path to save results", required=True)
+parser.add_argument("--num_sample", type=int, default=-1)
+
+args = parser.parse_args()
+os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
+data = json.load(open(args.load_prompt_from))
+data = data.values() if isinstance(data, dict) else data
+
+print(f"num data: {len(data)}")
+
+model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)
+
+results = defaultdict(dict)
+results_list = defaultdict(list)
+if os.path.exists(args.save_path):
+ prev_results = json.load(open(args.save_path))
+ results.update(prev_results)
+if os.path.exists(
+ os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "answer_list"),
+ )
+):
+ results_list = json.load(
+ open(
+ os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "answer_list"),
+ )
+ )
+ )
+
+prompt = "Write a high-quality answer for the given question using the provided meeting transcript (which may be compressed).\n{transcript}\nQuestion:{question}\nAnswer:"
+for sample in tqdm(data):
+ sample_idx = int(sample["idx"])
+ if sample_idx in results or str(sample_idx) in results:
+ print(f"{sample_idx}-th already processed.")
+ continue
+ if args.num_sample > 0 and int(sample_idx) > args.num_sample:
+ break
+ transcript = sample[args.load_key]
+ token_ids = tokenizer.encode(transcript)
+ if len(token_ids) > args.n_max_token - args.n_max_token_ans:
+ transcript = tokenizer.decode(
+ token_ids[: args.n_max_token - args.n_max_token_ans]
+ )
+ qa_list = sample["QA_pairs"]
+ q_list = []
+ a_list = []
+ a_list_model = []
+ for qa in qa_list:
+ q = qa["question"]
+ a = qa["answer"]
+ query = prompt.format(transcript=transcript, question=q)
+ answer = query_llm(
+ query,
+ model,
+ args.model_name_or_path,
+ args.n_max_token_ans,
+ tokenizer=tokenizer,
+ )
+ q_list.append(q)
+ a_list.append(a)
+ a_list_model.append(answer)
+
+ results[sample_idx]["transcript"] = transcript
+ results[sample_idx]["questions"] = q_list[:]
+ results[sample_idx]["answers"] = a_list[:]
+ results[sample_idx]["model_answers"] = a_list_model[:]
+
+ results_list["questions"].extend(q_list[:])
+ results_list["answers"].extend(a_list[:])
+ results_list["model_answers"].extend(a_list_model[:])
+
+ json.dump(results, open(args.save_path, "w"), indent=4)
+ json.dump(
+ results_list,
+ open(
+ os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "answer_list"),
+ ),
+ "w",
+ ),
+ indent=4,
+ )
+
+score_dict = evaluate_with_gt(results_list["answers"], results_list["model_answers"])
+json.dump(
+ score_dict,
+ open(
+ os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "metrics"),
+ ),
+ "w",
+ ),
+ indent=4,
+)
diff --git a/experiments/llmlingua2/evaluation/eval_meetingbank_summary.py b/experiments/llmlingua2/evaluation/eval_meetingbank_summary.py
new file mode 100644
index 0000000..54d211f
--- /dev/null
+++ b/experiments/llmlingua2/evaluation/eval_meetingbank_summary.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+import json
+import os
+from collections import defaultdict
+
+from metrics import evaluate_sim
+from tqdm import tqdm
+from utils import load_model_and_tokenizer, query_llm
+
+parser = argparse.ArgumentParser(description="compress any prompt.")
+parser.add_argument(
+ "--model_name_or_path", help="LLM used to answer", default="gpt-3.5-turbo-0613"
+)
+
+parser.add_argument("--n_max_token", type=int, default=8100)
+parser.add_argument(
+ "--n_max_token_ans",
+ type=int,
+ default=400,
+ help="token num in answer, following llmlingua",
+)
+
+parser.add_argument(
+ "--load_prompt_from", help="where to load compressed prompt", required=True
+)
+parser.add_argument("--load_key", default="prompt", type=str)
+parser.add_argument("--save_path", help="path to save results", required=True)
+parser.add_argument("--num_sample", type=int, default=-1)
+
+args = parser.parse_args()
+os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
+
+
+def predict():
+ model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)
+
+ data = json.load(open(args.load_prompt_from))
+ data = data.values() if isinstance(data, dict) else data
+ print(f"num data: {len(data)}")
+
+ results = defaultdict(dict)
+ results_list = defaultdict(list)
+ if os.path.exists(args.save_path):
+ prev_results = json.load(open(args.save_path))
+ results.update(prev_results)
+ if os.path.exists(
+ os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "answer_list"),
+ )
+ ):
+ results_list = json.load(
+ open(
+ os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "answer_list"),
+ )
+ )
+ )
+
+ prompt = "Summarize the provided meeting transcript (which may be compressed).\n{transcript}\nSummary:"
+ for sample in tqdm(data):
+ if isinstance(sample, float):
+ continue
+ sample_idx = int(sample["idx"])
+ if sample_idx in results or str(sample_idx) in results:
+ print(f"{sample_idx}-th already processed.")
+ continue
+ if args.num_sample > 0 and int(sample_idx) > args.num_sample:
+ break
+ transcript = sample[args.load_key]
+ token_ids = tokenizer.encode(transcript)
+ if len(token_ids) > args.n_max_token - args.n_max_token_ans:
+ transcript = tokenizer.decode(
+ token_ids[: args.n_max_token - args.n_max_token_ans]
+ )
+
+ query = prompt.format(transcript=transcript)
+
+ # t = time.time()
+ model_summary = query_llm(
+ query,
+ model,
+ args.model_name_or_path,
+ args.n_max_token_ans,
+ tokenizer=tokenizer,
+ )
+ # total_time += time.time() - t
+
+ summary = sample["gpt4_summary"]
+
+ results[sample_idx]["transcript"] = transcript
+ results[sample_idx]["model_summary"] = model_summary
+ results[sample_idx]["gpt4_summary"] = summary
+
+ results_list["model_summary"].append(model_summary)
+ results_list["gpt4_summary"].append(summary)
+
+ json.dump(results, open(args.save_path, "w"), indent=4)
+ json.dump(
+ results_list,
+ open(
+ os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "answer_list"),
+ ),
+ "w",
+ ),
+ indent=4,
+ )
+
+
+predict()
+results_list = defaultdict(list)
+results_list = json.load(
+ open(
+ os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "answer_list"),
+ )
+ )
+)
+score_dict = evaluate_sim(results_list["model_summary"], results_list["gpt4_summary"])
+json.dump(
+ score_dict,
+ open(
+ os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "metrics"),
+ ),
+ "w",
+ ),
+ indent=4,
+)
diff --git a/experiments/llmlingua2/evaluation/eval_zero_scrolls.py b/experiments/llmlingua2/evaluation/eval_zero_scrolls.py
new file mode 100644
index 0000000..478ba16
--- /dev/null
+++ b/experiments/llmlingua2/evaluation/eval_zero_scrolls.py
@@ -0,0 +1,172 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+import json
+import os
+import shutil
+from collections import defaultdict
+
+import datasets
+from huggingface_hub import hf_hub_download
+from tqdm import tqdm
+from utils import load_model_and_tokenizer, query_llm
+
+parser = argparse.ArgumentParser(description="compress any prompt.")
+parser.add_argument(
+ "--model_name_or_path", help="LLM used to answer", default="gpt-3.5-turbo-0613"
+)
+
+parser.add_argument("--n_max_token", type=int, default=8100)
+# parser.add_argument('--n_max_token_ans', type=int, default=400, help='token num in answer, following llmlingua')
+
+parser.add_argument(
+ "--load_prompt_from",
+ help="where to load compressed prompt",
+ default="results/zero_scrolls/origin/zero_scrolls_validation.json",
+)
+parser.add_argument("--load_key", default="prompt", type=str)
+parser.add_argument(
+ "--save_path",
+ help="path to save results",
+ default="results/zero_scrolls/origin/gpt35_chat_16k_answer/answer_zero_scrolls_validation.json",
+)
+args = parser.parse_args()
+os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
+save_path2 = os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "answer2"),
+)
+
+
+def eval(predict_path: str):
+ def download_metric():
+ zero_scrolls_metric_path = hf_hub_download(
+ repo_id="tau/zero_scrolls",
+ repo_type="dataset",
+ filename="metrics/zero_scrolls.py",
+ )
+ updated_zero_scrolls_metric_path = (
+ os.path.dirname(zero_scrolls_metric_path)
+ + os.path.basename(zero_scrolls_metric_path).replace(".", "_")
+ + ".py"
+ )
+ shutil.copy(zero_scrolls_metric_path, updated_zero_scrolls_metric_path)
+ return updated_zero_scrolls_metric_path
+
+ zero_scrolls_metric_path = download_metric()
+ preds = json.load(open(predict_path))
+ preds_g, refers_g = defaultdict(list), defaultdict(list)
+ for v in preds.values():
+ task, refer, pred = [v[k] for k in ["task", "reference", "pred"]]
+ # if task == "narrative_qa":
+ pred = (
+ pred.split("\n\nQuestion:", 1)[0]
+ .split("\n\nExplanation:", 1)[0]
+ .replace("<|im_end|>", "")
+ .replace("\end{document}", "")
+ .strip()
+ )
+ # .split("\n\nExplanation:", 1)[0]
+ if task == "space_digest":
+ if pred.startswith("0.") and "%" not in pred[:4]:
+ pred = "{:.2f}%".format(float(pred[:4]) * 100)
+ else:
+ pred = pred[:5].strip().replace("%", "") + "%"
+ preds_g[task].append(pred)
+ refers_g[task].append([refer])
+
+ zero_scrolls = []
+ score_dict = {}
+ OUT_TASKS = [
+ "gov_report",
+ "summ_screen_fd",
+ "qmsum",
+ "squality",
+ "quality",
+ "narrative_qa",
+ "qasper",
+ "musique",
+ "space_digest",
+ "book_sum_sort",
+ ]
+ for task in OUT_TASKS:
+ if task not in preds_g:
+ zero_scrolls.append(0)
+ continue
+ p, r = preds_g[task], refers_g[task]
+ zero_scrolls_metric = datasets.load_metric(zero_scrolls_metric_path, task)
+ results = zero_scrolls_metric.compute(predictions=p, references=r)
+ print(task, len(p), results)
+ zero_scrolls.append(results["zero_scrolls_score"])
+ score_dict[task] = {
+ "zero_scrolls_score": results["zero_scrolls_score"],
+ "length": len(p),
+ }
+ print(",".join([f"{ii:.2f}" for ii in zero_scrolls]))
+ score_avg = sum(zero_scrolls) / len(zero_scrolls)
+ score_dict["avg"] = score_avg
+ return score_dict
+
+
+def predict():
+ model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)
+
+ dataset = json.load(open(args.load_prompt_from))
+ if isinstance(dataset, dict):
+ dataset = dataset.values()
+
+ res = {}
+ res2 = {}
+ if os.path.exists(args.save_path):
+ res = json.load(open(args.save_path))
+ if os.path.exists(save_path2):
+ res2 = json.load(open(save_path2))
+
+ for sample in tqdm(dataset):
+ idx = int(sample["idx"])
+ if idx in res or str(idx) in res:
+ print(f"{idx} processed")
+ continue
+
+ prompt = sample[args.load_key]
+ max_gen = sample["n_max_token_ans"]
+ token_ids = tokenizer.encode(prompt)
+
+ if len(token_ids) > (args.n_max_token - max_gen):
+ half = int((args.n_max_token - max_gen) / 2) - 1
+ prompt = tokenizer.decode(token_ids[:half]) + tokenizer.decode(
+ token_ids[-half:]
+ )
+
+ pred = query_llm(prompt, model, args.model_name_or_path, max_gen)
+
+ res[idx] = {
+ "pred": pred,
+ "answer": sample["answer"],
+ "model_name": args.model_name_or_path,
+ "task": sample["task"],
+ "idx": idx,
+ }
+ json.dump(res, open(args.save_path, "w"), indent=4)
+ res2[f"{idx},{sample['task']}"] = {
+ "idx": idx,
+ "task": sample["task"],
+ "pred": pred,
+ "reference": sample["answer"],
+ }
+ json.dump(res2, open(save_path2, "w"), indent=4)
+
+
+predict()
+score_dict = eval(save_path2)
+json.dump(
+ score_dict,
+ open(
+ os.path.join(
+ os.path.dirname(args.save_path),
+ os.path.basename(args.save_path).replace("answer", "metrics"),
+ ),
+ "w",
+ ),
+)
diff --git a/experiments/llmlingua2/evaluation/metrics.py b/experiments/llmlingua2/evaluation/metrics.py
new file mode 100644
index 0000000..3da9ecd
--- /dev/null
+++ b/experiments/llmlingua2/evaluation/metrics.py
@@ -0,0 +1,272 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import re
+import string
+from collections import Counter
+from typing import List
+
+import evaluate
+import jieba
+from fuzzywuzzy import fuzz
+from rouge import Rouge
+
+
+def normalize_answer(s):
+ """Lower text and remove punctuation, articles and extra whitespace."""
+
+ def remove_articles(text):
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def normalize_zh_answer(s):
+ """Lower text and remove punctuation, extra whitespace."""
+
+ def white_space_fix(text):
+ return "".join(text.split())
+
+ def remove_punc(text):
+ cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
+ all_punctuation = set(string.punctuation + cn_punctuation)
+ return "".join(ch for ch in text if ch not in all_punctuation)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_punc(lower(s)))
+
+
+def count_score(prediction, ground_truth, **kwargs):
+ numbers = re.findall(r"\d+", prediction)
+ right_num = 0
+ for number in numbers:
+ if str(number) == str(ground_truth):
+ right_num += 1
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
+ return float(final_score)
+
+
+def retrieval_score(prediction, ground_truth, **kwargs):
+ pattern = r"Paragraph (\d+)"
+ matches = re.findall(pattern, ground_truth)
+ ground_truth_id = matches[0]
+ numbers = re.findall(r"\d+", prediction)
+ right_num = 0
+ for number in numbers:
+ if str(number) == str(ground_truth_id):
+ right_num += 1
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
+ return float(final_score)
+
+
+def retrieval_zh_score(prediction, ground_truth, **kwargs):
+ pattern = r"段落(\d+)"
+ matches = re.findall(pattern, ground_truth)
+ ground_truth_id = matches[0]
+ numbers = re.findall(r"\d+", prediction)
+ right_num = 0
+ for number in numbers:
+ if str(number) == str(ground_truth_id):
+ right_num += 1
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
+ return float(final_score)
+
+
+def code_sim_score(prediction, ground_truth, **kwargs):
+ all_lines = prediction.lstrip("\n").split("\n")
+ prediction = ""
+ for line in all_lines:
+ if ("`" not in line) and ("#" not in line) and ("//" not in line):
+ prediction = line
+ break
+ return fuzz.ratio(prediction, ground_truth) / 100
+
+
+def classification_score(prediction, ground_truth, **kwargs):
+ em_match_list = []
+ all_classes = kwargs["all_classes"]
+ for class_name in all_classes:
+ if class_name in prediction:
+ em_match_list.append(class_name)
+ for match_term in em_match_list:
+ if match_term in ground_truth and match_term != ground_truth:
+ em_match_list.remove(match_term)
+ if ground_truth in em_match_list:
+ score = 1.0 / len(em_match_list)
+ else:
+ score = 0.0
+ return score
+
+
+def rouge_score(prediction, ground_truth, **kwargs):
+ rouge = Rouge()
+ try:
+ scores = rouge.get_scores([prediction], [ground_truth], avg=True)
+ except:
+ return 0.0
+ return scores["rouge-l"]["f"]
+
+
+def rouge_zh_score(prediction, ground_truth, **kwargs):
+ prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
+ ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
+ score = rouge_score(prediction, ground_truth)
+ return score
+
+
+def f1_score(prediction, ground_truth, **kwargs):
+ common = Counter(prediction) & Counter(ground_truth)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(prediction)
+ recall = 1.0 * num_same / len(ground_truth)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def qa_f1_score(prediction, ground_truth, **kwargs):
+ normalized_prediction = normalize_answer(prediction)
+ normalized_ground_truth = normalize_answer(ground_truth)
+
+ prediction_tokens = normalized_prediction.split()
+ ground_truth_tokens = normalized_ground_truth.split()
+ return f1_score(prediction_tokens, ground_truth_tokens)
+
+
+def qa_f1_zh_score(prediction, ground_truth, **kwargs):
+ prediction_tokens = list(jieba.cut(prediction, cut_all=False))
+ ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
+ prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
+ ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
+ prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
+ ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
+ return f1_score(prediction_tokens, ground_truth_tokens)
+
+
+import regex
+
+
+def normalize_answer2(s: str) -> str:
+ """Normalization from the SQuAD evaluation script.
+
+ See https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
+ """
+
+ def remove_articles(text):
+ return regex.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def best_subspan_em(prediction: str, ground_truths: List[str]) -> float:
+ normalized_prediction = normalize_answer2(prediction)
+
+ for ground_truth in ground_truths:
+ normalized_ground_truth = normalize_answer2(ground_truth)
+ if normalized_ground_truth.lower() in normalized_prediction.lower():
+ return 1.0
+ return 0.0
+
+
+def evaluate_with_gt(pred_list, gt_list, truncate_pred=True, logger=None):
+ def eval_qa_f1_score(pred, ground_truths):
+ score = 0.0
+ for gt in ground_truths:
+ score = max(score, qa_f1_score(pred, gt))
+ score = score
+ return score
+
+ if truncate_pred:
+ pred_list_truncated = []
+ for pred in pred_list:
+ pred = pred.lstrip("\n").split("\n")[0].strip()
+ pred_list_truncated.append(pred)
+ pred_list = pred_list_truncated
+
+ metrics = {
+ "qa_f1_score": 0.0,
+ "best_subspan_em": 0.0,
+ }
+ for pred, gts in zip(pred_list, gt_list):
+ metrics["qa_f1_score"] += eval_qa_f1_score(pred, gts)
+ metrics["best_subspan_em"] += best_subspan_em(pred, gts)
+ # average
+ for metric_name, score in metrics.items():
+ metrics[metric_name] = score * 100 / len(pred_list)
+ print(f"{metric_name}: {metrics[metric_name]:.3f}")
+ if logger is not None:
+ logger.info(f"{metric_name}: {metrics[metric_name]:.3f}")
+
+ return metrics
+
+
+def evaluate_sim(pred_list, gt_list, truncate_pred=True, truncate_gt=False):
+ if truncate_pred:
+ pred_list_truncated = []
+ for pred in pred_list:
+ pred = pred.lstrip("\n").split("\n")[0].strip()
+ pred_list_truncated.append(pred)
+ pred_list = pred_list_truncated
+ if truncate_gt:
+ gt_list_truncated = []
+ for gt in gt_list:
+ gt = gt.lstrip("\n").split("\n")[0].strip()
+ gt_list_truncated.append(gt)
+ gt_list = gt_list_truncated
+
+ bleu = evaluate.load("bleu")
+ rouge = evaluate.load("rouge")
+ bertscore = evaluate.load("bertscore")
+ bleu_results = bleu.compute(predictions=pred_list, references=gt_list)
+ rouge_results = rouge.compute(predictions=pred_list, references=gt_list)
+ bertscore_results = bertscore.compute(
+ predictions=pred_list, references=gt_list, lang="en"
+ )
+ p, r, f1 = [bertscore_results[k] for k in ["precision", "recall", "f1"]]
+ evs = [
+ bleu_results["bleu"],
+ *[rouge_results[k] for k in ["rouge1", "rouge2", "rougeL", "rougeLsum"]],
+ sum(p) / len(p),
+ sum(r) / len(r),
+ sum(f1) / len(f1),
+ ]
+ metrics = {}
+ for i, metric_name in enumerate(
+ [
+ "bleu",
+ "rouge1",
+ "rouge2",
+ "rougeL",
+ "rougeLsum",
+ "bertscore_precision",
+ "bertscore_recall",
+ "bertscore_f1",
+ ]
+ ):
+ metrics[metric_name] = evs[i]
+ print(",".join([f"{ii * 100:.2f}" for ii in evs]))
+
+ return metrics
diff --git a/experiments/llmlingua2/evaluation/scripts/compress.sh b/experiments/llmlingua2/evaluation/scripts/compress.sh
new file mode 100644
index 0000000..e435994
--- /dev/null
+++ b/experiments/llmlingua2/evaluation/scripts/compress.sh
@@ -0,0 +1,25 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+python compress.py --load_origin_from ../../../results/meetingbank_short/origin/meetingbank_test_3qa_pairs_summary_formated.json \
+ --compression_rate 0.33 \
+ --force_tokens "\n,?,!,." \
+ --save_path ../../../results/meetingbank_short/llmlingua2/compression_ratio33_meetingbank_test_3qa_pairs_summary_formated.json
+
+python compress.py --load_origin_from ../../../results/longbench/origin/longbench_test_single_doc_qa_formated.json \
+ --target_token 2000 \
+ --force_tokens "\n,?,!,." \
+ --save_path ../../../results/longbench/llmlingua2/compression_target2000_longbench_test_single_doc_qa_formated.json
+
+python compress.py --load_origin_from ../../../results/zero_scrolls/origin/zero_scrolls_validation.json \
+ --target_token 2000 \
+ --force_tokens "\n,?,!,." \
+ --save_path ../../../results/zero_scrolls/llmlingua2/compression_target2000_zero_scrolls_validation.json
+
+python compress.py --load_origin_from ../../../results/gsm8k/origin/gsm8k_cot_example_all_in_one.json \
+ --load_key prompt_list \
+ --target_token 250 \
+ --force_tokens "+,-,*,×,/,÷,=,The answer is,\n" \
+ --use_context_level_filter \
+ --force_reserve_digit \
+ --save_path ../../../results/gsm8k/llmlingua2/compression_target250_gsm8k_cot_example_all_in_one.json
diff --git a/experiments/llmlingua2/evaluation/scripts/evaluate.sh b/experiments/llmlingua2/evaluation/scripts/evaluate.sh
new file mode 100644
index 0000000..7d5a215
--- /dev/null
+++ b/experiments/llmlingua2/evaluation/scripts/evaluate.sh
@@ -0,0 +1,22 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+python eval_meetingbank_qa.py --load_prompt_from ../../../results/meetingbank_short/llmlingua2/compression_ratio33_meetingbank_test_3qa_pairs_summary_formated.json \
+ --load_key compressed_prompt \
+ --model_name_or_path gpt-35-turbo-instruct \
+ --save_path ../../../results/meetingbank_short/llmlingua2/gpt35_answer/answer_ratio33_meetingbank_test_3qa_pairs_summary_formated.json
+
+python eval_longbench.py --load_prompt_from ../../../results/longbench/llmlingua2/compression_target2000_longbench_test_single_doc_qa_formated.json \
+ --load_key compressed_prompt \
+ --model_name_or_path gpt-35-turbo-instruct \
+ --save_path ../../../results/longbench/llmlingua2/gpt35_answer/answer_target2000_longbench_test_single_doc_qa_formated.json
+
+python eval_zero_scrolls.py --load_prompt_from ../../../results/zero_scrolls/llmlingua2/compression_target2000_zero_scrolls_validation.json \
+ --load_key compressed_prompt \
+ --model_name_or_path gpt-35-turbo-instruct \
+ --save_path ../../../results/zero_scrolls/llmlingua2/gpt35_answer/answer_target2000_zero_scrolls_validation.json
+
+python eval_gsm8k.py --load_prompt_from ../../../results/gsm8k/llmlingua2/compression_target200_gsm8k_cot_example_all_in_one.json \
+ --load_key compressed_prompt_list \
+ --model_name_or_path gpt-35-turbo-instruct \
+ --save_path ../../../results/gsm8k/llmlingua2/gpt35_answer/answer_target200_gsm8k_cot_example_all_in_one.json
diff --git a/experiments/llmlingua2/evaluation/utils.py b/experiments/llmlingua2/evaluation/utils.py
new file mode 100644
index 0000000..165592e
--- /dev/null
+++ b/experiments/llmlingua2/evaluation/utils.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+from time import sleep
+
+import openai
+import tiktoken
+
+
+def query_llm(
+ prompt,
+ model,
+ model_name,
+ max_tokens,
+ tokenizer=None,
+ chat_completion=False,
+ **kwargs,
+):
+ SLEEP_TIME_FAILED = 62
+
+ request = {
+ "temperature": kwargs["temperature"] if "temperature" in kwargs else 0.0,
+ "top_p": kwargs["top_p"] if "top_p" in kwargs else 1.0,
+ "seed": kwargs["seed"] if "seed" in kwargs else 42,
+ "max_tokens": max_tokens,
+ "n": 1,
+ "stream": False,
+ }
+ if chat_completion:
+ request["messages"] = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt},
+ ]
+ else:
+ request["prompt"] = prompt
+
+ answer = None
+ response = None
+ while answer is None:
+ try:
+ response = model.create(engine=model_name, **request)
+ answer = (
+ response["choices"][0]["message"]["content"]
+ if chat_completion
+ else response["choices"][0]["text"]
+ )
+ except Exception as e:
+ answer = None
+ print(f"error: {e}, response: {response}")
+ sleep(SLEEP_TIME_FAILED)
+ # sleep(SLEEP_TIME_SUCCESS)
+ return answer
+
+
+def load_model_and_tokenizer(model_name_or_path, chat_completion=False):
+ openai.api_key = "your_api_key"
+ openai.api_base = "your_api_base"
+ openai.api_type = "azure"
+ openai.api_version = "2023-05-15"
+
+ if chat_completion:
+ model = openai.ChatCompletion
+ else:
+ model = openai.Completion
+
+ tokenizer = tiktoken.encoding_for_model("gpt-4")
+ return model, tokenizer
diff --git a/experiments/llmlingua2/model_training/train.sh b/experiments/llmlingua2/model_training/train.sh
new file mode 100644
index 0000000..e1ed311
--- /dev/null
+++ b/experiments/llmlingua2/model_training/train.sh
@@ -0,0 +1,5 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+python train_roberta.py --data_path ../../../results/meetingbank/gpt-4-32k_comp/annotation_cs512_meetingbank_train_formated.pt \
+ --save_path ../../../results/models/xlm_roberta_large_meetingbank_only.pth
diff --git a/experiments/llmlingua2/model_training/train_roberta.py b/experiments/llmlingua2/model_training/train_roberta.py
new file mode 100644
index 0000000..69a4e24
--- /dev/null
+++ b/experiments/llmlingua2/model_training/train_roberta.py
@@ -0,0 +1,227 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+import os
+import random
+import time
+
+import torch
+from sklearn.metrics import accuracy_score
+from torch import cuda
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
+from transformers import AutoModelForTokenClassification, AutoTokenizer
+from utils import TokenClfDataset
+
+MAX_LEN = 512
+MAX_GRAD_NORM = 10
+
+parser = argparse.ArgumentParser(
+ description="train bert to do compression (by token classification)"
+)
+parser.add_argument(
+ "--model_name",
+ help="token classification model",
+ default="FacebookAI/xlm-roberta-large",
+)
+parser.add_argument(
+ "--data_path",
+ help="training and validation data path",
+ default="../../../results/meetingbank/gpt-4-32k_comp/annotation_kept_cs512_meetingbank_train_formated.pt",
+)
+parser.add_argument(
+ "--label_type",
+ help="word label or token label",
+ default="word_label",
+ choices=["word_label", "token_label"],
+)
+parser.add_argument(
+ "--save_path",
+ help="save path",
+ default="../../../results/models/xlm_roberta_large_meetingbank_only.pth",
+)
+parser.add_argument("--lr", help="learning rate", default=1e-5, type=float)
+parser.add_argument(
+ "--num_epoch", help="number of training epoch", default=10, type=int
+)
+parser.add_argument("--batch_size", type=int, default=10)
+
+args = parser.parse_args()
+os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
+writer = SummaryWriter(log_dir=os.path.dirname(args.save_path).replace("model", "log"))
+
+
+def train(epoch):
+ tr_loss, tr_accuracy = 0, 0
+ nb_tr_examples, nb_tr_steps = 0, 0
+ tr_preds, tr_labels = [], []
+ model.train()
+
+ for idx, batch in enumerate(train_dataloader):
+ t = time.time()
+ ids = batch["ids"].to(device, dtype=torch.long)
+ mask = batch["mask"].to(device, dtype=torch.long)
+ targets = batch["targets"].to(device, dtype=torch.long)
+
+ outputs = model(input_ids=ids, attention_mask=mask, labels=targets)
+ loss, tr_logits = outputs.loss, outputs.logits
+ tr_loss += loss.item()
+
+ nb_tr_steps += 1
+ nb_tr_examples += targets.size(0)
+
+ flattened_targets = targets.view(-1)
+ active_logits = tr_logits.view(-1, model.num_labels)
+ flattened_predictions = torch.argmax(active_logits, axis=1)
+ active_accuracy = mask.view(-1) == 1
+ targets = torch.masked_select(flattened_targets, active_accuracy)
+ predictions = torch.masked_select(flattened_predictions, active_accuracy)
+
+ tr_preds.extend(predictions)
+ tr_labels.extend(targets)
+
+ tmp_tr_accuracy = accuracy_score(
+ targets.cpu().numpy(), predictions.cpu().numpy()
+ )
+ tr_accuracy += tmp_tr_accuracy
+
+ if idx % 100 == 0:
+ loss_step = tr_loss / nb_tr_steps
+ acc_step = tr_accuracy / nb_tr_steps
+ writer.add_scalar(
+ "Loss/train", loss_step, idx + epoch * len(train_dataloader)
+ )
+ writer.add_scalar(
+ "Acc/train", acc_step, idx + epoch * len(train_dataloader)
+ )
+ writer.flush()
+ print(f"Training loss per 100 training steps: {loss_step}")
+
+ torch.nn.utils.clip_grad_norm_(
+ parameters=model.parameters(), max_norm=MAX_GRAD_NORM
+ )
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ tr_loss = tr_loss / nb_tr_steps
+ tr_accuracy = tr_accuracy / nb_tr_steps
+ print(f"Training loss epoch: {tr_loss}")
+ print(f"Training accuracy epoch: {tr_accuracy}")
+
+
+def test(model, eval_dataloader):
+ model.eval()
+
+ eval_loss, eval_accuracy = 0, 0
+ nb_eval_examples, nb_eval_steps = 0, 0
+ eval_preds, eval_labels = [], []
+
+ with torch.no_grad():
+ for idx, batch in enumerate(eval_dataloader):
+ ids = batch["ids"].to(device, dtype=torch.long)
+ mask = batch["mask"].to(device, dtype=torch.long)
+ targets = batch["targets"].to(device, dtype=torch.long)
+
+ outputs = model(input_ids=ids, attention_mask=mask, labels=targets)
+ loss, eval_logits = outputs.loss, outputs.logits
+
+ eval_loss += loss.item()
+
+ nb_eval_steps += 1
+ nb_eval_examples += targets.size(0)
+
+ flattened_targets = targets.view(-1)
+ active_logits = eval_logits.view(-1, model.num_labels)
+ flattened_predictions = torch.argmax(active_logits, axis=1)
+ active_accuracy = mask.view(-1) == 1
+ targets = torch.masked_select(flattened_targets, active_accuracy)
+ predictions = torch.masked_select(flattened_predictions, active_accuracy)
+
+ eval_labels.extend(targets)
+ eval_preds.extend(predictions)
+
+ tmp_eval_accuracy = accuracy_score(
+ targets.cpu().numpy(), predictions.cpu().numpy()
+ )
+ eval_accuracy += tmp_eval_accuracy
+
+ labels = [label.item() for label in eval_labels]
+ predictions = [pred.item() for pred in eval_preds]
+
+ eval_loss = eval_loss / nb_eval_steps
+ eval_accuracy = eval_accuracy / nb_eval_steps
+ print(f"Validation Loss: {eval_loss}")
+ print(f"Validation Accuracy: {eval_accuracy}")
+
+ writer.add_scalar("Loss/eval", eval_loss, epoch * len(eval_dataloader))
+ writer.add_scalar("Acc/eval", eval_accuracy, epoch * len(eval_dataloader))
+ writer.flush()
+
+ return eval_accuracy
+
+
+device = "cuda" if cuda.is_available() else "cpu"
+data = torch.load(args.data_path)
+
+tokenizer = AutoTokenizer.from_pretrained(args.model_name)
+model = AutoModelForTokenClassification.from_pretrained(
+ args.model_name, num_labels=2, ignore_mismatched_sizes=True
+)
+model.to(device)
+
+assert len(data["origin"]) == len(data["labels"])
+text_label = [(text, label) for text, label in zip(data["origin"], data["labels"])]
+random.shuffle(text_label)
+train_data = text_label[: int(len(text_label) * 0.8)]
+val_data = text_label[int(len(text_label) * 0.8) :]
+
+train_text = [text for text, label in train_data]
+train_label = [label for text, label in train_data]
+val_text = [text for text, label in val_data]
+val_label = [label for text, label in val_data]
+
+train_dataset = TokenClfDataset(
+ train_text, train_label, MAX_LEN, tokenizer=tokenizer, model_name=args.model_name
+)
+val_dataset = TokenClfDataset(
+ val_text, val_label, MAX_LEN, tokenizer=tokenizer, model_name=args.model_name
+)
+
+print(f"len taining set: {len(train_dataset)}, len validation set: {len(val_dataset)}")
+print(train_dataset[0])
+for token, label in zip(
+ tokenizer.convert_ids_to_tokens(train_dataset[0]["ids"][:30]),
+ train_dataset[0]["targets"][:30],
+):
+ print("{0:10} {1}".format(token, label.item()))
+train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
+
+val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
+
+ids = train_dataset[0]["ids"].unsqueeze(0)
+mask = train_dataset[0]["mask"].unsqueeze(0)
+targets = train_dataset[0]["targets"].unsqueeze(0)
+ids = ids.to(device)
+mask = mask.to(device)
+targets = targets.to(device)
+outputs = model(input_ids=ids, attention_mask=mask, labels=targets)
+initial_loss = outputs[0]
+print(initial_loss)
+
+tr_logits = outputs[1]
+print(tr_logits.shape)
+
+optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
+
+best_acc = 0
+for epoch in tqdm(range(args.num_epoch)):
+ print(f"Training epoch: {epoch + 1}")
+ train(epoch)
+ acc = test(model, val_dataloader)
+ if acc > best_acc:
+ best_acc = acc
+ torch.save(model.state_dict(), args.save_path)
diff --git a/experiments/llmlingua2/model_training/utils.py b/experiments/llmlingua2/model_training/utils.py
new file mode 100644
index 0000000..87877f6
--- /dev/null
+++ b/experiments/llmlingua2/model_training/utils.py
@@ -0,0 +1,118 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import spacy
+import torch
+from torch.utils.data import Dataset
+
+
+class TokenClfDataset(Dataset):
+ def __init__(
+ self,
+ texts,
+ labels=None,
+ max_len=512,
+ tokenizer=None,
+ model_name="bert-base-multilingual-cased",
+ ):
+ self.len = len(texts)
+ self.texts = texts
+ self.tokenizer = tokenizer
+ self.max_len = max_len
+ self.labels = labels
+ self.model_name = model_name
+ if "bert-base-multilingual-cased" in model_name:
+ self.cls_token = "[CLS]"
+ self.sep_token = "[SEP]"
+ self.unk_token = "[UNK]"
+ self.pad_token = "[PAD]"
+ self.mask_token = "[MASK]"
+ elif "xlm-roberta-large" in model_name:
+ self.bos_token = ""
+ self.eos_token = ""
+ self.sep_token = ""
+ self.cls_token = ""
+ self.unk_token = ""
+ self.pad_token = ""
+ self.mask_token = ""
+ else:
+ raise NotImplementedError()
+
+ self.nlp = spacy.load("en_core_web_sm")
+
+ def __getitem__(self, index):
+ text = self.texts[index]
+ if self.labels is not None:
+ labels = self.labels[index][:]
+ tokenized_text, labels = self.tokenize_and_preserve_labels(
+ text, labels, self.tokenizer
+ )
+ assert len(tokenized_text) == len(labels)
+ labels.insert(0, False)
+ labels.insert(-1, False)
+ else:
+ tokenized_text = self.tokenizer.tokenize(text)
+
+ tokenized_text = [self.cls_token] + tokenized_text + [self.sep_token]
+
+ if len(tokenized_text) > self.max_len:
+ tokenized_text = tokenized_text[: self.max_len]
+ if self.labels is not None:
+ labels = labels[: self.max_len]
+ else:
+ tokenized_text = tokenized_text + [
+ self.pad_token for _ in range(self.max_len - len(tokenized_text))
+ ]
+ if self.labels is not None:
+ labels = labels + [False for _ in range(self.max_len - len(labels))]
+
+ attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text]
+
+ ids = self.tokenizer.convert_tokens_to_ids(tokenized_text)
+
+ sample = {
+ "ids": torch.tensor(ids, dtype=torch.long),
+ "mask": torch.tensor(attn_mask, dtype=torch.long),
+ }
+ if self.labels is not None:
+ sample["targets"] = torch.tensor(labels, dtype=torch.long)
+
+ return sample
+
+ def __len__(self):
+ return self.len
+
+ def split_string(self, input_string, ignore_tokens=set([","])):
+ doc = self.nlp(input_string)
+ word_list = []
+ for word in doc:
+ if word.lemma_ not in ignore_tokens:
+ word_list.append(word.lemma_)
+ return word_list
+
+ def tokenize_and_preserve_labels(self, text, text_labels, tokenizer):
+ """
+ Word piece tokenization makes it difficult to match word labels
+ back up with individual word pieces. This function tokenizes each
+ word one at a time so that it is easier to preserve the correct
+ label for each subword. It is, of course, a bit slower in processing
+ time, but it will help our model achieve higher accuracy.
+ """
+
+ tokenized_text = []
+ labels = []
+
+ assert len(self.split_string(text)) == len(text_labels)
+
+ for word, label in zip(self.split_string(text), text_labels):
+ # Tokenize the word and count # of subwords the word is broken into
+ tokenized_word = tokenizer.tokenize(word)
+ n_subwords = len(tokenized_word)
+
+ # Add the tokenized word to the final tokenized word list
+ tokenized_text.extend(tokenized_word)
+
+ # Add the same label to the new list of labels `n_subwords` times
+ labels.extend([label] * n_subwords)
+
+ return tokenized_text, labels
diff --git a/images/LLMLingua-2.png b/images/LLMLingua-2.png
new file mode 100644
index 0000000..f9d0154
Binary files /dev/null and b/images/LLMLingua-2.png differ
diff --git a/images/motivation.png b/images/motivation.png
index e93ca1a..aec5f5a 100644
Binary files a/images/motivation.png and b/images/motivation.png differ
diff --git a/llmlingua/__init__.py b/llmlingua/__init__.py
index b1fc6e0..d750210 100644
--- a/llmlingua/__init__.py
+++ b/llmlingua/__init__.py
@@ -1,8 +1,8 @@
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
+
# flake8: noqa
from .prompt_compressor import PromptCompressor
from .version import VERSION as __version__
-
__all__ = ["PromptCompressor"]
diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py
index 0e7bcb6..a26a569 100644
--- a/llmlingua/prompt_compressor.py
+++ b/llmlingua/prompt_compressor.py
@@ -2,19 +2,32 @@
# Licensed under The MIT License [see LICENSE for details]
import bisect
+import copy
import re
+import string
from collections import defaultdict
from typing import List
-import numpy as np
-import torch
-
import nltk
+import numpy as np
import tiktoken
-from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
-
-
-encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ AutoModelForTokenClassification,
+ AutoTokenizer,
+)
+
+from .utils import (
+ TokenClfDataset,
+ get_pure_token,
+ is_begin_of_new_word,
+ replace_added_token,
+ seed_everything,
+)
class PromptCompressor:
@@ -33,11 +46,20 @@ class PromptCompressor:
device_map (str, optional): The device to load the model onto, e.g., "cuda" for GPU. Default is "cuda".
model_config (dict, optional): A dictionary containing the configuration parameters for the model. Default is an empty dictionary.
open_api_config (dict, optional): A dictionary containing configuration for openai APIs that may be used in conjunction with the model. Default is an empty dictionary.
-
+ use_llmlingua2 (bool, optional): Whether to use llmlingua-2 compressor based on the paper
+ "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression".
+ Zhuoshi Pan, Qianhui Wu, Huiqiang Jiang, Menglin Xia, Xufang Luo, Jue Zhang, Qingwei Lin, Victor Ruhle, Yuqing Yang, Chin-Yew Lin, H. Vicky Zhao, Lili Qiu, Dongmei Zhang.
+ "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression". arXiv preprint arXiv:,
+ Default is True.
+ llmlingua2_config (dict, optional): A dictionary containing the configuration parameters for llmlingua-2. Default is
+ {
+ "max_batch_size": 50,
+ "max_force_token": 100, # max number of the tokens which will be forcely preserved
+ }
Example:
- >>> compress_method = PromptCompressor(model_name="gpt2", device_map="cuda")
+ >>> compress_method = PromptCompressor(model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank", use_llmlingua2=True, )
>>> context = ["This is the first context sentence.", "Here is another context sentence."]
- >>> result = compress_method.compress_prompt(context)
+ >>> result = compress_method.compress_prompt(context, use_context_level_filter=True, target_token=5)
>>> print(result["compressed_prompt"])
# This will print the compressed version of the context.
@@ -51,13 +73,44 @@ def __init__(
device_map: str = "cuda",
model_config: dict = {},
open_api_config: dict = {},
+ use_llmlingua2: bool = False,
+ llmlingua2_config: dict = {},
):
- self.load_model(model_name, device_map, model_config)
+ self.model_name = model_name
+ self.use_llmlingua2 = use_llmlingua2
self.retrieval_model = None
self.retrieval_model_name = None
self.open_api_config = open_api_config
self.cache_bos_num = 10
self.prefix_bos_num = 100
+ self.oai_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
+
+ self.load_model(model_name, device_map, model_config)
+ if use_llmlingua2:
+ self.init_llmlingua2(**llmlingua2_config)
+
+ def init_llmlingua2(
+ self,
+ max_batch_size: int = 50,
+ max_force_token: int = 100,
+ ):
+ seed_everything(42)
+ self.max_batch_size = max_batch_size
+ self.max_seq_len = 512
+ self.max_force_token = max_force_token
+ self.special_tokens = set(
+ [
+ v
+ for k, v in self.tokenizer.special_tokens_map.items()
+ if k != "additional_special_tokens"
+ ]
+ )
+
+ self.added_tokens = [f"[NEW{i}]" for i in range(max_force_token)]
+ self.tokenizer.add_special_tokens(
+ {"additional_special_tokens": self.added_tokens}
+ )
+ self.model.resize_token_embeddings(len(self.tokenizer))
def load_model(
self, model_name: str, device_map: str = "cuda", model_config: dict = {}
@@ -65,40 +118,40 @@ def load_model(
trust_remote_code = model_config.get("trust_remote_code", True)
if "trust_remote_code" not in model_config:
model_config["trust_remote_code"] = trust_remote_code
- config = AutoConfig.from_pretrained(
- model_name, trust_remote_code=trust_remote_code
- )
- tokenizer = AutoTokenizer.from_pretrained(
- model_name, trust_remote_code=trust_remote_code
- )
+ config = AutoConfig.from_pretrained(model_name, **model_config)
+ tokenizer = AutoTokenizer.from_pretrained(model_name, **model_config)
if model_config.get("pad_to_left", True):
tokenizer.padding_side = "left"
tokenizer.pad_token_id = (
config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
)
+ MODEL_CLASS = (
+ AutoModelForTokenClassification
+ if any("ForTokenClassification" in ar for ar in config.architectures)
+ else AutoModelForCausalLM
+ )
self.device = (
device_map
if any(key in device_map for key in ["cuda", "cpu", "mps"])
else "cuda"
)
if "cuda" in device_map or "cpu" in device_map:
- model = AutoModelForCausalLM.from_pretrained(
+ model = MODEL_CLASS.from_pretrained(
model_name,
- torch_dtype="auto" if device_map == "cuda" else torch.float32,
+ torch_dtype=model_config.get(
+ "torch_dtype", "auto" if device_map == "cuda" else torch.float32
+ ),
device_map=device_map,
config=config,
ignore_mismatched_sizes=True,
**model_config,
)
else:
- model = AutoModelForCausalLM.from_pretrained(
+ model = MODEL_CLASS.from_pretrained(
model_name,
device_map=device_map,
- torch_dtype="auto",
+ torch_dtype=model_config.get("torch_dtype", "auto"),
pad_token_id=tokenizer.pad_token_id,
- offload_folder="/tmp/offload",
- offload_state_dict=True,
- cache_dir="/tmp/cache",
**model_config,
)
self.tokenizer = tokenizer
@@ -336,6 +389,17 @@ def compress_prompt(
context_segs: List[str] = None,
context_segs_rate: List[float] = None,
context_segs_compress: List[bool] = None,
+ target_context: int = -1,
+ context_level_rate: float = 1.0,
+ context_level_target_token: int = -1,
+ return_word_label: bool = False,
+ word_sep: str = "\t\t|\t\t",
+ label_sep: str = " ",
+ token_to_word: str = "mean",
+ force_tokens: List[str] = [],
+ force_reserve_digit: bool = False,
+ drop_consecutive: bool = False,
+ chunk_end_tokens: List[str] = [".", "\n"],
):
"""
Compresses the given context.
@@ -377,15 +441,53 @@ def compress_prompt(
rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua".
concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True.
+ target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target.
+ context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0.
+ context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression.
+ Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario.
+ force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
+ return_word_label (bool, optional): Whether to return word with corresponding label. Default is False.
+ word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t".
+ label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label. Default is " ".
+ token_to_word (str, optional): How to convert token probability to word probability. Default is "mean".
+ force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is [].
+ force_reserve_digit (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False.
+ drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt.
+ Default is False.
+ chunk_end_tokens (List[str], optinal): The early stop tokens for segmenting chunk. Default is [".", "\n"],
Returns:
dict: A dictionary containing:
- "compressed_prompt" (str): The resulting compressed prompt.
+ - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt. Only used in llmlingua2.
+ - "fn_labeled_original_prompt" (str): original words along with their labels
+ indicating whether to reserve in compressed prompt, in the format (word label_sep label)
+ Only used in llmlingua2 when return_word_label = True.
- "origin_tokens" (int): The original number of tokens in the input.
- "compressed_tokens" (int): The number of tokens in the compressed output.
- "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression.
- "rate" (str): The compression rate achieved, in a human-readable format.
- "saving" (str): Estimated savings in GPT-4 token usage.
"""
+ if self.use_llmlingua2:
+ return self.compress_prompt_llmlingua2(
+ context,
+ rate=rate,
+ target_token=target_token,
+ use_context_level_filter=use_context_level_filter,
+ use_token_level_filter=use_token_level_filter,
+ target_context=target_context,
+ context_level_rate=context_level_rate,
+ context_level_target_token=context_level_target_token,
+ force_context_ids=force_context_ids,
+ return_word_label=return_word_label,
+ word_sep=word_sep,
+ label_sep=label_sep,
+ token_to_word=token_to_word,
+ force_tokens=force_tokens,
+ force_reserve_digit=force_reserve_digit,
+ drop_consecutive=drop_consecutive,
+ chunk_end_tokens=chunk_end_tokens,
+ )
assert (
rate <= 1.0
), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]."
@@ -409,7 +511,9 @@ def compress_prompt(
else "none_condition"
)
origin_tokens = len(
- encoding.encode("\n\n".join([instruction] + context + [question]).strip())
+ self.oai_tokenizer.encode(
+ "\n\n".join([instruction] + context + [question]).strip()
+ )
)
context_tokens_length = [self.get_token_length(c) for c in context]
instruction_tokens_length, question_tokens_length = self.get_token_length(
@@ -541,7 +645,7 @@ def compress_prompt(
compressed_prompt = "\n\n".join(res)
- compressed_tokens = len(encoding.encode(compressed_prompt))
+ compressed_tokens = len(self.oai_tokenizer.encode(compressed_prompt))
saving = (origin_tokens - compressed_tokens) * 0.06 / 1000
ratio = 1 if compressed_tokens == 0 else origin_tokens / compressed_tokens
rate = 1 / ratio
@@ -554,10 +658,266 @@ def compress_prompt(
"saving": f", Saving ${saving:.1f} in GPT-4.",
}
- def get_token_length(self, text: str, add_special_tokens: bool = True):
- return len(
- self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids
- )
+ def compress_prompt_llmlingua2(
+ self,
+ context: List[str],
+ rate: float = 0.5,
+ target_token: int = -1,
+ use_context_level_filter: bool = False,
+ use_token_level_filter: bool = True,
+ target_context: int = -1,
+ context_level_rate: float = 1.0,
+ context_level_target_token: int = -1,
+ force_context_ids: List[int] = [],
+ return_word_label: bool = False,
+ word_sep: str = "\t\t|\t\t",
+ label_sep: str = " ",
+ token_to_word: str = "mean",
+ force_tokens: List[str] = [],
+ force_reserve_digit: bool = False,
+ drop_consecutive: bool = False,
+ chunk_end_tokens: List[str] = [".", "\n"],
+ ):
+ """
+ Compresses the given context, instruction and question.
+
+ Args:
+ context (List[str]): List of context strings that form the basis of the prompt.
+ rate (float, optional): The minimum compression rate target to be achieved. Default is 0.5. The actual compression rate
+ generally exceeds the specified target, but there can be fluctuations due to differences in tokenizers. If specified,
+ it should be a float greater than or equal to 1.0, representing the target compression rate.
+ target_token (int, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target.
+ The actual number of tokens after compression should generally be less than the specified target_token, but there can
+ be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
+ the sole criterion, overriding the rate.
+ target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target.
+ Only used in the coarse-to-fine compression.
+ context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0.
+ Only used in the coarse-to-fine compression.
+ context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression.
+ Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario.
+ force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
+ return_word_label (bool, optional): Whether to return word with corresponding label. Default is False.
+ word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t".
+ label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label. Default is " ".
+ token_to_word (str, optional): How to convert token probability to word probability. Default is "mean".
+ force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is [].
+ force_reserve_digit (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False.
+ drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt.
+ Default is False.
+ chunk_end_tokens (List[str], optional): The early stop tokens for segmenting chunk. Default is [".", "\n"].
+ Returns:
+ dict: A dictionary containing:
+ - "compressed_prompt" (str): The resulting compressed prompt.
+ - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt.
+ - "fn_labeled_original_prompt" (str): original words along with their labels
+ indicating whether to reserve in compressed prompt, in the format (word label_sep label)
+ - "origin_tokens" (int): The original number of tokens in the input.
+ - "compressed_tokens" (int): The number of tokens in the compressed output.
+ - "ratio" (str): The compression ratio achieved, in a human-readable format.
+ - "rate" (str): The compression rate achieved, in a human-readable format.
+ - "saving" (str): Estimated savings in GPT-4 token usage.
+
+ """
+ assert len(force_tokens) <= self.max_force_token
+ token_map = {}
+ for i, t in enumerate(force_tokens):
+ if len(self.tokenizer.tokenize(t)) != 1:
+ token_map[t] = self.added_tokens[i]
+ chunk_end_tokens = copy.deepcopy(chunk_end_tokens)
+ for c in chunk_end_tokens:
+ if c in token_map:
+ chunk_end_tokens.append(token_map[c])
+ chunk_end_tokens = set(chunk_end_tokens)
+
+ if type(context) == str:
+ context = [context]
+ context = copy.deepcopy(context)
+
+ if len(context) == 1 and use_context_level_filter:
+ use_context_level_filter = False
+
+ n_original_token = 0
+ context_chunked = []
+ for i in range(len(context)):
+ n_original_token += self.get_token_length(
+ context[i], use_oai_tokenizer=True
+ )
+ for ori_token, new_token in token_map.items():
+ context[i] = context[i].replace(ori_token, new_token)
+ context_chunked.append(
+ self.__chunk_context(context[i], chunk_end_tokens=chunk_end_tokens)
+ )
+
+ if use_context_level_filter:
+ # want use_context_level_filter but do not specify any parameters in context level?
+ # we will set context_level_rate = (rate + 1.0) / 2 if specify rate or target_token * 2 if specify target_token
+ if (
+ target_context <= 0
+ and context_level_rate >= 1.0
+ and context_level_target_token <= 0
+ ):
+ if target_token < 0 and rate < 1.0:
+ context_level_rate = (
+ (rate + 1.0) / 2 if use_token_level_filter else rate
+ )
+ if target_token >= 0:
+ context_level_target_token = (
+ target_token * 2 if use_token_level_filter else target_token
+ )
+
+ if target_context >= 0:
+ context_level_rate = min(target_context / len(context), 1.0)
+ if context_level_target_token >= 0:
+ context_level_rate = min(
+ context_level_target_token / n_original_token, 1.0
+ )
+
+ context_probs, context_words = self.__get_context_prob(
+ context_chunked,
+ token_to_word=token_to_word,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ )
+
+ threshold = np.percentile(
+ context_probs, int(100 * (1 - context_level_rate))
+ )
+
+ reserved_context = []
+ context_label = [False] * len(context_probs)
+ for i, p in enumerate(context_probs):
+ if p >= threshold or (
+ force_context_ids is not None and i in force_context_ids
+ ):
+ reserved_context.append(context_chunked[i])
+ context_label[i] = True
+ n_reserved_token = 0
+ for chunks in reserved_context:
+ for c in chunks:
+ n_reserved_token += self.get_token_length(c, use_oai_tokenizer=True)
+ if target_token >= 0:
+ rate = min(target_token / n_reserved_token, 1.0)
+
+ if use_token_level_filter:
+ compressed_context, word_list, word_label_list = self.__compress(
+ reserved_context,
+ reduce_rate=max(0, 1 - rate),
+ token_to_word=token_to_word,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ drop_consecutive=drop_consecutive,
+ )
+ else:
+ compressed_context, word_list, word_label_list = self.__compress(
+ reserved_context,
+ reduce_rate=0,
+ token_to_word=token_to_word,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ drop_consecutive=drop_consecutive,
+ )
+
+ n_compressed_token = 0
+ for c in compressed_context:
+ n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True)
+ saving = (n_original_token - n_compressed_token) * 0.06 / 1000
+ ratio = (
+ 1 if n_compressed_token == 0 else n_original_token / n_compressed_token
+ )
+ res = {
+ "compressed_prompt": "\n\n".join(compressed_context),
+ "compressed_prompt_list": compressed_context,
+ "origin_tokens": n_original_token,
+ "compressed_tokens": n_compressed_token,
+ "ratio": f"{ratio:.1f}x",
+ "rate": f"{1 / ratio * 100:.1f}%",
+ "saving": f", Saving ${saving:.1f} in GPT-4.",
+ }
+ if return_word_label:
+ words = []
+ labels = []
+ j = 0
+ for i in range(len(context)):
+ if context_label[i]:
+ words.extend(word_list[j])
+ labels.extend(word_label_list[j])
+ j += 1
+ else:
+ words.extend(context_words[i])
+ labels.extend([0] * len(context_words[i]))
+ word_label_lines = word_sep.join(
+ [f"{word}{label_sep}{label}" for word, label in zip(words, labels)]
+ )
+ res["fn_labeled_original_prompt"] = word_label_lines
+ return res
+
+ if target_token > 0:
+ rate = min(target_token / n_original_token, 1.0)
+
+ if use_token_level_filter:
+ compressed_context, word_list, word_label_list = self.__compress(
+ context_chunked,
+ reduce_rate=max(0, 1 - rate),
+ token_to_word=token_to_word,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ drop_consecutive=drop_consecutive,
+ )
+ else:
+ compressed_context, word_list, word_label_list = self.__compress(
+ context_chunked,
+ reduce_rate=0,
+ token_to_word=token_to_word,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ drop_consecutive=drop_consecutive,
+ )
+
+ n_compressed_token = 0
+ for c in compressed_context:
+ n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True)
+ saving = (n_original_token - n_compressed_token) * 0.06 / 1000
+ ratio = 1 if n_compressed_token == 0 else n_original_token / n_compressed_token
+ res = {
+ "compressed_prompt": "\n\n".join(compressed_context),
+ "compressed_prompt_list": compressed_context,
+ "origin_tokens": n_original_token,
+ "compressed_tokens": n_compressed_token,
+ "ratio": f"{ratio:.1f}x",
+ "rate": f"{1 / ratio * 100:.1f}%",
+ "saving": f", Saving ${saving:.1f} in GPT-4.",
+ }
+ if return_word_label:
+ words = []
+ labels = []
+ for w_list, l_list in zip(word_list, word_label_list):
+ words.extend(w_list)
+ labels.extend(l_list)
+
+ word_label_lines = word_sep.join(
+ [f"{word}{label_sep}{label}" for word, label in zip(words, labels)]
+ )
+ res["fn_labeled_original_prompt"] = word_label_lines
+ return res
+
+ def get_token_length(
+ self,
+ text: str,
+ add_special_tokens: bool = True,
+ use_oai_tokenizer: bool = False,
+ ):
+ if use_oai_tokenizer:
+ return len(self.oai_tokenizer.encode(text))
+ else:
+ return len(
+ self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids
+ )
def get_prefix_length(self, prefix: str, text: str):
possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1)
@@ -1554,7 +1914,6 @@ def get_distance_bge_llmembedder(corpus, query):
def get_distance_jinza(corpus, query):
from numpy.linalg import norm
-
from transformers import AutoModel
def cos_sim(a, b):
@@ -1723,3 +2082,301 @@ def concate_segment_info(
else:
new_segment_info.append((seg_len, seg_ratio, seg_compress))
return new_segment_info
+
+ def __get_context_prob(
+ self,
+ context_list: list,
+ token_to_word="mean",
+ force_tokens: List[str] = [],
+ token_map: dict = {},
+ force_reserve_digit: bool = False,
+ ):
+ chunk_list = []
+ for chunks in context_list:
+ for c in chunks:
+ chunk_list.append(c)
+
+ dataset = TokenClfDataset(
+ chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len
+ )
+ dataloader = DataLoader(
+ dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False
+ )
+
+ chunk_probs = []
+ chunk_words = []
+ with torch.no_grad():
+ for batch in dataloader:
+ ids = batch["ids"].to(self.device, dtype=torch.long)
+ mask = batch["mask"].to(self.device, dtype=torch.long) == 1
+
+ outputs = self.model(input_ids=ids, attention_mask=mask)
+ loss, logits = outputs.loss, outputs.logits
+ probs = F.softmax(logits, dim=-1)
+
+ for j in range(ids.shape[0]):
+ _probs = probs[j, :, 1]
+ _ids = ids[j]
+ _mask = mask[j]
+
+ active_probs = torch.masked_select(_probs, _mask)
+ active_ids = torch.masked_select(_ids, _mask)
+
+ tokens = self.tokenizer.convert_ids_to_tokens(
+ active_ids.squeeze().tolist()
+ )
+ token_probs = [prob for prob in active_probs.cpu().numpy()]
+
+ (
+ words,
+ valid_token_probs,
+ valid_token_probs_no_force,
+ ) = self.__merge_token_to_word(
+ tokens,
+ token_probs,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ )
+ word_probs_no_force = self.__token_prob_to_word_prob(
+ valid_token_probs_no_force, convert_mode=token_to_word
+ )
+
+ if "xlm-roberta-large" in self.model_name:
+ for i in range(len(words)):
+ words[i] = words[i].lstrip("▁")
+ chunk_words.append(words)
+ chunk_probs.append(word_probs_no_force)
+
+ prev_idx = 0
+ context_probs = []
+ context_words = []
+ for chunk_list in context_list:
+ n_chunk = len(chunk_list)
+ context_probs.append([])
+ context_words.append([])
+ for i in range(n_chunk):
+ context_probs[-1].extend(chunk_probs[prev_idx + i])
+ context_words[-1].extend(chunk_words[prev_idx + i])
+ prev_idx = prev_idx + n_chunk
+ context_probs = [sum(probs) / len(probs) for probs in context_probs]
+ return context_probs, context_words
+
+ def __chunk_context(self, origin_text, chunk_end_tokens):
+ origin_list = []
+ origin_tokens = self.tokenizer.tokenize(origin_text)
+ n = len(origin_tokens)
+ st = 0
+ while st < n:
+ if st + self.max_seq_len > n - 1:
+ chunk = self.tokenizer.convert_tokens_to_string(origin_tokens[st:n])
+ origin_list.append(chunk)
+ break
+ else:
+ ed = st + self.max_seq_len
+ for j in range(0, ed - st):
+ if origin_tokens[ed - j] in chunk_end_tokens:
+ ed = ed - j
+ break
+ chunk = self.tokenizer.convert_tokens_to_string(
+ origin_tokens[st : ed + 1]
+ )
+ origin_list.append(chunk)
+ st = ed + 1
+ return origin_list
+
+ def __merge_token_to_word(
+ self, tokens, token_probs, force_tokens, token_map, force_reserve_digit
+ ):
+ words = []
+ word_probs = []
+ word_probs_no_force = []
+
+ for token, prob in zip(tokens, token_probs):
+ if token in self.special_tokens:
+ continue
+ # add a new word
+ elif is_begin_of_new_word(token, self.model_name, force_tokens, token_map):
+ pure_token = get_pure_token(token, self.model_name)
+ prob_no_force = prob
+ if pure_token in force_tokens or pure_token in set(token_map.values()):
+ prob = 1.0
+ token = replace_added_token(token, token_map)
+ words.append(token)
+ word_probs.append(
+ [
+ 1.0
+ if force_reserve_digit and bool(re.search(r"\d", token))
+ else prob
+ ]
+ )
+ word_probs_no_force.append([prob_no_force])
+ # concatenate with previous token
+ else:
+ pure_token = get_pure_token(token, self.model_name)
+ words[-1] += pure_token
+ word_probs[-1].append(
+ 1.0
+ if force_reserve_digit and bool(re.search(r"\d", token))
+ else prob
+ )
+ word_probs_no_force[-1].append(prob_no_force)
+
+ return words, word_probs, word_probs_no_force
+
+ def __token_prob_to_word_prob(self, token_probs, convert_mode="mean"):
+ if convert_mode == "mean":
+ word_probs = [sum(p) / len(p) for p in token_probs]
+ elif convert_mode == "first":
+ word_probs = [p[0] for p in token_probs]
+ else:
+ raise NotImplementedError()
+
+ return word_probs
+
+ def __compress(
+ self,
+ context_list: list,
+ reduce_rate: float = 0.5,
+ token_to_word: str = "mean",
+ force_tokens: List[str] = [],
+ token_map: dict = {},
+ force_reserve_digit: bool = False,
+ drop_consecutive: bool = False,
+ ):
+ def split_string_to_words(input_string):
+ pattern = r'\b\w+\b|[<>=/!@#$%^&*()?":{}|\\`~;_+-]'
+ result = re.findall(pattern, input_string)
+ return result
+
+ if reduce_rate <= 0:
+ words, word_labels = [], []
+ for i in range(len(context_list)):
+ chunk_list = context_list[i]
+ chunk_words = []
+ chunk_word_labels = []
+ for j in range(len(chunk_list)):
+ # replace to original token
+ for ori_token, new_token in token_map.items():
+ chunk_list[j] = chunk_list[j].replace(new_token, ori_token)
+ ws = split_string_to_words(chunk_list[j])
+ chunk_words.extend(ws)
+ chunk_word_labels.extend([1 for _ in range(len(ws))])
+ context_list[i] = "".join(chunk_list)
+ words.append(chunk_words)
+ word_labels.append(chunk_word_labels)
+ return context_list, words, word_labels
+
+ chunk_list = []
+ for chunks in context_list:
+ for c in chunks:
+ chunk_list.append(c)
+
+ dataset = TokenClfDataset(
+ chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len
+ )
+ dataloader = DataLoader(
+ dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False
+ )
+
+ compressed_chunk_list = []
+ word_list = []
+ word_label_list = []
+ with torch.no_grad():
+ for batch in dataloader:
+ ids = batch["ids"].to(self.device, dtype=torch.long)
+ mask = batch["mask"].to(self.device, dtype=torch.long) == 1
+
+ outputs = self.model(input_ids=ids, attention_mask=mask)
+ loss, logits = outputs.loss, outputs.logits
+ probs = F.softmax(logits, dim=-1)
+
+ for j in range(ids.shape[0]):
+ chunk_probs = probs[j, :, 1]
+ chunk_ids = ids[j]
+ chunk_mask = mask[j]
+
+ active_probs = torch.masked_select(chunk_probs, chunk_mask)
+ active_ids = torch.masked_select(chunk_ids, chunk_mask)
+
+ tokens = self.tokenizer.convert_ids_to_tokens(
+ active_ids.squeeze().tolist()
+ )
+ token_probs = [prob for prob in active_probs.cpu().numpy()]
+
+ words, valid_token_probs, _ = self.__merge_token_to_word(
+ tokens=tokens,
+ token_probs=token_probs,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ )
+ word_probs = self.__token_prob_to_word_prob(
+ valid_token_probs, convert_mode=token_to_word
+ )
+
+ if drop_consecutive:
+ threshold = np.percentile(word_probs, int(100 * reduce_rate))
+ is_token_between = False
+ prev = None
+ for i, (word, word_prob) in enumerate(zip(words, word_probs)):
+ if word in force_tokens:
+ if is_token_between:
+ is_token_between = False
+ elif not is_token_between and word == prev:
+ word_probs[i] = 0.0
+ prev = word
+ else:
+ is_token_between |= word_prob > threshold
+
+ new_token_probs = []
+ for word, word_prob in zip(words, word_probs):
+ num_token = len(self.oai_tokenizer.encode(word))
+ new_token_probs.extend([word_prob for _ in range(num_token)])
+ threshold = np.percentile(
+ new_token_probs, int(100 * reduce_rate + 1)
+ )
+
+ keep_words = []
+ word_labels = []
+ assert len(words) == len(word_probs)
+ for word, word_porb in zip(words, word_probs):
+ if word_porb > threshold:
+ if (
+ drop_consecutive
+ and word in force_tokens
+ and len(keep_words) > 0
+ and keep_words[-1] == word
+ ):
+ word_labels.append(0)
+ else:
+ keep_words.append(word)
+ word_labels.append(1)
+ else:
+ word_labels.append(0)
+ keep_str = self.tokenizer.convert_tokens_to_string(keep_words)
+ if "xlm-roberta-large" in self.model_name:
+ for i in range(len(words)):
+ words[i] = words[i].lstrip("▁")
+
+ compressed_chunk_list.append(keep_str)
+ word_list.append(words[:])
+ word_label_list.append(word_labels[:])
+
+ compressed_context_list = []
+ original_word_list = []
+ original_word_label_list = []
+ prev_idx = 0
+ for chunk_list in context_list:
+ n_chunk = len(chunk_list)
+ compressed_context_list.append(
+ "".join(compressed_chunk_list[prev_idx : prev_idx + n_chunk])
+ )
+ original_word_list.append([])
+ original_word_label_list.append([])
+ for i in range(n_chunk):
+ original_word_list[-1].extend(word_list[prev_idx + i])
+ original_word_label_list[-1].extend(word_label_list[prev_idx + i])
+ prev_idx = prev_idx + n_chunk
+
+ return compressed_context_list, original_word_list, original_word_label_list
diff --git a/llmlingua/utils.py b/llmlingua/utils.py
new file mode 100644
index 0000000..0e3ce16
--- /dev/null
+++ b/llmlingua/utils.py
@@ -0,0 +1,109 @@
+import os
+import random
+import string
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+
+
+class TokenClfDataset(Dataset):
+ def __init__(
+ self,
+ texts,
+ max_len=512,
+ tokenizer=None,
+ model_name="bert-base-multilingual-cased",
+ ):
+ self.len = len(texts)
+ self.texts = texts
+ self.tokenizer = tokenizer
+ self.max_len = max_len
+ self.model_name = model_name
+ if "bert-base-multilingual-cased" in model_name:
+ self.cls_token = "[CLS]"
+ self.sep_token = "[SEP]"
+ self.unk_token = "[UNK]"
+ self.pad_token = "[PAD]"
+ self.mask_token = "[MASK]"
+ elif "xlm-roberta-large" in model_name:
+ self.bos_token = ""
+ self.eos_token = ""
+ self.sep_token = ""
+ self.cls_token = ""
+ self.unk_token = ""
+ self.pad_token = ""
+ self.mask_token = ""
+ else:
+ raise NotImplementedError()
+
+ def __getitem__(self, index):
+ text = self.texts[index]
+ tokenized_text = self.tokenizer.tokenize(text)
+
+ tokenized_text = (
+ [self.cls_token] + tokenized_text + [self.sep_token]
+ ) # add special tokens
+
+ if len(tokenized_text) > self.max_len:
+ tokenized_text = tokenized_text[: self.max_len]
+ else:
+ tokenized_text = tokenized_text + [
+ self.pad_token for _ in range(self.max_len - len(tokenized_text))
+ ]
+
+ attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text]
+
+ ids = self.tokenizer.convert_tokens_to_ids(tokenized_text)
+
+ return {
+ "ids": torch.tensor(ids, dtype=torch.long),
+ "mask": torch.tensor(attn_mask, dtype=torch.long),
+ }
+
+ def __len__(self):
+ return self.len
+
+
+def seed_everything(seed: int):
+ random.seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def is_begin_of_new_word(token, model_name, force_tokens, token_map):
+ if "bert-base-multilingual-cased" in model_name:
+ if token.lstrip("##") in force_tokens or token.lstrip("##") in set(
+ token_map.values()
+ ):
+ return True
+ return not token.startswith("##")
+ elif "xlm-roberta-large" in model_name:
+ if (
+ token in string.punctuation
+ or token in force_tokens
+ or token in set(token_map.values())
+ ):
+ return True
+ return token.startswith("▁")
+ else:
+ raise NotImplementedError()
+
+
+def replace_added_token(token, token_map):
+ for ori_token, new_token in token_map.items():
+ token = token.replace(new_token, ori_token)
+ return token
+
+
+def get_pure_token(token, model_name):
+ if "bert-base-multilingual-cased" in model_name:
+ return token.lstrip("##")
+ elif "xlm-roberta-large" in model_name:
+ return token.lstrip("▁")
+ else:
+ raise NotImplementedError()
diff --git a/llmlingua/version.py b/llmlingua/version.py
index cb0f052..a844652 100644
--- a/llmlingua/version.py
+++ b/llmlingua/version.py
@@ -5,7 +5,7 @@
_MINOR = "2"
# On master and in a nightly release the patch should be one ahead of the last
# released build.
-_PATCH = "0"
+_PATCH = "1"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..361d2a5
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,11 @@
+[tool.black]
+line-length = 88
+target-version = ['py38']
+include = '\.pyi?$'
+
+[tool.isort]
+atomic = true
+profile = "black"
+line_length = 88
+skip_gitignore = true
+known_first_party = ["llmlingua"]
diff --git a/tests/test_llmlingua.py b/tests/test_llmlingua.py
index 0416f1e..60f766e 100644
--- a/tests/test_llmlingua.py
+++ b/tests/test_llmlingua.py
@@ -1,3 +1,6 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
import unittest
from llmlingua import PromptCompressor
diff --git a/tests/test_llmlingua2.py b/tests/test_llmlingua2.py
new file mode 100644
index 0000000..cfd111e
--- /dev/null
+++ b/tests/test_llmlingua2.py
@@ -0,0 +1,192 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import unittest
+
+from llmlingua import PromptCompressor
+
+
+class LLMLingua2Tester(unittest.TestCase):
+ """
+ End2end Test for LLMLingua-2
+ """
+
+ PROMPT = "John: So, um, I've been thinking about the project, you know, and I believe we need to, uh, make some changes. I mean, we want the project to succeed, right? So, like, I think we should consider maybe revising the timeline.\n\nSarah: I totally agree, John. I mean, we have to be realistic, you know. The timeline is, like, too tight. You know what I mean? We should definitely extend it."
+ COMPRESSED_SINGLE_CONTEXT_PROMPT = "John: thinking project believe need make changes. want project succeed? consider revising timeline.\n\n Sarah agree. be realistic. timeline too tight.? extend."
+ COMPRESSED_MULTIPLE_CONTEXT_PROMPT = "John: So, I've been thinking about project believe we need to make changes. we want project to succeed, right? think we should consider maybe revising timeline."
+
+ GSM8K_PROMPT = "Question: Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day?\nLet's think step by step\nAngelo and Melanie think they should dedicate 3 hours to each of the 2 chapters, 3 hours x 2 chapters = 6 hours total.\nFor the worksheets they plan to dedicate 1.5 hours for each worksheet, 1.5 hours x 4 worksheets = 6 hours total.\nAngelo and Melanie need to start with planning 12 hours to study, at 4 hours a day, 12 / 4 = 3 days.\nHowever, they need to include time for breaks and lunch. Every hour they want to include a 10-minute break, so 12 total hours x 10 minutes = 120 extra minutes for breaks.\nThey also want to include 3 10-minute snack breaks, 3 x 10 minutes = 30 minutes.\nAnd they want to include 30 minutes for lunch each day, so 120 minutes for breaks + 30 minutes for snack breaks + 30 minutes for lunch = 180 minutes, or 180 / 60 minutes per hour = 3 extra hours.\nSo Angelo and Melanie want to plan 12 hours to study + 3 hours of breaks = 15 hours total.\nThey want to study no more than 4 hours each day, 15 hours / 4 hours each day = 3.75\nThey will need to plan to study 4 days to allow for all the time they need.\nThe answer is 4\n\nQuestion: You can buy 4 apples or 1 watermelon for the same price. You bought 36 fruits evenly split between oranges, apples and watermelons, and the price of 1 orange is $0.50. How much does 1 apple cost if your total bill was $66?\nLet's think step by step\nIf 36 fruits were evenly split between 3 types of fruits, then I bought 36/3 = 12 units of each fruit\nIf 1 orange costs $0.50 then 12 oranges will cost $0.50 * 12 = $6\nIf my total bill was $66 and I spent $6 on oranges then I spent $66 - $6 = $60 on the other 2 fruit types.\nAssuming the price of watermelon is W, and knowing that you can buy 4 apples for the same price and that the price of one apple is A, then 1W=4A\nIf we know we bought 12 watermelons and 12 apples for $60, then we know that $60 = 12W + 12A\nKnowing that 1W=4A, then we can convert the above to $60 = 12(4A) + 12A\n$60 = 48A + 12A\n$60 = 60A\nThen we know the price of one apple (A) is $60/60= $1\nThe answer is 1"
+ GSM8K_150TOKENS_COMPRESSED_SINGLE_CONTEXT_PROMPT = "Question: Angelo Melanie plan test 2 chapters 4 worksheets 3 hours each chapter 1.5 hours each worksheet study 4 hours day how days 10-minute break every 3 10-minute snack breaks 30 minutes lunch\n\n dedicate 3 hours 2 chapters 3 2 = 6 hours total\n worksheets 1.5 hours each worksheet 1.5 4 = 6 hours total\n 12 hours study 4 hours a day 12 / 4 = 3 days\n breaks lunch 10-minute break 12 hours 10 = 120 minutes\n 3 10-minute snack breaks 3 10 = 30 minutes\n 30 minutes lunch 120 + 30 + 30 = 180 minutes 180 / 60 = 3 extra hours\n 12 hours study + 3 hours breaks = 15 hours total\n 4 hours each day 15 / 4 = 3.75\n 4 days\nThe answer is 4"
+ GSM8K_150TOKENS_COMPRESSED_MULTIPLE_CONTEXT_PROMPT = "4 apples 1 watermelon 36 fruits oranges watermelons 1 orange $0.50 1 apple bill $66\n\n 36 fruits 3 36/3 = 12 units\n 1 orange $0.50 12 oranges $0.50 * 12 = $6\n total bill $66 spent $6 oranges $66 - $6 = $60 other 2\n watermelon W 4 apples one apple A 1W=4A\n 12 watermelons 12 apples $60 $60 = 12W + 12A\n $60 = 12(4A + 12A\n = 48A + 12A\n = 60A\n one apple $60/60= $1\nThe answer is 1"
+
+ MEETINGBANK_PROMPT = "Item 28 Report from Development. Services Recommendation to declare ordinance amending the Land Use District Map from institutional to IRP 13 read and adopted as read District eight. Councilman Austin. So moved. Wonderful. And I want to ask Councilman Andrews so any member of the public that wishes to address item 28 saying none, members, cast your vote. Oh, I'm sorry, sir. I did not see you. Can we? I know this sounds picky and stupid. But this is an illogical motion because you haven't yet created ARP 13. By the way, unlike some other speakers, I will furnish you my name. I'm Joe Weinstein. I did speak last week. I do not like to come down here again to talk on the same subjects. But. There is a minor little matter. As to whether a. The proposed zoning is a good idea. And B, whether. The project, which it is intended. To permit. In fact. Meets the specifications of the zoning. I have not check that out, but someone else did raise that question and there may be some question as to whether all of the conditions of that zoning have, in fact, been met by the details of this project. This particular zoning, perhaps in the abstract, need not be a bad idea, but the way you see it realized in the project. Is not a very good idea. You could have the same density and more without destroying the usability, the usable green space that this design does. Because really, although it looks impressive from a top down view, it looks like you see plenty of green space between the buildings, that that space is pretty well wasted and useless because the buildings are high enough to pretty well shade and dominate the green space that's in that project. So I'm not saying that the density that you're going for is a bad thing. But doing it in this way doesn't work, and any zoning that just permits this without further control is not a good idea. Thank you. Okay. Thank you, sir. Members, please cast your vote. Councilman Andrew's motion carries. Next time, please. Report from Development Services recommendation to declare ordinance amending the Land Use District Map from institutional to park red and adopted as Red District eight."
+ MEETINGBANK_150TOKENS_COMPRESSED_SINGLE_CONTEXT_PROMPT = "Item 28 Report Development. Services declare ordinance amending Land Use District Map institutional IRP 13 adopted District eight. Councilman Austin. moved ask Councilman Andrews public address item 28 cast vote.?. illogical motion created ARP 13. Joe Weinstein. last week. subjects. minor matter. proposed zoning good idea. project. Meets specifications zoning. conditions zoning met details project. zoning not bad. not good. same density more without destroying green space. green space buildings wasted useless buildings shade dominate green space. not density bad. doesn't work zoning permits without control not good idea. Thank you. cast vote. Councilman Andrew's motion carries. Next time. Report Development Services declare ordinance amending Land Use District Map institutional to park red adopted District eight"
+
+ LONGBENCH_PROMPT_LIST = [
+ "新闻内容:\n(服务·健康)专家提醒:寒冷气候易诱发心脑血管疾病\n新华社海口2月9日专电(张苏民、李建国)海口市疾病预防控制中心专家介绍,持续的寒冷气候是心脑血管疾病的杀手,尤其患有高血压或高血脂疾病的老人更应做好防范,防止脑中风发生。\n 在寒冷的气候环境当中要注意保暖,增添衣服,饮食以清淡为主,多食用蔬菜,忌暴食荤类。尤其过年时,切忌熬夜,平时要加强身体锻炼,劳逸结合。除此之外,冬季还是呼吸道传染病暴发和流行的季节,应该注意预防流感、麻疹、流脑、水痘等呼吸道传染病的发生。\n 专家还指出,由于寒冷气候影响,人们习惯门窗紧闭,空气不对流,一旦有传染源传入,极容易造成疾病的暴发。春节期间,一些商场或公共娱乐场所人群密集,有关单位应加强通风。(完)\n类别:医药、卫生",
+ "\n\n新闻内容:\n李明波在恩施调研时强调 大力推进基层党内民主建设\n本报讯 (记者吴畏、通讯员曾言、周恩祖)11日至13日,省委常委、秘书长李明波到恩施州调研基层党建工作时强调,要以增强党的创新活力、巩固党的团结统一为目标,以改革创新精神大力抓好基层党内民主建设。\n 李明波视察了非公有制企业党建、党代表常任制、基层党务公开、以党内和谐推进社区和谐等党建工作现场,与基层党务工作者座谈。李明波强调,在新形势下,要把握好民主进程与经济社会发展、尊重党员主体地位与提高党员民主素质、履行党员民主权利与保证党的统一意志、发挥党员民主监督作用与加强党纪教育管理等的关系,进一步深入探索,在丰富形式、拓宽渠道、完善机制等方面取得更大成绩。\n类别:政治",
+ "\n\n新闻内容:\n第38届世界贸易中心年会及经贸洽谈会\n第38届世界贸易中心年会将于2007年10月21至24日在美国路易斯\n安那州首府新奥尔良召开。该会由美国纽约世界贸易中心总部和美国贸\n易服务管理总局、新奥尔良世贸中心共同举办,届时将有来自60多个国\n家和地区的经贸代表团约600余人与会。天津贸促会与天津世贸中心协\n会将共同组织天津经贸代表团赴美国参加“世贸中心2007年年会及经贸\n洽谈会”。\n 联系人:王岭 刘鹏\n 电话:022-2520231725202123\n 传真:022-25201975\n 地址:天津经济技术开发区宏达街19号A区2楼\n类别:商业、外贸、海关",
+ "\n\n新闻内容:\n(全运会)第十一届全运会开闭幕时间确定\n新华社济南6月5日体育专电(记者赵仁伟)第十一届全国运动会组委会5日在济南宣布,十一运会将于今年10月16日在济南奥体中心开幕,闭幕时间为10月28日。\n 十一运会组委会常务副秘书长、山东省体育局局长张洪涛介绍,十一运会的比赛项目共设33个大项、43个分项、362个小项,其中包括28个夏季奥运会项目、4个冬季项目以及武术项目。与2005年十运会相比,大项增加了1个,即自由式滑雪;小项增加了5个,分别是自由式滑雪男子个人、女子个人,女子水球项目,足球男子16岁以下组和女子18岁以下组。\n 在十一运会全部362个小项中,马拉松男、女2个小项的比赛在北京举办,速度滑冰4个小项、自由式滑雪2个小项的比赛分别在沈阳和长春举办,其余354个小项的比赛在山东省17个赛区举行。其中,济南赛区共举办小项212个,青岛48个,日照40个,滨州28个,枣庄8个,菏泽7个,威海5个,烟台、德州各3个;淄博、东营、潍坊、济宁、泰安、莱芜、临沂、聊城8个赛区只举办小组赛和第四名以后的比赛,不产生金牌。\n 张洪涛介绍,十一运会冰雪项目已于1月至4月举行,占全部小项的4.4%。因部分夏季项目的世界锦标赛或国际重要赛事的时间与十一运会比赛时间冲突或相距较近,国家体育总局确定把这些项目的比赛安排在开幕式前举行,共有15个项目、80个小项,占全部小项的22.1%。(完)\n类别:体育",
+ "\n\n新闻内容:\n(教育)河北整顿公办初中、小学招收择校生\n(教育)河北整顿公办初中、小学招收择校生\n 新华社石家庄3月12日电(冯月静)记者从河北省教育纪检监察审计工作会议上了解到,从今年起,河北省不再审批新的改制学校。对已审批的改制学校进行一次全面整顿和规范,重点解决公办初中、小学以改制为名或以民办为名举办“校中校”“校中班”高收费问题。\n 据了解,河北省规定达不到要求的,要限期整改;年内仍达不到标准要求的,一律停止招生。公办学校一律不准搞“一校两制”,更不准以改制为名高收费。\n 同时,今年秋季新学年开始,设区市市区的公办省级示范性普通高中(含在县镇办学的市直属省级示范性高中)择校生比例最高限额由原定的40%一律下调为30%。严禁学校擅自扩大择校生招生比例、降低录取分数线、提高收费标准或在限定金额外加收任何其他费用。(完)\n类别:教育",
+ "\n\n新闻内容:\n(服务·关注“过劳死”) “过劳死”青睐什么人?\n人?\n 新华社郑州3月16日专电(记者李丽静) 有关专家\n研究表明:受教育程度高、中青年、女性是“过劳死”这\n一疾病的危险人群。这是因为这些人事业上强力拼搏,生\n活负荷过重,自身经常处于紧张状态之中,过度疲劳难以\n避免。\n 随着社会竞争的日趋激烈,该病也越来越多地困扰着\n我国的都市人。据一项在上海、无锡、深圳等地对\n1197位中年人健康状况调查显示:其中66%的人有\n失眠、多梦、不易入睡等现象;62%的人经常腰酸背痛;\n58%的人一干活就累;57%的人爬楼时感到吃力或记\n忆力明显减退;48%的人皮肤干燥、瘙痒、面色晦暗、\n脾气暴躁、焦虑。据国家有关部门的一项调查结果表明,\n慢性疲劳综合征在城市新兴行业人群中的发病率为10%\n至20%,在科技、新闻、广告、公务人员、演艺、出租\n车司机等行业中发病率则更高。\n 有关专家通过统计认为,“过劳死”特别“青睐”三\n种人:\n 第一种是有钱但不知保养的人。这部分人“富裕”的\n背后,往往有一条铺满辛酸的路。由于对贫穷的恐惧,使\n他们对财富永远不满足。为了追逐更多的财富,即使赴汤\n蹈火也在所不辞,而对他们最初惟一的资本———身体,\n则很不在乎。 \n 第二种是有事业心,特别是称得上“工作狂”的人。\n主要以从事科研、教学、新型高科技,如网络等职业者居\n多。\n 第三种是有家族遗传背景者。如果父母亲、爷爷奶奶\n等直系亲属中有心绞痛、心肌梗死、脑中风的患者,就要\n特别小心了,千万别让自己累着,否则很有可能在年轻时\n就诱发疾病。\n 而在对“过劳死”人群深入研究中发现,猝死直接死\n因的前5位是冠状动脉疾病、主动脉瘤、心瓣膜病、心肌\n病和脑出血。一些无症状冠心病,特别是无症状心肌梗死\n是首要的危险因素,一般的体检和心电图不易发现隐性冠\n心病。一旦发作,措手不及。此外,高血压也是一个潜在\n的危险因素。在遇到某些诱因时,便会引发高血压、脑中\n风等。(完)\n类别:医药、卫生",
+ "\n\n新闻内容:\n五项措施应对技术性贸易壁垒\n调查结果显示,2006年我国有31\n .4%的出口企业受到国外技术性贸易措施不同程度的影响,比2005年增长6.3个百分点;全年出口贸易直接损失359.20亿美元,占同期出口额的3.71%,企业新增成本191.55亿美元。\n 会议通报的情况显示,对中国企业出口影响较大的技术性贸易措施类型集中在认证要求、技术标准要求、有毒有害物质限量要求、包装及材料的要求和环保要求(包括节能及产品回收),食品中农兽药残留要求、重金属等有害物质限量要求、细菌等卫生指标要求、食品标签要求和食品接触材料的要求等方面。受国外技术性贸易措施影响较大的行业排在前五位的是机电、农食产品、化矿、塑料皮革和纺织鞋帽。\n 会议提出了加强应对的5点意见。一是要强化进出口质量监管措施,在“严”字上下功夫,重点从源头上抓好农兽药残留、有毒化学物质残留、微生物等问题,同时要完善监管机制,提高检测能力,要检得出,检得快,检得准。二是要加快实施技术标准战略,在“高”字上下功夫,不断提高采标率,加快标准的制修订步伐。三是要加大信息共享力度,在“准”字上下功夫,各部门要密切配合,建立沟通机制,做到信息资源的充分利用。四是要果断迅速应对突发事件,在“快”字上下功夫。五是要加强技术性贸易措施的积极应对,在“实”字上下功夫,协调配合、相互支持。\n类别:商业、外贸、海关",
+ "\n\n新闻内容:\n(新华时评·奥运会倒计时一百天)让我们共同守护奥林匹克精神\n新华社北京4月30日电 题:让我们共同守护奥林匹克精神\n 新华社记者张旭\n 在北京奥运会倒计时一百天之际,奥运圣火结束在其他国家的传递进入中国香港。在这两个重要时间节点重合之时,让我们以奥林匹克精神为依归,回味今年以来围绕北京奥运的风风雨雨,并以百倍的努力在接下来的日子里守护这一美好理想。\n 奥林匹克运动会是古希腊人的体育盛会,许多比赛项目源于古希腊文化。顾拜旦说:“古希腊人之所以组织竞赛活动,不仅仅只是为了锻炼体格和显示一种廉价的壮观场面,更是为了教育人”。更高更快更强并不是现代奥林匹克运动的全部价值诉求。现代奥林匹克运动经过了一百年的历史变迁,向世界传达的精神与主题始终如一,那就是在共同创造、共同分享、平等友爱的旗帜下,展现人类最美好的情感。奥林匹克是迄今为止人类社会不同种族、地域乃至不同意识形态间最大的交集。\n 2001年7月13日,时任国际奥委会主席的萨马兰奇宣布北京取得2008年奥运会主办权,现代奥林匹克运动从奥林匹亚来到万里长城。7年后的春天,当奥运圣火开始在中国境外传递时,妖魔化中国的舆论攻势和扰乱奥运火炬传递的暴力举动让海内外目光聚焦中国。我们可以肯定地说,这些人在为一己之私对奥林匹克精神进行亵渎。\n 北京奥运圣火一路走来,虽然遇到了噪音和干扰,但更多面对的还是像火一样热情的世界人民和对奥林匹克精神充分尊重的各国人士。他们因为懂得尊重奥林匹克精神,因此也能够享受奥林匹克带来的快乐。\n 2008年4月30日,“北京欢迎你”的歌声回荡在有着近600年历史的紫禁城太庙上空。8月8日,中国人民将第一次以东道主的身份在北京承办举世瞩目的奥林匹克运动会。北京奥运会对中国来说不仅仅是一次体育盛会,更是一次与世界各国开展文化交流的机会。如同当年奥林匹亚为神圣的无战争区域一样,体育竞技的目标是为了全世界的和平与发展。北京奥运会也完全可以成为世界各种文明一个共同的精神家园,通过沟通交流,达到良性互动。\n 奥运会的脚步声离我们越来越近的时候,奥林匹克运动正在为13亿中国人民所熟悉,奥林匹克精神也继续在世界范围内承载起人类追求幸福生活的梦想。中国人民真诚地邀请各国运动员、教练员和朋友们参与2008年北京奥运会。中国人民同时真诚地邀请全世界热爱奥林匹克精神和奥林匹克运动的人们一起,共同守护这一人类美好理想,让它在北京奥运会上开放出更加美丽的花朵。(完)\n类别:体育",
+ "\n\n新闻内容:\n海口“接管”省 特殊教育 学校\n创建于1989年的海南省特殊教育学校原属省教育厅直属正处级事业单位,为海南省惟一一所全日寄宿的公立特殊教育学校。\n 我市“接管”省特殊教育学校之后,将继续面向全省招收视障、听障两类适龄儿童,优化教育布局调整,促进特殊教育又好又快发展。\n类别:教育",
+ "\n\n新闻内容:\n9月7日特稿(加1)(美国-大学流感)\n美一大学两千学生恐染流感\n 马震\n 美国华盛顿州立大学大约2000名学生报告甲型H1N1流感症状。校方和医护人员说,这可能是最严重的一起大学生感染新型流感事件。\n (小标题)人数众多\n 这所大学位于华盛顿州普尔曼,主校区大约有1.9万名学生。据美国《纽约时报》网络版6日报道,华盛顿州注册护士萨莉·雷德曼证实了大约2000名华盛顿州立大学学生报告流感症状一事。\n 雷德曼在华盛顿州立大学学生医疗部门工作。她说,流感暴发情况出现在8月21日,那时学校还没开学。但如今为学生提供医疗服务的部门总是门庭若市。有一天,大约有200名学生就诊或给医疗机构打电话报告喉咙疼、发烧、咳嗽等症状。\n 华盛顿州立大学所在惠特曼县的卫生部门官员说,州实验室上周的检测结果显示,这所大学的疫情确实是因甲型H1N1流感病毒引起。\n 学校现已开学。法新社本月6日报道,学校上周开了关于流感疫情的博客,博客上最新的信息说:“秋季学期的前10天,我们估计已与大约2000名有流感症状的人联络。”\n 校方管理人员说,一些学生可能到社区医院就诊,一些学生可能居家自我治疗。校方无法掌握这些人的人数,已要求当地卫生部门提供相关数据,以便校方更好了解疫情情况。\n (小标题)无一死亡\n 华盛顿州立大学已根据国家疾病控制和预防中心的防流感指南向学生提供咨询服务,以避免疫情进一步加重。学校还向学生发放了一些防流感的药品和护具等。\n 为防止甲型流感传播,美国的一些大学已建立起隔离机制,但华盛顿州立大学没有类似机制。雷德曼说,在华盛顿州立大学上报的大部分流感疫情案例中,疑似染病的学生被要求待在居所内休息并吃退烧药。如果这些人在不吃退烧药24小时后体温仍旧正常,就可以正常来上课。\n 美国已有593例与甲型流感有关的死亡病例,但华盛顿州立大学尚未发现一起死亡病例。到目前为止,学生的流感症状相对温和,只有两个不是学生的患者入院治疗。\n 校方在声明中说:“我校患者中的绝大部分症状温和,通常3到5天就能见强。”\n (小标题)担心传播\n 华盛顿州立大学大规模流感疫情出现前,美国大学健康协会于8月28日对165所大学实施了流感疫情调查。调查结果显示,全国超过2000名学生报告说有甲型流感症状。\n 惠特曼县公共卫生部门负责人蒂莫西·穆迪认为本月晚些时候开学的其他大学可能会遭遇类似华盛顿州立大学的情况,而地方医疗机构会担心疫情可能向校外蔓延。\n 国家疾病控制和预防中心主任托马斯·弗里登6日接受美国有线电视新闻网采访时说,学校医务人员本学年报告的流感数字不同寻常。疾病控制和预防中心此前未遭遇过8月和9月数字增长这么快的情况。\n 国家疾病控制和预防中心现在特别重视流感疫情。弗里登说:“如果它的致命性增加,可能会造成特别严重的情形,可能会给上学和上班的人带来特别多麻烦。”(完)(新华社供本报特稿)\n 关键词:华盛顿州立大学(Washington State University)\n类别:医药、卫生",
+ "\n\n新闻内容:\n在国防教育的落实上下功夫\n在国防教育的落实上下功夫 赵荣\n 加强全民国防教育是增强国防观念和忧患意识、促进国防和军队建设的基础性工程。鉴此,在今后的实践中,要坚持以科学发展观为指导,科学谋划、创新形式、狠抓落实,使全民国防教育深入人心,扎实有效地开展下去。\n 抓好责任落实。《国防教育法》第三章第十八条规定:各地区各部门的领导人员应当依法履行组织、领导本地区、本部门开展国防教育的职责。因而,要使全民国防教育扎实有效地开展下去,各级领导和职能部门要依法负起抓好全民国防教育的责任,对本地区、本单位、本行业的国防教育,从计划安排到组织实施都要认真负责地抓好落实。\n 抓好人员落实。国防教育是面向全民的教育,它的开展必须面向全社会,而不能只针对个别地区、个别单位和个别人员。因而,各地要对一切有接受能力的公民实施国防教育,以提高全民的政治、思想和道德素质,使全体公民积极争当热爱祖国、热爱国防的好公民。\n 抓好效果落实。国防教育的开展,效果的落实极为重要。为此,教育中应着重抓好国防理论、国防精神、国防知识、国防历史、国防技能、国防法制的教育,以强化爱国精神、增长国防知识、强化国防观念。通过教育,使全体公民进一步了解我国安全面临的新形势、世界军事变革的新发展、我国国防和军队建设面临的新挑战、以及在对国防建设中应承担的义务和责任等,不断提高他们支持和关心国防建设的积极性和自觉性。\n (来源:中国国防报 发布时间: 2007-11-22 08:19)\n类别:军事",
+ "\n\n新闻内容:\n中国又一学者当选瑞典皇家工程科学院外籍院士\n新华社北京8月20日电 北京航空航天大学中国循环经济研究中心主任、北京循环经济促进会会长吴季松教授,日前被瑞典皇家工程科学院全体大会选为该院外籍院士。\n 作为改革开放后首批出国访问学者之一,吴季松曾在欧洲原子能联营法国原子能委员会研究受控热核聚变,还曾任中国常驻联合国教科文组织代表团参赞衔副代表、联合国教科文组织科技部门高技术与环境顾问。1985至1986年,主持联合国教科文组织“多学科综合研究应用于经济发展”专题研究,并由联合国教科文组织发表项目研究报告创意知识经济。\n 他在中国科技和产业领域作出了多项贡献,主要包括:创意“知识经济”并将科技园区的实践介绍到中国、提出修复生态系统理论并主持制定水资源规划、创立新循环经济学等。\n 瑞典皇家工程科学院创建于1919年,是世界上第一个工程院,现有机械工程、电机工程等学部。该院参与相关诺贝尔奖项的提名和评审工作。目前共有院士(含外籍院士)近1100人,来自中国的外籍院士包括宋健、徐冠华等。(完)\n类别:科学技术",
+ ]
+ LONGBENCH_1000TOKENS_COMPRESSED_MULTIPLE_CONTEXT_PROMPT = "\n 新闻内容 第38届世界贸易中心年会及经贸洽谈会\n 安那州首府新奥尔良召开。\n 易服务管理总局、新奥尔良世贸中心共同举办\n 家和地区的经贸代表团约600余人与会。 天津贸促会与天津世贸中心协\n 会将共同组织天津经贸代表团赴美国参加“世贸中心2007年年会及经贸\n 洽谈会”。\n 联系人:王岭 刘鹏\n 电话:022-2520231725202123\n 传真:022-25201975\n 地址:天津经济 技术开发区宏达街19号A区2楼\n类别:商业、外贸、海关\n\n\n 新闻内容\n 海口“接管”省 特殊教育 学校\n 创建于1989年的海南省特殊教育 学校原属省教育 厅直属正处级事业单位,为海南省惟一一所全日寄宿的公立特殊教育 学校。\n教育 学校之后,将继续面向全省招收视障、听障两类适龄儿童教育 布局调整教育。\n类别:教育\n\n\n 中国又一学者当选瑞典皇家工程科学院外籍院士\n 新华社北京8月20日电 北京航空航天大学中国循环经济 研究中心主任、北京循环经济 促进会会长吴季松教授,日前被瑞典皇家工程科学院全体大会选为该院外籍院士。\n 作为改革开放后首批出国访问学者之一,吴季松曾在欧洲原子能联营法国原子能委员会研究受控热核聚变,还曾任中国常驻联合国教科文组织代表团参赞衔副代表、联合国教科文组织科技部门高技术与环境顾问。 1985至1986年,主持联合国教科文组织“多学科综合研究应用于经济 发展”专题研究经济。\n:创意“知识经济 ”并将科技园区的实践介绍到中国、提出修复生态系统理论并主持制定水资源规划、创立新循环经济 学等。\n 瑞典皇家工程科学院创建于1919年,是世界上第一个工程院,现有机械工程、电机工程等学部。 目前共有院士(含外籍院士)近1100人,来自中国的外籍院士包括宋健、徐冠华等。\n类别:科学技术"
+
+ def __init__(self, *args, **kwargs):
+ super(LLMLingua2Tester, self).__init__(*args, **kwargs)
+ self.llmlingua = PromptCompressor(
+ model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
+ device_map="cpu",
+ use_llmlingua2=True,
+ )
+
+ def test_general_compress_prompt(self):
+ compressed_prompt = self.llmlingua.compress_prompt(
+ self.PROMPT,
+ rate=0.33,
+ force_tokens=["\n", ".", "!", "?"],
+ drop_consecutive=False,
+ force_reserve_digit=False,
+ )
+ self.assertEqual(
+ compressed_prompt["compressed_prompt"],
+ self.COMPRESSED_SINGLE_CONTEXT_PROMPT,
+ )
+ self.assertEqual(compressed_prompt["origin_tokens"], 98)
+ self.assertEqual(compressed_prompt["compressed_tokens"], 30)
+ self.assertEqual(compressed_prompt["ratio"], "3.3x")
+ self.assertEqual(compressed_prompt["rate"], "30.6%")
+
+ compressed_prompt = self.llmlingua.compress_prompt(
+ self.PROMPT.split("\n\n"),
+ target_token=40,
+ use_context_level_filter=True,
+ force_tokens=["\n", ".", "!", "?"],
+ drop_consecutive=False,
+ force_reserve_digit=False,
+ )
+ self.assertEqual(
+ compressed_prompt["compressed_prompt"],
+ self.COMPRESSED_MULTIPLE_CONTEXT_PROMPT,
+ )
+ self.assertEqual(compressed_prompt["origin_tokens"], 98)
+ self.assertEqual(compressed_prompt["compressed_tokens"], 34)
+ self.assertEqual(compressed_prompt["ratio"], "2.9x")
+ self.assertEqual(compressed_prompt["rate"], "34.7%")
+
+ # Single Context
+ compressed_prompt = self.llmlingua.compress_prompt(
+ self.GSM8K_PROMPT.split("\n\n")[0],
+ target_token=170,
+ force_tokens=[
+ "+",
+ "-",
+ "*",
+ "×",
+ "/",
+ "÷",
+ "=",
+ "The answer is",
+ "\n",
+ "Question:",
+ ],
+ drop_consecutive=False,
+ force_reserve_digit=True,
+ )
+ self.assertEqual(
+ compressed_prompt["compressed_prompt"],
+ self.GSM8K_150TOKENS_COMPRESSED_SINGLE_CONTEXT_PROMPT,
+ )
+ self.assertEqual(compressed_prompt["origin_tokens"], 422)
+ self.assertEqual(compressed_prompt["compressed_tokens"], 203)
+ self.assertEqual(compressed_prompt["ratio"], "2.1x")
+ self.assertEqual(compressed_prompt["rate"], "48.1%")
+
+ # Single Context
+ compressed_prompt = self.llmlingua.compress_prompt(
+ self.MEETINGBANK_PROMPT.split("\n\n")[0],
+ target_token=150,
+ force_tokens=["\n", ".", "?", "!"],
+ drop_consecutive=True,
+ force_reserve_digit=False,
+ )
+ self.assertEqual(
+ compressed_prompt["compressed_prompt"],
+ self.MEETINGBANK_150TOKENS_COMPRESSED_SINGLE_CONTEXT_PROMPT,
+ )
+ self.assertEqual(compressed_prompt["origin_tokens"], 464)
+ self.assertEqual(compressed_prompt["compressed_tokens"], 154)
+ self.assertEqual(compressed_prompt["ratio"], "3.0x")
+ self.assertEqual(compressed_prompt["rate"], "33.2%")
+
+ # Multiple Context
+ compressed_prompt = self.llmlingua.compress_prompt(
+ self.GSM8K_PROMPT.split("\n\n"),
+ target_token=150,
+ use_context_level_filter=True,
+ force_tokens=["+", "-", "*", "×", "/", "÷", "=", "The answer is", "\n"],
+ drop_consecutive=False,
+ force_reserve_digit=True,
+ )
+ self.assertEqual(
+ compressed_prompt["compressed_prompt"],
+ self.GSM8K_150TOKENS_COMPRESSED_MULTIPLE_CONTEXT_PROMPT,
+ )
+ self.assertEqual(compressed_prompt["origin_tokens"], 726)
+ self.assertEqual(compressed_prompt["compressed_tokens"], 161)
+ self.assertEqual(compressed_prompt["ratio"], "4.5x")
+ self.assertEqual(compressed_prompt["rate"], "22.2%")
+
+ # Multiple Context
+ compressed_prompt = self.llmlingua.compress_prompt(
+ self.LONGBENCH_PROMPT_LIST,
+ target_token=1000,
+ use_context_level_filter=True,
+ force_tokens=[
+ "\n",
+ "。",
+ ":",
+ "?",
+ "类别:",
+ "农业、农村",
+ "军事",
+ "文学、艺术",
+ "体育",
+ "传媒业",
+ "电子信息产业",
+ "文化、休闲娱乐",
+ "社会、劳动",
+ "经济",
+ "服务业、旅游业",
+ "环境、气象",
+ "能源、水务、水利",
+ "财政、金融",
+ "教育",
+ "科学技术",
+ "对外关系、国际关系",
+ "矿业、工业",
+ "政治",
+ "交通运输、邮政、物流",
+ "灾难、事故",
+ "基本建设、建筑业、房地产",
+ "医药、卫生",
+ "法律、司法",
+ "商业、外贸、海关",
+ ],
+ drop_consecutive=True,
+ force_reserve_digit=False,
+ )
+ self.assertEqual(
+ compressed_prompt["compressed_prompt"],
+ self.LONGBENCH_1000TOKENS_COMPRESSED_MULTIPLE_CONTEXT_PROMPT,
+ )
+ self.assertEqual(compressed_prompt["origin_tokens"], 8389)
+ self.assertEqual(compressed_prompt["compressed_tokens"], 870)
+ self.assertEqual(compressed_prompt["ratio"], "9.6x")
+ self.assertEqual(compressed_prompt["rate"], "10.4%")
diff --git a/tests/test_longllmlingua.py b/tests/test_longllmlingua.py
index 7fe906c..2763cc9 100644
--- a/tests/test_longllmlingua.py
+++ b/tests/test_longllmlingua.py
@@ -1,3 +1,6 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
import unittest
from llmlingua import PromptCompressor