本仓库利用peft库与transformers.Trainer,实现对ChatGLM2的简单4-bit/8-bit LoRA微调。(其它LLM应该也行,只要稍作修改)
This repo uses peft and transformers.Trainer to achieve simple 4-bit/8-bit LoRA fine-tuning for ChatGLM2. (You can also use this repo for other LLM with minor modifications)
$ pip install -r requirement.txt
requirement.txt:
datasets==2.13.1
protobuf
transformers==4.30.2
cpm_kernels
torch>=2.0
mdtex2html
sentencepiece
accelerate
git+https://github.com/huggingface/peft.git
bitsandbytes
loralib
scipy
文件config.py参数如下:
- MICRO_BATCH_SIZE,每块GPU的batch size大小。
- BATCH_SIZE,真正的batch size,当每个batch的处理样本数达到BATCH_SIZE时,进行梯度更新。
- EPOCHS,总训练代数。
- WARMUP_STEPS,预热步数。
- LEARNING_RATE,学习率。
- CONTEXT_LEN,context字段截断长度(对应json文件的context)。
- TARGET_LEN,target字段截断长度(对应json文件的target)。
- TEXT_LEN,text字段截断长度(对应txt文件的文本)。
- LORA_R,LoRA低秩的秩数。
- LORA_ALPHA,LoRA的alpha。
- LORA_DROPOUT,LoRA层的Dropout率。
- MODEL_NAME,模型名称(huggingface仓库地址)。
- LOGGING_STEPS,日志步数,即训练的时候输出loss的间隔步数。
- OUTPUT_DIR,输出LoRA权重的存放文件夹位置。
- DATA_PATH,数据集文件位置。
- DATA_TYPE,数据集文件类型,可选json或txt。
- SAVE_STEPS,保存LoRA权重的间隔步数。
- SAVE_TOTAL_LIMIT,保存LoRA权重文件的总数(不包括最终权重)。
- PROMPT,推理时的prompt。
- TEMPERATURE,推理时的温度,调整模型的创造力。
- LORA_CHECKPOINT_DIR,待推理LoRA权重的文件夹位置。
- BIT_4,使用4bit量化+LoRA微调。
- BIT_8,使用8bit量化+LoRA微调。
The parameters in config.py are as follows:
- MICRO_BATCH_SIZE,batch size on each device。
- BATCH_SIZE,when the number of processed samples in each split batch reaches BATCH_SIZE, update the gradient.
- EPOCHS,training epochs。
- WARMUP_STEPS,warmup steps。
- LEARNING_RATE,learning rate of fine-tuning。
- CONTEXT_LEN,truncation length of context (in json)。
- TARGET_LEN,truncation length of target (in json)。
- TEXT_LEN,truncation length of text (in txt)。
- LORA_R,LoRA low rank。
- LORA_ALPHA,LoRA Alpha。
- LORA_DROPOUT,LoRA dropout。
- MODEL_NAME,model name (huggingface repo address)。
- LOGGING_STEPS,the number of interval steps for outputting loss during training。
- OUTPUT_DIR,the storage folder location for LoRA weights。
- DATA_PATH,the location of your dataset file。
- DATA_TYPE,the type of your dataset file, including json and txt。
- SAVE_STEPS,the number of interval steps to save LoRA weights。
- SAVE_TOTAL_LIMIT,the total number of LoRA weight files saved (excluding the final one)。
- PROMPT,your prompt when inference。
- TEMPERATURE,the temperature when inference, adjusting the creativity of LLM。
- LORA_CHECKPOINT_DIR,folder location for LoRA weights to be inferred。
- BIT_4,use 4-bit。
- BIT_8,use 8-bit。
json文件格式如下:
The JSON file format is as follows:
{"context":question1, "target":answer1}{"context":question2, "target":answer2}...
txt文件格式如下:
The txt file format is as follows:
sentence1
sentence2
sentence3
...
$ sh train.sh
train.sh:
python main.py \
--MICRO_BATCH_SIZE 8 \
--BATCH_SIZE 16 \
--EPOCHS 50 \
--LEARNING_RATE 5e-4 \
--CONTEXT_LEN 64 \
--TARGET_LEN 192 \
--LORA_R 16 \
--LORA_DROPOUT 0.5 \
--MODEL_NAME THUDM/chatglm2-6b \
--OUTPUT_DIR ./output_model \
--DATA_PATH ./new_train.json \
--DATA_TYPE json \
--SAVE_STEPS 1000 \
--BIT_4
$ sh inference.sh
inference.sh:
python inference.py \
--CONTEXT_LEN 256 \
--MODEL_NAME THUDM/chatglm2-6b \
--LORA_CHECKPOINT_DIR ./output_model/checkpoint-4000/ \
--BIT_4 \
--PROMPT "put your prompt here"
THUDM/ChatGLM2-6B: ChatGLM2-6B: An Open Bilingual Chat LLM | 开源双语对话语言模型 (github.com)
mymusise/ChatGLM-Tuning: 一种平价的chatgpt实现方案, 基于ChatGLM-6B + LoRA (github.com)
- [2023/08/02]:更新了LoraConfig的target_modules。
- [2023/07/27]:对于QA的训练,更新loss的计算目标,只计算问题部分(json里面的target字段)的loss。
- [2023/07/25]:添加4-bit量化LoRA训练。
- [2023/07/24]:添加eos_token_id,解决重复输出的问题。