Skip to content

Commit

Permalink
Support LoRA ChatGLM with Alpaca Dataset (#11580)
Browse files Browse the repository at this point in the history
* Support LoRA ChatGLM with Alpaca Dataset

* refine

* fix

* add 2-card alpaca
  • Loading branch information
Uxito-Ada authored Jul 16, 2024
1 parent 99c2274 commit 365adad
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,30 @@ source /opt/intel/oneapi/setvars.sh

### 3. LoRA Fine-Tune on ChatGLM3-6B

First, download the dataset: we use `AdvertiseGen` to finetune ChatGLM3-6B in the following, and please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script:
First, as for the dataset, you have two options:

1. `AdvertiseGen`: please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script:

```bash
python process_advertise_gen_dataset.py
```

Then, './AdvertiseGen' will be converted to './AdvertiseGen_fix'. Now, we have prepared the dataset, and are going to start LoRA fine-tuning on ChatGLM3-6B.

2. `Alapca`: We also support [yahma/alpaca-cleaned](https://huggingface.co/datasets/yahma/alpaca-cleaned) that contains generated instructions and demonstrations. It does not require preprocessing, and please directy run the following script.

#### 3.1. Fine-Tune with a Single Arc Card

Start the fine-tuning by:
1. For `AdvertiseGen`, start the fine-tuning by:

```bash
bash lora_finetuning_chatglm3_6b_on_advertise_gen_with_1_arc_card.sh
```

2. For `Alpaca`, start the fine-tuning by:

```bash
bash lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh
bash lora_finetuning_chatglm3_6b_on_alpaca_with_1_arc_card.sh
```

Then, you will get output are as below:
Expand Down Expand Up @@ -145,6 +155,14 @@ Training completed. Do not forget to share your model on huggingface.co/models =

Start the data-parallel fine-tuning on 2 Intel Arc XPU cards by:

1. `AdvertiseGen` dataset:

```bash
bash lora_finetuning_chatglm3_6b_on_advertise_gen_with_2_arc_cards.sh
```

2. `Alpaca` dataset:

```bash
bash lora_finetuning_on_chatglm3_6b_with_2_arc_cards.sh
bash lora_finetuning_chatglm3_6b_on_alpaca_with_2_arc_cards.sh
```
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,15 @@
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq

from transformers import Trainer
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer

current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..', '..')
import sys
sys.path.append(common_util_path)
from common.utils import get_train_val_data, Prompter

ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
app = typer.Typer(pretty_exceptions_show_locals=False)
Expand Down Expand Up @@ -247,7 +254,7 @@ def _load_datasets(
return dataset_dct


class DataManager(object):
class AdvertiseGenDataManager(object):
def __init__(self, data_dir: str, data_config: DataConfig):
self._num_proc = data_config.num_proc

Expand Down Expand Up @@ -283,6 +290,52 @@ def get_dataset(
num_proc=self._num_proc,
)

class AlpacaDataConfig(object):
def __init__(self, tokenizer, prompter, train_on_inputs,
add_eos_token, cutoff_len, val_set_size, seed):
self.tokenizer = tokenizer
self.prompter = prompter
self.train_on_inputs = train_on_inputs
self.add_eos_token = add_eos_token
self.cutoff_len = cutoff_len
self.val_set_size = val_set_size
self.seed = seed


class AlpacaDataManager(object):
def __init__(self, data_dir: str, data_config: AlpacaDataConfig):
if data_dir.endswith(".json") or data_dir.endswith(".jsonl"):
data = load_dataset("json", data_files=data_dir)
else:
data = load_dataset(data_dir)
self.train_data, self.val_data = get_train_val_data(
data,
data_config.tokenizer,
data_config.prompter,
data_config.train_on_inputs,
data_config.add_eos_token,
data_config.cutoff_len,
data_config.val_set_size,
seed=data_config.seed)
self.train_data = self.train_data.remove_columns(
['output', 'input', 'instruction', 'attention_mask', 'position_ids'])
self.val_data = self.val_data.remove_columns(
['output', 'input', 'instruction', 'attention_mask', 'position_ids'])

def get_dataset(
self,
split: NamedSplit,
process_fn: Callable[[dict[str, Any]], dict[str, Any]],
batched: bool = True,
remove_orig_columns: bool = True,
) -> Optional[Dataset]:
if split == Split.TRAIN:
return self.train_data
elif split == Split.VALIDATION:
return self.val_data
else:
return None


def print_model_size(model: PreTrainedModel):
print("--> Model")
Expand Down Expand Up @@ -484,7 +537,17 @@ def main(
):
ft_config = FinetuningConfig.from_file(config_file)
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if 'AdvertiseGen' in data_dir:
data_manager = AdvertiseGenDataManager(data_dir, ft_config.data_config)
elif 'alpaca' in data_dir:
data_config = AlpacaDataConfig(tokenizer=tokenizer, prompter=Prompter("alpaca"),
train_on_inputs=True, add_eos_token=False,
cutoff_len=256, val_set_size=2000, seed=42)
data_manager = AlpacaDataManager(data_dir, data_config)
else:
raise NotImplementedError("Wrong dataset, currently only support AdvertiseGen and Alpaca")

train_dataset = data_manager.get_dataset(
Split.TRAIN,
Expand Down Expand Up @@ -530,38 +593,47 @@ def main(
# turn model to fp32
_prepare_model_for_training(model, ft_config.training_args.use_cpu)

ft_config.training_args.generation_config.pad_token_id = (
tokenizer.pad_token_id
)
ft_config.training_args.generation_config.eos_token_id = [
tokenizer.eos_token_id,
tokenizer.get_command('<|user|>'),
tokenizer.get_command('<|observation|>'),
]
if 'AdvertiseGen' in data_dir:
ft_config.training_args.generation_config.pad_token_id = (
tokenizer.pad_token_id
)
ft_config.training_args.generation_config.eos_token_id = [
tokenizer.eos_token_id,
tokenizer.get_command('<|user|>'),
tokenizer.get_command('<|observation|>'),
]
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

use_tokenizer = True
if ft_config.peft_config is not None:
use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True
if 'AdvertiseGen' in data_dir:
use_tokenizer = True
if ft_config.peft_config is not None:
use_tokenizer = False if ft_config.peft_config.peft_type == "LORA" else True
else:
use_tokenizer = False

# Add below L544-L546 to enable finetuning on 2 Intel Arc XPU cards on top of oneccl and deepspeed
if deepspeed_config_file != '':
ft_config.training_args.ddp_backend = "ccl"
ft_config.training_args.deepspeed = deepspeed_config_file

trainer = Seq2SeqTrainer(
from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq

BASE_TRAINER = Trainer if 'alpaca' in data_dir else Seq2SeqTrainer

trainer = BASE_TRAINER(
model=model,
args=ft_config.training_args,
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding='longest',
return_tensors='pt',
padding=True if 'alpaca' in data_dir else 'longest',
pad_to_multiple_of=8 if 'alpaca' in data_dir else None,
),
train_dataset=train_dataset,
eval_dataset=val_dataset.select(list(range(50))),
tokenizer=tokenizer if use_tokenizer else None, # LORA does not need tokenizer
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer) if 'AdvertiseGen' in data_dir else None,
)

if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

export BIGDL_CHECK_DUPLICATE_IMPORT=0

# You can also set the remote model repository to a local model path
python lora_finetune_chatglm.py \
yahma/alpaca-cleaned \
THUDM/chatglm3-6b \
./lora_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

export MASTER_ADDR=127.0.0.1
export OMP_NUM_THREADS=6
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
export BIGDL_CHECK_DUPLICATE_IMPORT=0

# You can also set the remote model repository to a local model path
mpirun -n 2 \
python lora_finetune_chatglm.py \
yahma/alpaca-cleaned \
THUDM/chatglm3-6b \
./lora_config.yaml \
./deepspeed_config.json

0 comments on commit 365adad

Please sign in to comment.