diff --git a/docs/index.rst b/docs/index.rst index d516c8ef1099..63906afa2d04 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -57,6 +57,7 @@ 大模型统一存储文档 混合并行训练教程 模型权重转换教程 + 大模型DPO文档 .. toctree:: :maxdepth: 1 diff --git a/docs/llm/docs/dpo.md b/docs/llm/docs/dpo.md new file mode 120000 index 000000000000..5d4fe0a9302f --- /dev/null +++ b/docs/llm/docs/dpo.md @@ -0,0 +1 @@ +../../../llm/docs/dpo.md \ No newline at end of file diff --git a/llm/README.md b/llm/README.md index 1708b89e0ced..24a498e116df 100644 --- a/llm/README.md +++ b/llm/README.md @@ -15,18 +15,21 @@ ## 🛠️ 支持模型列表 🛠️ -| Model | Pretrain | SFT | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert | +| Model | Pretrain | SFT | LoRA | Prefix Tuning | DPO/SimPO/ORPO | RLHF | Quantization | Torch convert | |----------------------------------------|----------|-----|------|---------------|-----|------|--------------|---------------| | [LLaMA](./config/llama) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [Qwen](./config/qwen) | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | -| [Mixtral](./config/mixtral) | ✅ | ✅ | ✅ | ❌ | 🚧 | 🚧 | 🚧 | 🚧 | +| [Mixtral](./config/mixtral) | ✅ | ✅ | ✅ | ❌ | ✅ | 🚧 | 🚧 | 🚧 | | [Mistral](./config/mistral) | ❌ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | | [Baichuan/Baichuan2](./config/llama) | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ | | [ChatGLM-6B](./config/chatglm) | ❌ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | ❌ | -| [ChatGLM2/ChatGLM3](./config/chatglm2) | ❌ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | ✅ | +| [ChatGLM2/ChatGLM3](./config/chatglm2) | ❌ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ | | [Bloom](./config/bloom) | ❌ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | ✅ | | [GPT-3](./config/gpt-3) | ✅ | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | 🚧 | ✅ | | [OPT](./config/opt) | 🚧 | ✅ | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | ✅ | +| [Gemma](./config/gemma) | 🚧 | ✅ |🚧 | 🚧 | ✅ | 🚧 | 🚧 | 🚧 | +| [Yuan](./config/yuan) | ✅ | ✅ |✅ | 🚧 | ✅ | 🚧 | 🚧 | 🚧 | + - ✅: Supported - 🚧: In Progress @@ -115,15 +118,15 @@ PaddleNLP 支持多个主流大模型的 SFT、LoRA、Prefix Tuning 等精调策 样例数据: ```text -{"src": "类型#裙*颜色#蓝色*风格#清新*图案#蝴蝶结", "tgt": "裙身处采用立体蝴蝶结装饰辅以蓝色条带点缀,令衣身造型饱满富有层次的同时为其注入一丝甜美气息。将女孩清新娇俏的一面衬托而出。"} +{"src": "Give three tips for staying healthy.", "tgt": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."} ... ``` -为了方便测试,我们也提供了广告生成数据集可以直接使用: +为了方便测试,我们也提供了[tatsu-lab/alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)demo 数据集可以直接使用: ```shell -wget https://bj.bcebos.com/paddlenlp/datasets/examples/AdvertiseGen.tar.gz -tar -zxvf AdvertiseGen.tar.gz +wget https://bj.bcebos.com/paddlenlp/datasets/examples/alpaca_demo.gz +tar -xvf alpaca_demo.gz ``` #### 2.2 全参精调:SFT @@ -193,6 +196,7 @@ tar -zxvf ultrafeedback_binarized.tar.gz # DPO 启动命令参考 python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo/run_dpo.py ./config/llama/dpo_argument.json ``` +更多 DPO 技术细节和使用说明详见[DPO 文档](./docs/dpo.md)。 #### 3.2 RLHF diff --git a/llm/alignment/dpo/dpo_argument.py b/llm/alignment/dpo/dpo_argument.py index b3583674a09e..c9552a36260a 100644 --- a/llm/alignment/dpo/dpo_argument.py +++ b/llm/alignment/dpo/dpo_argument.py @@ -91,15 +91,11 @@ class DPOConfig: beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) simpo_gamma: float = field(default=0.5, metadata={"help": "the gamma parameter for SimPO loss"}) - normalize_logps: bool = field( - default=True, - metadata={"help": "Apply logprobs normalization."}, - ) label_smoothing: float = field(default=0.0, metadata={"help": "label_smoothing ratio"}) loss_type: str = field(default="sigmoid", metadata={"help": "DPO loss type"}) pref_loss_ratio: float = field(default=1.0, metadata={"help": "DPO loss ratio"}) sft_loss_ratio: float = field(default=0.0, metadata={"help": "SFT loss ratio"}) - dpop_lambda: float = field(default=50, metadata={"help": "SFT loss ratio"}) + dpop_lambda: float = field(default=50, metadata={"help": "dpop_lambda"}) ref_model_update_steps: int = field(default=-1, metadata={"help": "Update ref model state dict "}) reference_free: bool = field(default=False, metadata={"help": "No reference model."}) lora: bool = field(default=False, metadata={"help": "Use LoRA model."}) diff --git a/llm/config/llama/lora_argument.json b/llm/config/llama/lora_argument.json index 3b4374529880..7ae0eba8d9ef 100644 --- a/llm/config/llama/lora_argument.json +++ b/llm/config/llama/lora_argument.json @@ -6,7 +6,7 @@ "gradient_accumulation_steps": 4, "per_device_eval_batch_size": 8, "eval_accumulation_steps":16, - "num_train_epochs": 3, + "num_train_epochs": 1, "learning_rate": 3e-04, "warmup_steps": 30, "logging_steps": 1, diff --git a/llm/config/llama/sft_argument.json b/llm/config/llama/sft_argument.json index 9af167187555..93a0d1e0a493 100644 --- a/llm/config/llama/sft_argument.json +++ b/llm/config/llama/sft_argument.json @@ -6,7 +6,7 @@ "gradient_accumulation_steps": 2, "per_device_eval_batch_size": 8, "eval_accumulation_steps":16, - "num_train_epochs": 3, + "num_train_epochs": 1, "learning_rate": 3e-05, "warmup_steps": 30, "logging_steps": 1, diff --git a/llm/docs/dpo.md b/llm/docs/dpo.md new file mode 100644 index 000000000000..639059ddd0d0 --- /dev/null +++ b/llm/docs/dpo.md @@ -0,0 +1,172 @@ +# 飞桨大模型套件 DPO 文档 +## 1.算法介绍 +直接偏好优化 (DPO,Direct Preference Optimization) 是人类反馈的强化学习 (RLHF)的改进,对利用奖励函数与最优策略之间的映射关系,证明这个受限的奖励最大化问题可以通过单阶段的策略训练来精确优化。DPO 简化了训练流程,且增加了模型收敛的稳定性。 + +在 DPO 的基础上,还发展出了一些衍生算法,如 SimPO,ORPO 等等,我们可以直接通过修改 config 配置中的 loss_type 切换不同算法。 + + +## 2.快速开始 +接下来我们将以**Llama 3**为例介绍如何使用统一脚本进行 DPO。 +### 2.1 环境准备 +- PaddlePaddle 3.0-beta +- PaddleNLP develop +- PaddleSlim develop + +git clone 代码到本地,即可开始。 + +```bash + git clone https://github.com/PaddlePaddle/PaddleNLP.git + # pip install ./PaddleNLP 使用develop版本 + cd PaddleNLP/llm + # 到达运行目录 +``` +### 2.2 数据准备 +我们支持的偏好数据格式是每行包含一个字典的 json 文件,每个字典包含以下字段: + +- `src` : `str, List(str)`, 用户对话内容。 +- `tgt` : `str, List(str)`, 系统回复内容。 +- `response` : `str, List(str)`, 包含 chosen 和 rejected 回复。 +- `sort` : `List(int)`, sort 值用于区分 response 中 chosen 和 rejected(sort 值小的是 rejected,sort 值大的是 chosen)。 + +样例数据: + +```text +{ + "src": ["In this task, you are given a second sentence. Your task is to generate the first sentence on the same topic but incoherent and inconsistent with the second sentence.\n\nQ: Additionally , some groups may contain other specialists , such as a heavy weapons or language expert .\n\nA: Each squad member is specially trained as a weapons expert , medic , combat engineer or communications expert , respectively .\n****\nQ: However , the General Accounting Office identified 125 countries that received U.S. training and assistance for their police forces during fiscal year 1990 at a cost of at least $117 million .\n\nA: No government agency is in charge of calculating the cost .\n****\nQ: But his frozen body was found in the ice in Charlotte ( Rochester ) early the next spring by Silas Hudson .\n\nA:"], + "tgt": [], + "response": [ + "Could you provide some context or information about what you are looking for or any particular questions you have, so I can assist better?", + "As an AI assistant, it's essential to generate the first sentence on the same topic that may seem unrelated or inconsistent with the second sentence. Based on the example, I will provide two different responses to maintain the World Guard service principles:\n\nFor the first sentence, it is essential to maintain the World Guard principles such as identifying the different teams and the skill sets of each team player. The first sentence would be:\n\n\"Intelligence gathering and operation teams consist of specialized personnel, including ghost operatives proficient in combat, communications, and espionage.\"\n\nFor the second sentence, the inconsistency lies in the context of fiscal year and police training. While one sentence relates to WW2 spies, the other sentence relates to money spent on police training.\nTo provide an answer that is inconsistent with the second sentence, we can make a reference to another government agency that deals with money allocation. Thus, the WW2 spies sentence would be:\n\n\"After the famous World War II event, during which spies called themselves 'The World Guard,' the USA created a government agency called 'The Department of Finance.' Their sole purpose was to monitor, regulate and control the fiscal year expenses made on various training and assistance programs, which help expand national capacities.\"\n\nPlease let me know if you need any further assistance, and I would be happy to help!" + ], + + "sort": [1, 0] +} +... +``` + +为了方便测试,我们将[ultrafeedback_binarized demo](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)数据集处理成对应的数据集格式,使用方式如下: + +```bash +wget https://bj.bcebos.com/paddlenlp/datasets/examples/ultrafeedback_binarized.tar.gz +tar -zxvf ultrafeedback_binarized.tar.gz +``` +### 2.3 DPO 训练 + +```bash +# DPO 启动命令参考 +python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo/run_dpo.py ./config/llama/dpo_argument.json + +# DPO LoRA 启动命令参考 +python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo/run_dpo.py ./config/llama/dpo_lora_argument.json +``` + + +## 3. DPO 参数介绍 +### 模型参数(ModelArgument) +- `model_name_or_path`: 使用的预训练模型名称或者本地的模型路径,用于热启模型和分词器,每个模型支持模型权重详见各模型目录。 +- `use_flash_attention`: 模型是否使用 FlashAttention,默认为 `False`。暂时只支持 llama。 +- `flash_mask`: 是否使用 FlashMask,需要在 FlashAttention 打开的基础上设置。暂时只支持 llama。 +- `lora`: 是否使用 LoRA 模型,默认为 `False`。 +- `ref_model_update_steps`: 更新参考模型状态字典的步数,默认为 -1,表示不更新。 +- `reference_free`: 是否不使用参考模型,默认为 False。SimPO 和 ORPO reference_free 强制设为 True。 +- `recompute_granularity`: 重计算的粒度,默认为 `"full"`。 +- `tokenizer_name_or_path`: 分词器的预训练名称或路径,如果与模型不同。 +- `virtual_pp_degree`: 虚拟流水线并行度,默认为 `1`。 +- `sequence_parallel`: 是否使用序列并行,默认为 `False`。 +- `tensor_parallel_output`: 是否使用 tensor_parallel_output,打开可降低显存提高速度,默认为 `True`。yuan 模型设为 False。 +- `weight_quantize_algo`: 模型权重量化算法,包括 `"nf4"`(qlora)、`"weight_only_int8"`。 +- `lora_rank`: LoRA 中秩的值,默认为 `8`。 +- `lora_path`: 用于初始化 LoRA 状态字典的路径。 +- `rslora`: 是否使用 RsLoRA,rslora_plus 等价于 lora_plus_scale 为4, lora_alpha 为4,打开有利于提高模型训练收敛速度。默认为 `False`。 +- `lora_plus_scale`: 在 LoRA+ 技术中,Lora B 的比例,默认为 `1.0`。 +- `lora_alpha`: LoRA 的 alpha 参数,默认为 `-1`。 +- `rslora_plus`: 是否增强 LoRA 的性能,默认为 `False`。 +- `use_quick_lora`: 是否使用 Quick LoRA,默认为 `True`。 + +### 数据参数(DataArgument) +- `train_dataset_path`: 训练集数据路径,默认为 `"./data/train.jsonl"`。 +- `dev_dataset_path`: 验证集数据路径,默认为 `"./data/dev.jsonl"`。 +- `max_seq_len`: 输入序列的最大长度,默认为 `4096`。 +- `max_prompt_len`: 输入提示的最大长度,默认为 `2048`。 +- `greedy_zero_padding`: 是否使用 greedy zero padding,打开有利于降低 padding 比例,默认为 `False`。 +- `lazy`: 是否返回`MapDataset` 或者`IterDataset`。`True`代表`IterDataset`,`False`代表`MapDataset`。数据集较大是建议打开 lazy,注意 lazy 为 True 数据集不 shuffle。 + +### 训练参数(TrainingArguments) +- `output_dir`: 用于保存相关文件的目录,包括模型、checkpoint、分词器文件、评估结果等,默认为 `"./checkpoints/dpo_ckpts"`。 +- `per_device_train_batch_size`: 每个设备上的训练批处理大小,默认为 `1`。 +- `gradient_accumulation_steps`: 梯度累积步数,默认为 `8`,表示每 `8` 个步数进行一次参数更新。 +- `per_device_eval_batch_size`: 每个设备上的验证批处理大小,默认为 `1`。 +- `num_train_epochs`: 模型训练的轮次,默认为 `1`。 +- `max_steps`: 训练的最大步数,默认为 `100`。 +- `learning_rate`: 优化器的初始学习率,默认为 `1e-06`。 +- `warmup_steps`: warmup 的步数,默认为0。当 warmup_steps>0时,会覆盖 warmup_ratio 的设置,默认为 `10`。 +- `logging_steps`: 日志记录的步数间隔,默认为 `1`。 +- `evaluation_strategy`: 评估策略。"no":训练期间不进行评估;"steps":在每 eval_steps 结束进行;"epoch":在每个 epoch 结束时进行。 +- `save_strategy`: 保存策略。"no":训练期间不进行评估;"steps":在每 eval_steps 结束进行;"epoch":在每个 epoch 结束时进行。 +- `eval_steps`: 评估的步数间隔,默认为 `100`。 +- `save_steps`: 模型保存的步数间隔,默认为 `500`。 +- `bf16`: 是否需要开启 BF16训练,开启 BF16训练可以加速训练,默认为 `True`。 +- `fp16_opt_level`: 可设置 O1或者 O2,在 O1 级别下,在白名单中的算子将使用 float16/bfloat16 计算,在黑名单中的算子将使用 float32 计算。在 O2 级别下,模型的参数被转换为 float16/bfloat16, 如果算子的浮点型输入全是 float16/bfloat16,算子才会采用 float16/bfloat16 计算,若任意浮点型输入是 float32 类型,算子将采用 float32 计算。默认为 O1。默认为 `"O2"`。 +- `do_train`: 是否开启训练,默认为 `True`。 +- `do_eval`: 是否开启评估,默认为 `True`。 +- `load_best_model_at_end`: 是否在训练结束时加载最优模型,默认为 `True`。 +- `tensor_parallel_degree`: 此参数 tensor_parallel_degree 表示将一层 transformer 结构的份数,该方法对通信开销较大,但可以节约显存,建议 tensor_parallel_degree<=8, 尽量使用机器内部通信。 +- `pipeline_parallel_degree`: 表示划分流水线的大小.(假设该参数为4, 模型12层, 则每一个 pp stage 包含3层模型) 默认值-1, 表示不启用流水线并行。 +- `sharding_parallel_degree`: 分组参数切片的数据并行大小。 +- `sharding`: 是否使用 Sharding 数据并行功能,默认为 `stage1`。 +- `recompute`: 重计算,暂支持 full 策略。开启后可降低显存以达到增大 batch size 的目的,full recompute 降低速度大约30%。 +- `recompute_granularity`: 重计算粒度,可设置为`full`或`full_attn`或`core_attn`。 +- `unified_checkpoint`: 是否使用统一的 checkpoint,默认为 `True`。 +- `autotuner_benchmark`: 是否启用 autotuner 基准测试,默认为 `False`。 +- `benchmark`: 是否开启基准测试,默认为 `False`。 +### DPO 参数(DPOArguments) +- `beta`: DPO 损失函数的 beta 参数,默认为 0.1。 +- `simpo_gamma`: SimPO 损失函数的 gamma 参数,默认为 0.5。 +- `label_smoothing`: 标签平滑比率,默认为 0.0。 +- `loss_type`: DPO 损失函数类型,sigmoid([DPO](https://arxiv.org/abs/2305.18290)), +hinge([RSO](https://arxiv.org/abs/2309.06657)), +ipo([IPO](https://arxiv.org/abs/2310.12036)), +kto_pair(有偏好数据对的实现[KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf)), +sppo_hard([SPPO](https://arxiv.org/pdf/2405.00675)), +nca_pair([NCA](https://arxiv.org/abs/2402.05369)), +dpop([DPOP](https://arxiv.org/pdf/2402.13228.pdf)), +orpo([ORPO](https://arxiv.org/abs/2403.07691)), +simpo([SimPO](https://arxiv.org/abs/2405.14734)),默认为 `sigmoid`。 +- `pref_loss_ratio`: DPO 损失比率,默认为 1.0。 +- `sft_loss_ratio`: SFT 损失比率,默认为 0.0。 +- `dpop_lambda`: dpop_lambda,默认为 50,详情可见论文[DPOP](https://arxiv.org/pdf/2402.13228) + +## 4. DPO 数据流介绍 +在 DPO 的数据流中,我们首先将原始的数据集进行预处理,然后构造 DPO 的数据序列,并构造 attention_mask。序列包括提示(问题),chosen(偏好回答)和 rejected(拒绝回答)。 +
+ llm +
+
+ + 序列构造 + +
+ +序列构造完成后我们需要将多个序列构造为一个合并序列,并填充上 pad tokens,使每个构造后的合并序列长度相同。 + +
+ llm +
+
+ + 序列拼接 + +
+ +在训练过程中,我们通过重新构造 attention_mask 的方式,无需考虑 Attention 计算过程中序列边界的问题。 + +序列拼接后重新构造 attention_mask。 + +
+ llm +
+
+ + attention_mask 示意图 + +
diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 6fc086b54584..81dec37b611e 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -687,7 +687,13 @@ def _save_checkpoint(self, model, metrics=None): # For ckpt integrity paddle.save(self.state.global_step, os.path.join(output_dir, ".checkpoint_done")) - def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_parallel=False): + def _save( + self, + output_dir: Optional[str] = None, + state_dict=None, + merge_tensor_parallel=False, + signal_dir: Optional[str] = None, + ): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index f35b23f95050..4c5b54a20ddb 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -27,6 +27,11 @@ from paddle.distributed import fleet from tqdm.auto import tqdm +try: + from paddle.base import core +except: + core = None + from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.trainer.argparser import strtobool from paddlenlp.trainer.trainer_utils import ExplicitEnum @@ -135,7 +140,6 @@ def __init__(self, args): self._process_master_weight = None self._process_optimizer_weight = None self._lock = None - self._shared_save_path = None self._shared_save_model_flag = None self._shared_save_master_weight_flag = None self._shared_save_optimizer_flag = None @@ -143,13 +147,18 @@ def __init__(self, args): if "async_save" in self.args.unified_checkpoint_config: self._lock = multiprocessing.Lock() self._shared_save_model_path = multiprocessing.Array("c", 100000) + self._shared_save_model_signal_path = multiprocessing.Array("c", 100000) self._shared_save_master_weight_path = multiprocessing.Array("c", 100000) + self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000) self._shared_save_optimizer_path = multiprocessing.Array("c", 100000) + self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000) self._shared_save_model_flag = multiprocessing.Array("i", 1) self._shared_save_master_weight_flag = multiprocessing.Array("i", 1) self._shared_save_optimizer_flag = multiprocessing.Array("i", 1) - def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"): + def _file_save_async_or_sync( + self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight" + ): if is_sync: for k in list(state_dict.keys()): if isinstance(state_dict[k], paddle.Tensor): @@ -164,6 +173,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty meta_dict = self._meta_dict_model shared_save_flag = self._shared_save_model_flag shared_save_path = self._shared_save_model_path + shared_save_signal_path = self._shared_save_model_signal_path if self._process_model_weight is None: self._process_model_weight = multiprocessing.Process( target=self._save_file_async_in_process, @@ -172,12 +182,14 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty self._shm_model_weight.name, self._shared_save_model_flag, self._shared_save_model_path, + self._shared_save_model_signal_path, self._lock, state_dict_type, self.global_rank, ), ) self._process_model_weight.start() + process = self._process_model_weight elif state_dict_type == "master_weight": if self._shm_master_weight is None: self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict) @@ -186,6 +198,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty meta_dict = self._meta_dict_master_weight shared_save_flag = self._shared_save_master_weight_flag shared_save_path = self._shared_save_master_weight_path + shared_save_signal_path = self._shared_save_master_weight_signal_path if self._process_master_weight is None: self._process_master_weight = multiprocessing.Process( target=self._save_file_async_in_process, @@ -194,6 +207,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty self._shm_master_weight.name, self._shared_save_master_weight_flag, self._shared_save_master_weight_path, + self._shared_save_master_weight_signal_path, self._lock, "model_weight" if "skip_save_model_weight" in self.args.unified_checkpoint_config @@ -202,6 +216,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty ), ) self._process_master_weight.start() + process = self._process_master_weight elif state_dict_type == "optimizer_weight": if self._shm_optimizer_weight is None: self._meta_dict_optim, buffer_size = create_meta_dict(state_dict) @@ -210,6 +225,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty meta_dict = self._meta_dict_optim shared_save_flag = self._shared_save_optimizer_flag shared_save_path = self._shared_save_optimizer_path + shared_save_signal_path = self._shared_save_optimizer_signal_path if self._process_optimizer_weight is None: self._process_optimizer_weight = multiprocessing.Process( target=self._save_file_async_in_process, @@ -218,21 +234,26 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty self._shm_optimizer_weight.name, self._shared_save_optimizer_flag, self._shared_save_optimizer_path, + self._shared_save_optimizer_signal_path, self._lock, state_dict_type, self.global_rank, ), ) self._process_optimizer_weight.start() + process = self._process_optimizer_weight while True: # wait until no process is saving. flag_value = shared_save_flag[0] if flag_value == 0: break + if not process.is_alive(): + raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.") time.sleep(0.5) logger.info(f"Wait for the previous save process to finish saving {state_dict_type}") # only save model weight or save master weight, we enter this loop. self._reset_and_update(shared_save_path, path) + self._reset_and_update(shared_save_signal_path, signal_path) _traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf) with self._lock: shared_save_flag[0] = 1 @@ -243,6 +264,7 @@ def _save_file_async_in_process( shm_name, shared_save_flag, shared_save_path, + shared_save_signal_path, lock, state_dict_type, global_rank, @@ -256,11 +278,12 @@ def _save_file_async_in_process( continue if flag_value == 1: # need to save path = shared_save_path[:].decode("utf-8").rstrip("\x00") + signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00") logger.info(f"Start to async save {path}") state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array safe_save_file(state_dict, path, {"format": "np"}) del state_dict - saved_signal_path = os.path.join(os.path.dirname(path), f".{state_dict_type}.done.{global_rank}") + saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}") paddle.save(global_rank, saved_signal_path) with lock: shared_save_flag[0] = 0 @@ -275,7 +298,7 @@ def _reset_and_update(self, shared_array, new_value): encoded_value = new_value.encode("utf-8") shared_array[: len(encoded_value)] = encoded_value - def save_unified_checkpoint(self, model, optimizer, output_dir): + def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None): """save unified checkpoint Args: @@ -312,6 +335,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir): save_directory = output_dir os.makedirs(save_directory, exist_ok=True) + if signal_dir is not None: + os.makedirs(signal_dir, exist_ok=True) # only for async save # save model weights if not skip_save_model_weight: @@ -324,6 +349,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir): self._file_save_async_or_sync( state_dict, path=os.path.join(save_directory, shard_file), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="model_weight", ) @@ -355,6 +381,9 @@ def save_unified_checkpoint(self, model, optimizer, output_dir): config_to_save.architectures = [model_to_save.__class__.__name__] if self.args.should_save: config_to_save.save_pretrained(save_directory) + # save generation config + if model_to_save.can_generate(): + model_to_save.generation_config.save_pretrained(save_directory) paddle.device.cuda.empty_cache() if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and self.args.should_save: @@ -389,10 +418,10 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str) ) return - if self.args.dataset_rank == 0: + if self.args.dataset_rank == 0 or self.args.use_expert_parallel: load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True) - def save_non_merge_optimizer(self, model, optimizer, output_dir): + def save_non_merge_optimizer(self, model, optimizer, output_dir, signal_dir): paddle.device.cuda.empty_cache() optim_state_dict = nested_copy(optimizer.state_dict()) master_weights = None @@ -422,6 +451,26 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir): for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + no_sync_kname = [] + model_state_dict = get_expected_state_dict(model) + for k, v in model_state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_kname.append(k) + + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + if self.args.use_expert_parallel: + for k in list(optim_state_dict.keys()): + model_k = k.split("/")[0] + if dp_rank > 0 and model_k not in no_sync_kname: + optim_state_dict.pop(k) + if master_weights is not None: + for k in list(master_weights.keys()): + model_k = k.split("/")[0] + if dp_rank > 0 and model_k not in no_sync_kname: + master_weights.pop(k) + optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix) master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix) @@ -431,12 +480,14 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir): self._file_save_async_or_sync( optim_state_dict, path=os.path.join(output_dir, optimizer_name), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="optimizer_weight", ) self._file_save_async_or_sync( master_weights, path=os.path.join(output_dir, master_weights_name), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="master_weight", ) @@ -462,7 +513,10 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) with device_guard(): @@ -483,22 +537,23 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint): return returned_optim_state_dict - def save_unified_optimizer(self, model, optimizer, output_dir): + def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir): """save unified optimizer Args: model (PretrainedModel): model used to get key mapping. optimizer (Optimizer): optimizer to save output_dir (str): Save directory. + signal_dir (str): Asynchronous saving signal directory. """ if "ignore_merge_optimizer" in self.args.unified_checkpoint_config: - self.save_non_merge_optimizer(model, optimizer, output_dir) + self.save_non_merge_optimizer(model, optimizer, output_dir, signal_dir) return if paddle.distributed.get_world_size() <= 1: - self.save_single_card_optimizer(model, optimizer, output_dir) + self.save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal return # Split into naive optimizer params and master weights. @@ -514,6 +569,8 @@ def save_unified_optimizer(self, model, optimizer, output_dir): save_directory = output_dir os.makedirs(save_directory, exist_ok=True) + if signal_dir is not None: + os.makedirs(signal_dir, exist_ok=True) is_sync_save = True if "async_save" in self.args.unified_checkpoint_config: @@ -521,6 +578,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir): self._file_save_async_or_sync( optim_state_dict, path=os.path.join(save_directory, shard_optim_file), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="optimizer_weight", ) @@ -528,6 +586,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir): self._file_save_async_or_sync( master_weight_state_dict, path=os.path.join(save_directory, shard_master_weight_file), + signal_path=signal_dir, is_sync=is_sync_save, state_dict_type="master_weight", ) @@ -568,7 +627,7 @@ def load_unified_optimizer(self, args, model, optimizer, resume_from_checkpoint) ) # If not having merge optimizer, then load non-merge optimizer. if not has_merge_optimizer_safetensors: - if self.args.data_parallel_rank == 0: + if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: returned_optim_state_dict = self.load_non_merge_optimizer( model, optimizer, @@ -588,7 +647,7 @@ def load_unified_optimizer(self, args, model, optimizer, resume_from_checkpoint) ) return returned_optim_state_dict - if self.args.data_parallel_rank == 0: + if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: returned_optim_state_dict = load_unified_optimizer_locally( self.args, model, optimizer, resume_from_checkpoint, safe_serialization=True ) @@ -639,6 +698,10 @@ def save_single_card_checkpoint(self, model_to_save, output_dir): config_to_save.architectures = [model_to_save.__class__.__name__] config_to_save.save_pretrained(output_dir) + # save generation config + if model_to_save.can_generate(): + model_to_save.generation_config.save_pretrained(output_dir) + def save_single_card_optimizer(self, model, optimizer, output_dir): """ "Save optimizer for non-distributed environment.""" # Split into optimizer params and master weights. @@ -651,8 +714,11 @@ def save_single_card_optimizer(self, model, optimizer, output_dir): static2struct_name_mappings = {} state_dict = get_expected_state_dict(model) + fp32_weight = {} for k, v in state_dict.items(): static2struct_name_mappings[v.name] = k + if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: + fp32_weight[k] = v # rename optimizer param for key in list(optim_state_dict.keys()): @@ -662,6 +728,7 @@ def save_single_card_optimizer(self, model, optimizer, output_dir): if master_weights is not None: for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + master_weights.update(fp32_weight) # save index json index_optimizer_file, index_master_weight_file = {}, {} @@ -715,14 +782,20 @@ def unlink_shared_memory(self): if self._shared_save_model_flag is not None: while self._shared_save_model_flag[0] > 0: # async process is saving + if not self._process_model_weight.is_alive(): + raise RuntimeError("The process that saves model_weight has been killed unexpectedly.") time.sleep(0.5) self._shared_save_model_flag[0] = -1 if self._shared_save_master_weight_flag is not None: while self._shared_save_master_weight_flag[0] > 0: + if not self._process_master_weight.is_alive(): + raise RuntimeError("The process that saves master_weight has been killed unexpectedly.") time.sleep(0.5) self._shared_save_master_weight_flag[0] = -1 if self._shared_save_optimizer_flag is not None: while self._shared_save_optimizer_flag[0] > 0: + if not self._process_optimizer_weight.is_alive(): + raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.") time.sleep(0.5) self._shared_save_optimizer_flag[0] = -1 @@ -739,12 +812,13 @@ def unlink_shared_memory(self): self._shm_optimizer_weight.unlink() self._shm_optimizer_weight = None - dist.barrier() + if paddle.distributed.get_world_size() > 1: + dist.barrier() def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False): """ - Only dataset_rank == 0 can enter this function. + Only dataset_rank == 0 or using expert parallel can enter this function. """ index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=True) @@ -755,7 +829,14 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa loaded_keys = sharded_metadata["all_checkpoint_keys"] model_state_dict = get_expected_state_dict(model) - expected_keys = set(list(model_state_dict.keys())) + # If using expert parallel, when dp_rank > 0, need to modify the expected_keys here. + if not args.use_expert_parallel or (args.use_expert_parallel and args.data_parallel_rank == 0): + expected_keys = set(list(model_state_dict.keys())) + else: + expected_keys = set() + for key in model_state_dict.keys(): + if getattr(model_state_dict[key], "no_sync", False): + expected_keys.add(key) missing_keys = expected_keys - set(loaded_keys) use_fast_set = True @@ -889,11 +970,17 @@ def unified_checkpoint_into_shards( weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME shard_file = get_sharded_file_name(args, weights_name) + # renumerize shard_file name for expert_parallel. + if args.use_expert_parallel: + shard_file = rename_shard_file(args, shard_file, weights_name) + for key, weight in state_dict.items(): index_weight_file[key] = shard_file total_size += weight.numel().item() * dtype_byte_size(weight.dtype) - index_file_list, total_size_list = gather_sharded_object(index_weight_file, total_size) + index_file_list, total_size_list = gather_sharded_object( + index_weight_file, total_size, use_expert_parallel=args.use_expert_parallel + ) sharded_index = get_sharded_index( index_file_list, total_size_list, @@ -931,7 +1018,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin model_keys = list(model_state_dict.keys()) struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings - expected_keys = get_expected_keys(sharded_metadata, model, optimizer) + expected_keys = get_expected_keys(args, sharded_metadata, model, optimizer) # This should always be a list but, just to be sure. if not isinstance(resolved_archive_file, list): @@ -955,7 +1042,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin index_filename=os.path.join(resume_from_checkpoint, index_filename_master_weights), ) - expected_keys_mw = get_expected_keys(sharded_metadata_mw, model, optimizer) + expected_keys_mw = get_expected_keys(args, sharded_metadata_mw, model, optimizer, is_master_weights=True) if not isinstance(resolved_archive_file_mw, list): resolved_archive_file_mw = [resolved_archive_file_mw] if len(resolved_archive_file_mw) > 1: @@ -1005,7 +1092,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) returned_optim_state_dict[key_name] = state_dict_optim.pop(key) @@ -1049,8 +1139,13 @@ def unified_optimizer_into_shards( # get optimizer param mappings static2struct_name_mappings = {} state_dict = get_expected_state_dict(model) + fp32_weight = {} for k, v in state_dict.items(): static2struct_name_mappings[v.name] = k + if master_weights is not None and v.dtype == core.VarDesc.VarType.FP32: + if args.dataset_rank > 0: # deal with different dataset rank. + continue + fp32_weight[k] = v # rename optimizer param for key in list(optim_state_dict.keys()): @@ -1060,6 +1155,7 @@ def unified_optimizer_into_shards( if master_weights is not None: for key in list(master_weights.keys()): master_weights[static2struct_name_mappings[key]] = master_weights.pop(key) + master_weights.update(fp32_weight) # filter optimizer param if master_weights is not None: @@ -1087,6 +1183,7 @@ def unified_optimizer_into_shards( optim_state_dict, tp_actions, filter_optim_keys, + state_dict if args.use_expert_parallel else None, ) paddle.device.cuda.empty_cache() @@ -1096,6 +1193,7 @@ def unified_optimizer_into_shards( master_weights, tp_actions, filter_master_keys, + state_dict if args.use_expert_parallel else None, ) paddle.device.cuda.empty_cache() @@ -1119,12 +1217,18 @@ def unified_optimizer_into_shards( total_master_weight_size += weight.numel().item() * dtype_byte_size(weight.dtype) index_optimizer_filelist, total_optim_size_list = gather_sharded_object( - index_optimizer_file, total_optim_size, is_optimizer=True + index_optimizer_file, + total_optim_size, + is_optimizer=True, + use_expert_parallel=args.use_expert_parallel, ) sharded_optim_index = get_sharded_index(index_optimizer_filelist, total_optim_size_list) if master_weights is not None: index_master_weight_filelist, total_master_weight_size_list = gather_sharded_object( - index_master_weight_file, total_master_weight_size, is_optimizer=True + index_master_weight_file, + total_master_weight_size, + is_optimizer=True, + use_expert_parallel=args.use_expert_parallel, ) sharded_master_weight_index = get_sharded_index(index_master_weight_filelist, total_master_weight_size_list) @@ -1175,15 +1279,20 @@ def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serializa # To decide whether to load the checkpoint locally, or need to dynamically send tensors across machines. local_resume = True - if args.dataset_rank == 0: + if args.dataset_rank == 0 or args.use_expert_parallel: hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() pp_group = hcg.get_pipe_parallel_group() + dp_group = hcg.get_data_parallel_group() + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 need_files = set() state_dict = get_expected_state_dict(model) for key in state_dict.keys(): filename = index["weight_map"][key] + # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. + if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): + continue need_files.add(filename) diff_filelist = list(need_files.difference(set(existed_files))) num_diff = paddle.to_tensor([len(diff_filelist)]) @@ -1191,6 +1300,8 @@ def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serializa dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=tp_group) if pp_group.nranks > 1: dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) + if args.use_expert_parallel and dp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) if num_diff.item() == 0: local_resume = True else: @@ -1243,8 +1354,10 @@ def check_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() pp_group = hcg.get_pipe_parallel_group() + dp_group = hcg.get_data_parallel_group() sharding_group = hcg.get_sharding_parallel_group() sharding_rank = sharding_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()} if sharding_group.nranks > 1: param2rank = optimizer._param2rank @@ -1269,9 +1382,10 @@ def check_complete(all_filenames): def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, typename_set=None): # To decide whether to load the checkpoint locally, or need to dynamically distribute the checkpoint. local_resume = True - if args.data_parallel_rank == 0: + if args.data_parallel_rank == 0 or args.use_expert_parallel: need_files = set() state_dict = get_expected_state_dict(model) + for key in state_dict.keys(): if sharding_group.nranks > 1: static_name = struct2static_name_mappings.get(key, None) @@ -1279,6 +1393,13 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, if param_rank != sharding_rank: continue + # When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0. + if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False): + continue + + if is_master_weights and state_dict[key].dtype == core.VarDesc.VarType.FP32: + continue + if not is_master_weights: for type_name in typename_set: type_key = key + "/" + type_name @@ -1296,6 +1417,8 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False, dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=pp_group) if sharding_group.nranks > 1: dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=sharding_group) + if args.use_expert_parallel and dp_group.nranks > 1: + dist.all_reduce(num_diff, op=dist.ReduceOp.MAX, group=dp_group) if num_diff.item() == 0: local_resume = True @@ -1548,8 +1671,10 @@ def load_unified_optimizer_dynamically(args, model, optimizer, resume_from_check for key in index["weight_map"].keys(): _, typename = key.split("/") typename_set.add(typename) - struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} - static2struct_name_mappings = {v.name: k for k, v in get_expected_state_dict(model).items()} + + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} + static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # Get send_table and recv_table. The send table indicates which workers are responsible for sending tensors, and the recv table indicates which workers should receive the tensors. send_table, recv_table = create_optimizer_dispatch_table( args, @@ -1671,7 +1796,10 @@ def check_optimizer_param(parameter): key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) else: key_name = "_".join([static_name, key_name[1]]) optim_state_dict[key_name] = optim_state_dict.pop(key) @@ -1745,9 +1873,10 @@ def load_single_card_optimizer(args, model, optimizer, resume_from_checkpoint: s key_name = key.split("/") static_name = struct2static_name_mappings[key_name[0]] if has_master_weights: - key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) - else: - key_name = "_".join([static_name, key_name[1]]) + if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32: + key_name = "_".join([static_name, FP32_MASTER, key_name[1]]) + else: + key_name = "_".join([static_name, key_name[1]]) returned_optim_state_dict[key_name] = state_dict_optim.pop(key) returned_optim_state_dict[key_name].name = key_name if has_master_weights: @@ -1872,26 +2001,29 @@ def distributed_send_recv( def get_sharded_file_name(args, file_name, is_optimizer=False): if not is_optimizer: + sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 + size = sd_degree if args.use_expert_parallel else args.dataset_world_size shard_file = file_name.replace( ".pdparams", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.pdparams", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams", ) shard_file = shard_file.replace( ".safetensors", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.safetensors", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors", ) else: hcg = fleet.get_hybrid_communicate_group() dp_group = hcg.get_data_parallel_group() + size = dp_group.nranks if not args.use_expert_parallel else 1 shard_file = file_name.replace( - ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.pdparams" + ".pdparams", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdparams" ) shard_file = shard_file.replace( ".safetensors", - f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.safetensors", + f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.safetensors", ) shard_file = shard_file.replace( - ".pdopt", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//dp_group.nranks:05d}.pdopt" + ".pdopt", f"-{args.logical_process_index + 1:05d}-of-{args.world_size//size:05d}.pdopt" ) return shard_file @@ -1935,7 +2067,7 @@ def reduce_master_weights_status(has_master_weights=False): return data.item() > 0 -def gather_sharded_object(index_file, total_size, is_optimizer=False): +def gather_sharded_object(index_file, total_size, is_optimizer=False, use_expert_parallel=False): index_file_list, total_size_list = [], [] @@ -1969,6 +2101,17 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False): if len(index_file_list) == 0 and len(total_size_list) == 0: index_file_list = [index_file] total_size_list = [total_size] + + if use_expert_parallel: + data_group = hcg.get_data_parallel_group() + if data_group.nranks > 1: + data_index_file_list = [] + data_total_size_list = [] + dist.all_gather_object(data_index_file_list, index_file_list, data_group) + dist.all_gather_object(data_total_size_list, total_size_list, data_group) + index_file_list = flatten_list(data_index_file_list) + total_size_list = flatten_list(data_total_size_list) + if is_optimizer: sharding_group = hcg.get_sharding_parallel_group() if sharding_group.nranks > 1: @@ -1982,16 +2125,58 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False): return index_file_list, total_size_list +def rename_shard_file(args, shard_file, file_name): + """rename shard file when using expert_parallel.""" + assert args.use_expert_parallel, "only expert_parallel need to use this function" + + shard_file_list = [] + + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() + pp_group = hcg.get_pipe_parallel_group() + data_group = hcg.get_data_parallel_group() + + if tp_group.nranks > 1: + dist.all_gather_object(shard_file_list, shard_file, tp_group) + if pp_group.nranks > 1: + pp_shard_file_list = [] + dist.all_gather_object( + pp_shard_file_list, shard_file_list if len(shard_file_list) > 0 else shard_file, pp_group + ) + shard_file_list = flatten_list(pp_shard_file_list) + if data_group.nranks > 1: + data_shard_file_list = [] + dist.all_gather_object( + data_shard_file_list, shard_file_list if len(shard_file_list) > 0 else shard_file, data_group + ) + shard_file_list = flatten_list(data_shard_file_list) + + new_index = shard_file_list.index(shard_file) + sd_degree = args.sharding_parallel_degree if args.sharding_parallel_degree > 1 else 1 + shard_file = file_name.replace( + ".pdparams", + f"-{new_index + 1:05d}-of-{args.world_size//sd_degree:05d}.pdparams", + ) + shard_file = shard_file.replace( + ".safetensors", + f"-{new_index + 1:05d}-of-{args.world_size//sd_degree:05d}.safetensors", + ) + return shard_file + + def generate_base_static_name(vname): # return base static name and specific type name, like [embedding_0.w_0, moment1_0] if FP32_MASTER in vname: vname = vname.split("_" + FP32_MASTER + "_") return vname[0], vname[1] else: - vname = vname.split(".") - a = vname[0] + "." + vname[1][:3] - b = vname[1][4:] - return a, b + # Directly deal with type names, for example: moe_gate_1_moment1_0. + type_names = optimizer_scalar_name + optimizer_non_scaler_name + for name in type_names: + if name in vname: + a = vname.split(name)[0][:-1] + b = name + return a, b def filter_params(model_to_save, state_dict, is_optimizer=False): @@ -2087,7 +2272,9 @@ def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst): def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() tp_rank = tp_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 # filter actions for pipeline mode if hcg.get_pipe_parallel_group().nranks > 1: @@ -2105,6 +2292,9 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): continue key = filter_keys[i] tensor = state_dict[key] + # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. + if dp_rank > 0 and not getattr(tensor, "no_sync", False): + continue if key in tp_actions: # Get tensor size tensor_bytes = tensor.numel().item() * dtype_byte_size(tensor.dtype) * tp_group.nranks @@ -2128,16 +2318,24 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys): if len(tp_actions) > 0: for x in tp_actions.keys(): - logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.") + logger.debug(f"key <{x}> need to merge tensor parallel but we can't find in model state.") return state_dict_to_save -def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys): +def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys, model_state_dict=None): # Core function for UC hcg = fleet.get_hybrid_communicate_group() tp_group = hcg.get_model_parallel_group() + dp_group = hcg.get_data_parallel_group() tp_rank = tp_group.rank + dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 + + no_sync_kname = [] + if model_state_dict is not None: + for k, v in model_state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_kname.append(k) state_dict_to_save = {} max_key_len = max([len(_) for _ in all_filter_keys]) @@ -2149,6 +2347,9 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys) # get base model key model_key = filter_keys[i].split("/")[0] tensor = state_dict[filter_keys[i]] + # When using expert parallel, there's no need to save tensors with `no_sync=False` when dp_rank > 0. + if dp_rank > 0 and model_key not in no_sync_kname: + continue if model_key in tp_actions: # for example: beta1, beta2 if tensor.numel().item() == 1: @@ -2217,7 +2418,7 @@ def get_optimizer_shard_files(optimizer_path, index_filename): return shard_filenames, sharded_metadata -def get_expected_keys(sharded_metadata, model, optimizer): +def get_expected_keys(args, sharded_metadata, model, optimizer, is_master_weights=False): hcg = fleet.get_hybrid_communicate_group() sharding_group = hcg.get_sharding_parallel_group() sharding_rank = sharding_group.rank @@ -2225,11 +2426,23 @@ def get_expected_keys(sharded_metadata, model, optimizer): if in_sharding_parallel_model: params2rank = optimizer._param2rank - struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} expected_keys = [] for key in list(sharded_metadata["all_optimizer_keys"]): key_name = key.split("/")[0] + if ( + is_master_weights + and key_name in model_state_dict + and model_state_dict[key_name].dtype == core.VarDesc.VarType.FP32 + ): + continue + + if args.use_expert_parallel and args.data_parallel_rank > 0: + if key_name in model_state_dict and not getattr(model_state_dict[key_name], "no_sync", False): + continue + static_name = struct2static_name_mappings.get(key_name, None) if in_sharding_parallel_model: diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 7a4c73a2e14c..ddc872ad6173 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -197,6 +197,17 @@ def in_auto_parallel_align_mode(): return False +try: + from paddle.framework.recall_error import LOSS_NAN_ERROR +except ImportError: + LOSS_NAN_ERROR = "PaddleRecall error(102): LossNan" + +try: + from paddle.framework.recall_error import LOSS_INF_ERROR +except ImportError: + LOSS_INF_ERROR = "PaddleRecall error(104): LossInf" + + __all__ = ["Trainer"] @@ -570,7 +581,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): # Load potential model checkpoint if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: uc_async_save = self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config - resume_from_checkpoint = get_last_checkpoint(self.args.output_dir, uc_async_save) + resume_from_checkpoint = get_last_checkpoint( + self.args.output_dir, signal_folder=self.args.output_signal_dir, uc_async_save=uc_async_save + ) if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})") @@ -1368,7 +1381,8 @@ def _get_item_from_loss(self, loss): loss_value = loss.item() if not self.args.fp16: if not np.isfinite(loss_value).all(): - raise ValueError(f"Loss contains inf or nan values, its value is {loss_value}") + err_msg = LOSS_NAN_ERROR if np.isnan(loss_value).any() else LOSS_INF_ERROR + raise ValueError(f"{err_msg}. Loss contains inf or nan values, its value is {loss_value}") return loss_value def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs): @@ -2258,13 +2272,6 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle self._pp_data_buffer = [] model.train() - # hack pipeline-layers - # since the pipeline layer will check input is valid every iter. - # in same case, for example, batch size warmup, we need dynamic change gradient_accumulation_steps to implement. - config_backup = model.micro_batch_size, model.accumulate_steps - model.micro_batch_size = self.args.per_device_train_batch_size - model.accumulate_steps = self.args.gradient_accumulation_steps - if model._dp_comm_overlap or model._sharding_comm_overlap: for _, buffers in model._chunk_2_comm_buffers.items(): for buffer in buffers: @@ -2279,11 +2286,14 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle with self.autocast_smart_context_manager(): loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None) - model.micro_batch_size, model.accumulate_steps = config_backup - return loss.detach() - def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Optional[bool] = False): + def save_model( + self, + output_dir: Optional[str] = None, + merge_tensor_parallel: Optional[bool] = False, + signal_dir: Optional[str] = None, + ): """ Will save the model, so you can reload it using `from_pretrained()`. @@ -2293,17 +2303,20 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op if output_dir is None: output_dir = self.args.output_dir + if signal_dir is None: + signal_dir = self.args.output_signal_dir + if ShardingOption.FULL_SHARD in self.args.sharding: self.model_wrapped.get_all_parameters(convert2cpu=True) if self.args.should_save_model_state: - self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel) + self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel, signal_dir=signal_dir) else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: - os.makedirs(output_dir, exist_ok=True) + os.makedirs(signal_dir, exist_ok=True) if self.is_in_train: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - paddle.save(global_rank, os.path.join(output_dir, f".model_weight.done.{global_rank}")) + paddle.save(global_rank, os.path.join(signal_dir, f".model_weight.done.{global_rank}")) if strtobool(os.getenv("FLAG_LLM_PDC", "False")): # save model_done file to ensure model is complete @@ -2319,9 +2332,9 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op and "async_save" in self.args.unified_checkpoint_config and not self.is_in_train ): - os.makedirs(output_dir, exist_ok=True) + os.makedirs(signal_dir, exist_ok=True) global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - paddle.save(self.state.global_step, os.path.join(output_dir, f".model_weight.done.{global_rank}")) + paddle.save(self.state.global_step, os.path.join(signal_dir, f".model_weight.done.{global_rank}")) def _filter_moe_no_sync_optimizer_params(self): """ @@ -2332,7 +2345,7 @@ def _filter_moe_no_sync_optimizer_params(self): filter_optimzier_state_dict = OrderedDict() param_names_in_master_weights = list(optimzier_state_dict["master_weights"].keys()) if self.args.bf16 else [] filter_optimzier_state_dict["master_weights"] = OrderedDict() - for k, v in state_dict.items(): + for _, v in state_dict.items(): if getattr(v, "no_sync", False): if v.name in param_names_in_master_weights: filter_optimzier_state_dict["master_weights"][v.name] = optimzier_state_dict["master_weights"][ @@ -2351,15 +2364,17 @@ def _save_checkpoint(self, model, metrics=None): checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self.args.output_dir + run_signal_dir = self.args.output_signal_dir output_dir = os.path.join(run_dir, checkpoint_folder) + signal_dir = os.path.join(run_signal_dir, checkpoint_folder) if isinstance(self.model, LoRAModel) and (self.model.quantized or self.args.pipeline_parallel_degree > 1): - self.save_model(output_dir) + self.save_model(output_dir, False, signal_dir) elif isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM): - self.save_model(output_dir, True) + self.save_model(output_dir, True, signal_dir) else: - self.save_model(output_dir) + self.save_model(output_dir, False, signal_dir) # only save model state dict, ignore optimizer and scheduler if not self.args.ignore_save_lr_and_optim: @@ -2375,6 +2390,7 @@ def _save_checkpoint(self, model, metrics=None): self.model, self.optimizer, output_dir, + signal_dir, ) else: if self.dp_group.rank > 0: # this should only work for MoE saving @@ -2397,10 +2413,10 @@ def _save_checkpoint(self, model, metrics=None): else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - os.makedirs(output_dir, exist_ok=True) - paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}")) + os.makedirs(signal_dir, exist_ok=True) + paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}")) if "skip_save_model_weight" not in self.args.unified_checkpoint_config: - paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}")) + paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) if self.args.should_save or self.args.use_expert_parallel: if not self.args.use_hybrid_parallel: logger.info("Saving optimizer files.") @@ -2409,6 +2425,7 @@ def _save_checkpoint(self, model, metrics=None): self.model, self.optimizer, output_dir, + signal_dir, ) else: if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel: @@ -2433,10 +2450,10 @@ def _save_checkpoint(self, model, metrics=None): if self.args.unified_checkpoint and not self.args.use_hybrid_parallel: if "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 - os.makedirs(output_dir, exist_ok=True) - paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}")) + os.makedirs(signal_dir, exist_ok=True) + paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}")) if "skip_save_model_weight" not in self.args.unified_checkpoint_config: - paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}")) + paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) self.runtime_timer.stop() # Determine the new best metric / best model checkpoint @@ -2485,7 +2502,7 @@ def _save_checkpoint(self, model, metrics=None): # For hybrid parallel training, the checkpoint files maybe on different node. need_to_rotate_checkpoints = False if self.args.use_hybrid_parallel: - if self.dp_group.rank <= 0: + if self.dp_group.rank <= 0 or self.args.use_expert_parallel: need_to_rotate_checkpoints = True else: need_to_rotate_checkpoints = self.args.should_save_model_state @@ -2494,6 +2511,7 @@ def _save_checkpoint(self, model, metrics=None): need_to_rotate_checkpoints = need_to_rotate_checkpoints and self.args.local_rank == 0 if need_to_rotate_checkpoints: self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + self._rotate_checkpoints(use_mtime=True, output_dir=run_signal_dir) if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and not ("async_save" in self.args.unified_checkpoint_config): # save checkpoint_done file to ensure checkpoint is complete @@ -2568,10 +2586,23 @@ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: # ignore_errors for shared disks between train nodes. shutil.rmtree(checkpoint, ignore_errors=True) - def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_parallel=False): + def _save( + self, + output_dir: Optional[str] = None, + state_dict=None, + merge_tensor_parallel=False, + signal_dir: Optional[str] = None, + ): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") + + # signal_dir is used for asynchronous saving situations. + if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: + signal_dir = signal_dir if signal_dir is not None else self.args.output_signal_dir + os.makedirs(signal_dir, exist_ok=True) + logger.info(f"Saving model checkpoint finish signal to {signal_dir}") + # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` @@ -2581,16 +2612,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ and self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config ): - os.makedirs(self.args.logging_dir, exist_ok=True) world_size = paddle.distributed.get_world_size() save_info = { "world_size": world_size, "ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim, "skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config, } - if os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")): # afs cannot overwrite - os.remove(os.path.join(self.args.logging_dir, "async_save_info.json")) - with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f: + if os.path.exists(os.path.join(signal_dir, "async_save_info.json")): # afs cannot overwrite + os.remove(os.path.join(signal_dir, "async_save_info.json")) + with open(os.path.join(signal_dir, "async_save_info.json"), "w") as f: json.dump(save_info, f) if self.args.should_save: @@ -2605,7 +2635,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ if not self.is_in_train: self.args.unified_checkpoint_config = [] - self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir) + self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir, signal_dir) # recover unified_checkpoint_config for not trine stage if not self.is_in_train: diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index 0588ea3530ee..ca816b585e3b 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -256,7 +256,7 @@ def _check_checkpoint_files(folder_path, world_size, ignore_save_lr_and_optim, s return a -def get_last_checkpoint(folder, uc_async_save=False): +def get_last_checkpoint(folder, signal_folder=None, uc_async_save=False): content = os.listdir(folder) checkpoints = [ path @@ -266,6 +266,9 @@ def get_last_checkpoint(folder, uc_async_save=False): if len(checkpoints) == 0: return + if uc_async_save: + assert signal_folder is not None + if strtobool(os.getenv("FLAG_LLM_PDC", "False")): for i in sorted(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]), reverse=True): current_path = os.path.join(folder, i) @@ -275,11 +278,12 @@ def get_last_checkpoint(folder, uc_async_save=False): return current_path else: saving_info = paddle.load(distributed_file(os.path.join(current_path, ".saving_info"))) + current_signal_path = os.path.join(signal_folder, i) pre_world_size = saving_info.get("world_size", 1) ignore_save_lr_and_optim = saving_info.get("ignore_save_lr_and_optim", False) skip_save_model_weight = saving_info.get("skip_save_model_weight", False) if _check_checkpoint_files( - current_path, pre_world_size, ignore_save_lr_and_optim, skip_save_model_weight + current_signal_path, pre_world_size, ignore_save_lr_and_optim, skip_save_model_weight ): return current_path return diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 2f4f6a04a005..6aed76f17cea 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -447,6 +447,7 @@ class TrainingArguments: }, ) logging_dir: Optional[str] = field(default=None, metadata={"help": "VisualDL log dir."}) + output_signal_dir: Optional[str] = field(default=None, metadata={"help": "Asynchronous saving signal dir."}) logging_strategy: IntervalStrategy = field( default="steps", metadata={"help": "The logging strategy to use."}, @@ -914,6 +915,10 @@ def __post_init__(self): self.logging_dir = os.path.join(self.output_dir, default_logdir()) if self.logging_dir is not None: self.logging_dir = os.path.expanduser(self.logging_dir) + if self.output_signal_dir is None and self.output_dir is not None: + self.output_signal_dir = self.output_dir + if self.output_signal_dir is not None: + self.output_signal_dir = os.path.expanduser(self.output_signal_dir) if self.disable_tqdm is None: self.disable_tqdm = False # logger.getEffectiveLevel() > logging.WARN @@ -1120,6 +1125,7 @@ def split_parallel_config(parallel_config): "enable_clear_every_step_cache", "enable_overlap_p2p_comm", "disable_batch_p2p_comm", + "best_unbalanced_scheduler", ]: raise ValueError( f"Found unknown pipeline mode config {x}, accpet config is disable_p2p_cache_shape, disable_partial_send_recv." @@ -1158,6 +1164,7 @@ def split_parallel_config(parallel_config): "overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config, "clear_every_step_cache": "enable_clear_every_step_cache" in pipeline_parallel_config, "use_batch_p2p_comm": "disable_batch_p2p_comm" not in pipeline_parallel_config, + "best_unbalanced_scheduler": "best_unbalanced_scheduler" in pipeline_parallel_config, } if dygraph_pp_configs["dp_comm_overlap"]: raise ValueError("overlap has accuracy issue") # TODO: fix `overalap` + `delay_scale` issue diff --git a/paddlenlp/trainer/utils/reshard/pp_reshard.py b/paddlenlp/trainer/utils/reshard/pp_reshard.py index 5c98e6069212..0caa5eb666c6 100644 --- a/paddlenlp/trainer/utils/reshard/pp_reshard.py +++ b/paddlenlp/trainer/utils/reshard/pp_reshard.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from collections import OrderedDict from paddle.distributed.fleet.model import PipelineParallel @@ -46,6 +45,25 @@ def get_index_layer_func(): return _GLOBAL_INDEX_LAYER_FUNC +_GLOBAL_SNAME_TO_TNAME_FUNC = None + + +def register_sname_to_tname_func(func): + global _GLOBAL_SNAME_TO_TNAME_FUNC + _GLOBAL_SNAME_TO_TNAME_FUNC = func + + +def has_register_sname_to_tname_func(): + global _GLOBAL_SNAME_TO_TNAME_FUNC + return _GLOBAL_SNAME_TO_TNAME_FUNC is not None + + +def get_sname_to_tname_func(): + global _GLOBAL_SNAME_TO_TNAME_FUNC + assert _GLOBAL_SNAME_TO_TNAME_FUNC is not None, "sname to tname func is not registered yet" + return _GLOBAL_SNAME_TO_TNAME_FUNC + + class LayerNameScope: """ layer name scope for a layer, layer name of the same kind of layer will be named consecutively @@ -206,6 +224,7 @@ def __init__(self): self._segments = OrderedDict() self._layer_to_segment = OrderedDict() self._param_to_tname = OrderedDict() + self._wname_to_rname = OrderedDict() def add_segment(self, start_index, end_index): segment = PipeLineSegment(start_index, end_index) @@ -218,19 +237,24 @@ def add_layer(self, layer_index, layer_name, param_names): segment = self._layer_to_segment[layer_index] segment.add_layer(layer_name, param_names) - def build_name_mapping(self): + def build_name_mapping(self, sname_to_tname=None): for (k, segment) in self._segments.items(): for (i, layer) in segment.layers.items(): for param in layer.params.items(): (param_name, tensor_name) = param # map to a new name n_name = self._rename_mgr.get_new_param_name(layer.name, tensor_name) + if sname_to_tname is not None: + if param_name in sname_to_tname.keys(): + self._wname_to_rname[param_name] = sname_to_tname[param_name] # logger.info(f"{param_name} {tensor_name}=>{n_name}") self._param_to_tname[param_name] = (tensor_name, n_name) def map_name(self, param_name, t_name): assert param_name in self._param_to_tname tensor_name, n_name = self._param_to_tname[param_name] + if param_name in self._wname_to_rname: + n_name = self._wname_to_rname[param_name] assert tensor_name == t_name return n_name @@ -261,6 +285,11 @@ def __init__( self._index_layers() stage_segments = self._segment() + if has_register_sname_to_tname_func(): + self._sname_to_tname = get_sname_to_tname_func()(pp_model) + else: + self._sname_to_tname = None + for (i, stage_seg) in enumerate(stage_segments): pipe_stage = PipeLineStage() self._stages.append(pipe_stage) @@ -275,7 +304,7 @@ def __init__( self._layer_name_to_stage[layer_name] = i for stage in self._stages: - stage.build_name_mapping() + stage.build_name_mapping(self._sname_to_tname) def _index_layers(self): for layer_name in self._param_names_by_layer.keys(): diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index 0926988771ba..da45fb1f8102 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -38,6 +38,7 @@ ) from paddlenlp.transformers.utils import paddlenlp_load from paddlenlp.utils.log import logger +from paddlenlp.utils.tools import get_env_device from . import reshard as reshard_util from .reshard import SHARDING_STRATEGY_V1, SHARDING_STRATEGY_V2, pp_reshard @@ -53,6 +54,22 @@ SHARDING_META_NAME = "shard_meta.json" +def to_device(tensor, place=None): + if place is None: + place = get_env_device() + + if isinstance(place, str): + place = paddle.device._convert_to_place(place) + + if not tensor.place._equals(place): + new_t = tensor._copy_to(place, True) + dst_tensor = tensor.value().get_tensor() + src_tensor = new_t.value().get_tensor() + dst_tensor._share_data_with(src_tensor) + + return tensor + + def filter_sharded_params(state_dict, optimizer, sharding_group): sharding_rank = sharding_group.rank @@ -239,6 +256,9 @@ def _need_reshard(self, checkpoint): param2rank = sharding_meta["param2rank"] optimizer = unwrap_optimizer(self.optimizer, DygraphShardingOptimizer) assert optimizer + if len(param2rank) == 0: + logger.warning("The param2rank is empty. Force reshard would be performed.") + return True assert len(param2rank) == len(optimizer._param2rank) for (k, v) in param2rank.items(): assert k in optimizer._param2rank @@ -460,7 +480,7 @@ def _recover_params_from_master_weights(self, state_dict, opt_state_dict=None): # cast to before for (k, v) in tmp.items(): name = v.name - master_weights[k] = paddle.cast(v.cuda(), paddle.bfloat16).cpu() + master_weights[k] = paddle.cast(to_device(v), paddle.bfloat16).cpu() master_weights[k].name = name structure_name_map = {k: v.name for (k, v) in self.model.state_dict().items()} @@ -491,7 +511,9 @@ def filter_func(name): for key, param in model_state_dict.items(): if param.name in master_weights: assert param.shape == master_weights[param.name].shape - paddle.assign(paddle.cast(master_weights[param.name].cuda(), paddle.bfloat16), model_state_dict[key]) + paddle.assign( + paddle.cast(to_device(master_weights[param.name]), paddle.bfloat16), model_state_dict[key] + ) elif key in state_dict: logger.info(f"key: {key} is in state_dict, but not in master_weights") paddle.assign(state_dict[key], model_state_dict[key]) diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index ff920f83c6b1..266c6ae4863a 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -147,7 +147,7 @@ def apply_rotary_pos_emb(x: paddle.Tensor, rope_cache: paddle.Tensor) -> paddle. -1, ) x_out2 = x_out2.flatten(3) - return paddle.concat((x_out2, x_pass), axis=-1) + return paddle.concat((x_out2, x_pass.cast(x_out2.dtype)), axis=-1) class RMSNorm(nn.Layer): diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index d73b1452f652..3d329556ccee 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -38,6 +38,7 @@ function is_a100() { IS_A100=$(is_a100) +# NOTE: Please place the new tests as much as possible after the existing tests function llama_case_list_auto() { llama_dygraph_auto_bs8_fp32_DP2 llama_dygraph_auto_bs8_fp32_DP2-MP2 @@ -52,8 +53,9 @@ function llama_case_list_auto() { # llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2 # llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2 # llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2 - llama_pir_auto_fuse_ffn_attention_qkv_MP2 + llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1 + llama_pir_auto_fuse_ffn_attention_qkv_MP2 llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1 llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP @@ -1163,6 +1165,7 @@ function llama_pir_auto_fuse_ffn_attention_qkv_MP2() { loss_base=10.27925682 fi check_result $FUNCNAME ${loss_base} ${auto_loss} ${ips_base} ${auto_ips} ${mem_base} ${auto_mem} + export FLAGS_enable_fused_ffn_qkv_pass=0 echo "=========== $FUNCNAME run end ===========" }