diff --git a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/README.md b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/README.md index aef7accef61..cad1665acbc 100644 --- a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/README.md +++ b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/README.md @@ -31,7 +31,9 @@ source /opt/intel/oneapi/setvars.sh ### 3. LoRA Fine-Tune on ChatGLM3-6B -First, download the dataset: we use `AdvertiseGen` to finetune ChatGLM3-6B in the following, and please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script: +First, as for the dataset, you have two options: + +1. `AdvertiseGen`: please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script: ```bash python process_advertise_gen_dataset.py @@ -39,12 +41,20 @@ python process_advertise_gen_dataset.py Then, './AdvertiseGen' will be converted to './AdvertiseGen_fix'. Now, we have prepared the dataset, and are going to start LoRA fine-tuning on ChatGLM3-6B. +2. `Alapca`: We also support [yahma/alpaca-cleaned](https://huggingface.co/datasets/yahma/alpaca-cleaned) that contains generated instructions and demonstrations. It does not require preprocessing, and please directy run the following script. + #### 3.1. Fine-Tune with a Single Arc Card -Start the fine-tuning by: +1. For `AdvertiseGen`, start the fine-tuning by: + +```bash +bash lora_finetuning_chatglm3_6b_on_advertise_gen_with_1_arc_card.sh +``` + +2. For `Alpaca`, start the fine-tuning by: ```bash -bash lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh +bash lora_finetuning_chatglm3_6b_on_alpaca_with_1_arc_card.sh ``` Then, you will get output are as below: @@ -145,6 +155,14 @@ Training completed. Do not forget to share your model on huggingface.co/models = Start the data-parallel fine-tuning on 2 Intel Arc XPU cards by: +1. `AdvertiseGen` dataset: + +```bash +bash lora_finetuning_chatglm3_6b_on_advertise_gen_with_2_arc_cards.sh +``` + +2. `Alpaca` dataset: + ```bash -bash lora_finetuning_on_chatglm3_6b_with_2_arc_cards.sh +bash lora_finetuning_chatglm3_6b_on_alpaca_with_2_arc_cards.sh ``` diff --git a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetune_chatglm.py b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetune_chatglm.py index 98ad862efe1..eb335100e0b 100644 --- a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetune_chatglm.py +++ b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetune_chatglm.py @@ -65,8 +65,15 @@ ) from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq +from transformers import Trainer from transformers import Seq2SeqTrainer as _Seq2SeqTrainer +current_dir = os.path.dirname(os.path.realpath(__file__)) +common_util_path = os.path.join(current_dir, '..', '..') +import sys +sys.path.append(common_util_path) +from common.utils import get_train_val_data, Prompter + ModelType = Union[PreTrainedModel, PeftModelForCausalLM] TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] app = typer.Typer(pretty_exceptions_show_locals=False) @@ -247,7 +254,7 @@ def _load_datasets( return dataset_dct -class DataManager(object): +class AdvertiseGenDataManager(object): def __init__(self, data_dir: str, data_config: DataConfig): self._num_proc = data_config.num_proc @@ -283,6 +290,52 @@ def get_dataset( num_proc=self._num_proc, ) +class AlpacaDataConfig(object): + def __init__(self, tokenizer, prompter, train_on_inputs, + add_eos_token, cutoff_len, val_set_size, seed): + self.tokenizer = tokenizer + self.prompter = prompter + self.train_on_inputs = train_on_inputs + self.add_eos_token = add_eos_token + self.cutoff_len = cutoff_len + self.val_set_size = val_set_size + self.seed = seed + + +class AlpacaDataManager(object): + def __init__(self, data_dir: str, data_config: AlpacaDataConfig): + if data_dir.endswith(".json") or data_dir.endswith(".jsonl"): + data = load_dataset("json", data_files=data_dir) + else: + data = load_dataset(data_dir) + self.train_data, self.val_data = get_train_val_data( + data, + data_config.tokenizer, + data_config.prompter, + data_config.train_on_inputs, + data_config.add_eos_token, + data_config.cutoff_len, + data_config.val_set_size, + seed=data_config.seed) + self.train_data = self.train_data.remove_columns( + ['output', 'input', 'instruction', 'attention_mask', 'position_ids']) + self.val_data = self.val_data.remove_columns( + ['output', 'input', 'instruction', 'attention_mask', 'position_ids']) + + def get_dataset( + self, + split: NamedSplit, + process_fn: Callable[[dict[str, Any]], dict[str, Any]], + batched: bool = True, + remove_orig_columns: bool = True, + ) -> Optional[Dataset]: + if split == Split.TRAIN: + return self.train_data + elif split == Split.VALIDATION: + return self.val_data + else: + return None + def print_model_size(model: PreTrainedModel): print("--> Model") @@ -484,7 +537,17 @@ def main( ): ft_config = FinetuningConfig.from_file(config_file) tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config) - data_manager = DataManager(data_dir, ft_config.data_config) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if 'AdvertiseGen' in data_dir: + data_manager = AdvertiseGenDataManager(data_dir, ft_config.data_config) + elif 'alpaca' in data_dir: + data_config = AlpacaDataConfig(tokenizer=tokenizer, prompter=Prompter("alpaca"), + train_on_inputs=True, add_eos_token=False, + cutoff_len=256, val_set_size=2000, seed=42) + data_manager = AlpacaDataManager(data_dir, data_config) + else: + raise NotImplementedError("Wrong dataset, currently only support AdvertiseGen and Alpaca") train_dataset = data_manager.get_dataset( Split.TRAIN, @@ -530,38 +593,47 @@ def main( # turn model to fp32 _prepare_model_for_training(model, ft_config.training_args.use_cpu) - ft_config.training_args.generation_config.pad_token_id = ( - tokenizer.pad_token_id - ) - ft_config.training_args.generation_config.eos_token_id = [ - tokenizer.eos_token_id, - tokenizer.get_command('<|user|>'), - tokenizer.get_command('<|observation|>'), - ] + if 'AdvertiseGen' in data_dir: + ft_config.training_args.generation_config.pad_token_id = ( + tokenizer.pad_token_id + ) + ft_config.training_args.generation_config.eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.get_command('<|user|>'), + tokenizer.get_command('<|observation|>'), + ] model.gradient_checkpointing_enable() model.enable_input_require_grads() - use_tokenizer = True - if ft_config.peft_config is not None: - use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True + if 'AdvertiseGen' in data_dir: + use_tokenizer = True + if ft_config.peft_config is not None: + use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True + else: + use_tokenizer = False # Add below L544-L546 to enable finetuning on 2 Intel Arc XPU cards on top of oneccl and deepspeed if deepspeed_config_file != '': ft_config.training_args.ddp_backend = "ccl" ft_config.training_args.deepspeed = deepspeed_config_file - trainer = Seq2SeqTrainer( + from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq + + BASE_TRAINER = Trainer if 'alpaca' in data_dir else Seq2SeqTrainer + + trainer = BASE_TRAINER( model=model, args=ft_config.training_args, data_collator=DataCollatorForSeq2Seq( tokenizer=tokenizer, - padding='longest', return_tensors='pt', + padding=True if 'alpaca' in data_dir else 'longest', + pad_to_multiple_of=8 if 'alpaca' in data_dir else None, ), train_dataset=train_dataset, eval_dataset=val_dataset.select(list(range(50))), tokenizer=tokenizer if use_tokenizer else None, # LORA does not need tokenizer - compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer), + compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer) if 'AdvertiseGen' in data_dir else None, ) if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None: diff --git a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_chatglm3_6b_on_advertise_gen_with_1_arc_card.sh similarity index 100% rename from python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh rename to python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_chatglm3_6b_on_advertise_gen_with_1_arc_card.sh diff --git a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_on_chatglm3_6b_with_2_arc_cards.sh b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_chatglm3_6b_on_advertise_gen_with_2_arc_cards.sh similarity index 100% rename from python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_on_chatglm3_6b_with_2_arc_cards.sh rename to python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_chatglm3_6b_on_advertise_gen_with_2_arc_cards.sh diff --git a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_chatglm3_6b_on_alpaca_with_1_arc_card.sh b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_chatglm3_6b_on_alpaca_with_1_arc_card.sh new file mode 100644 index 00000000000..54e1b466f37 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_chatglm3_6b_on_alpaca_with_1_arc_card.sh @@ -0,0 +1,23 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +export BIGDL_CHECK_DUPLICATE_IMPORT=0 + +# You can also set the remote model repository to a local model path +python lora_finetune_chatglm.py \ + yahma/alpaca-cleaned \ + THUDM/chatglm3-6b \ + ./lora_config.yaml diff --git a/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_chatglm3_6b_on_alpaca_with_2_arc_cards.sh b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_chatglm3_6b_on_alpaca_with_2_arc_cards.sh new file mode 100644 index 00000000000..d6447841e24 --- /dev/null +++ b/python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_finetuning_chatglm3_6b_on_alpaca_with_2_arc_cards.sh @@ -0,0 +1,29 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +export MASTER_ADDR=127.0.0.1 +export OMP_NUM_THREADS=6 +export FI_PROVIDER=tcp +export CCL_ATL_TRANSPORT=ofi +export BIGDL_CHECK_DUPLICATE_IMPORT=0 + +# You can also set the remote model repository to a local model path +mpirun -n 2 \ + python lora_finetune_chatglm.py \ + yahma/alpaca-cleaned \ + THUDM/chatglm3-6b \ + ./lora_config.yaml \ + ./deepspeed_config.json