Skip to content

Commit

Permalink
support for more chat templates
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Feb 22, 2024
1 parent 390ac6a commit 593ca62
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/autotrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@
warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow")


__version__ = "0.6.97.dev0"
__version__ = "0.6.98.dev0"
1 change: 1 addition & 0 deletions src/autotrain/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
block_size=1024,
epochs=3,
padding="right",
chat_template="none",
).model_dump()

PARAMS["text-classification"] = TextClassificationParams(
Expand Down
10 changes: 5 additions & 5 deletions src/autotrain/cli/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,12 @@ def register_subcommand(parser: ArgumentParser):
"alias": ["--dpo-beta"],
},
{
"arg": "--apply_chat_template",
"arg": "--chat_template",
"help": "Apply chat template",
"required": False,
"action": "store_true",
"alias": ["--apply-chat-template"],
"default": None,
"alias": ["--chat-template"],
"choices": ["tokenizer", "chatml", "zephyr"],
},
{
"arg": "--padding",
Expand Down Expand Up @@ -284,7 +285,6 @@ def __init__(self, args):
"merge_adapter",
"use_flash_attention_2",
"disable_gradient_checkpointing",
"apply_chat_template",
]
for arg_name in store_true_arg_names:
if getattr(self.args, arg_name) is None:
Expand Down Expand Up @@ -374,7 +374,7 @@ def run(self):
model_ref=self.args.model_ref,
dpo_beta=self.args.dpo_beta,
prompt_text_column=self.args.prompt_text_column,
apply_chat_template=self.args.apply_chat_template,
chat_template=self.args.chat_template,
padding=self.args.padding,
)

Expand Down
57 changes: 51 additions & 6 deletions src/autotrain/trainers/clm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import sys
from enum import Enum
from functools import partial

import torch
Expand Down Expand Up @@ -29,6 +30,32 @@
from autotrain.trainers.common import monitor, pause_space, remove_autotrain_data, save_training_params


class ZephyrSpecialTokens(str, Enum):
USER = "<|user|>"
ASSISTANT = "<|assistant|>"
SYSTEM = "<|system|>"
EOS_TOKEN = "</s>"
BOS_TOKEN = "<s>"
PAD_TOKEN = "<pad>"

@classmethod
def list(cls):
return [c.value for c in cls]


class ChatmlSpecialTokens(str, Enum):
USER = "<|im_start|>user"
ASSISTANT = "<|im_start|>assistant"
SYSTEM = "<|im_start|>system"
EOS_TOKEN = "<|im_end|>"
BOS_TOKEN = "<s>"
PAD_TOKEN = "<pad>"

@classmethod
def list(cls):
return [c.value for c in cls]


def parse_args():
# get training_config.json from the end user
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -93,19 +120,37 @@ def train(config):
if config.padding not in ("left", "right"):
config.padding = None

if config.apply_chat_template and config.trainer == "default":
logger.warning("apply_chat_template is not supported for default trainer. Skipping...")
config.apply_chat_template = False

is_deepspeed_enabled = os.environ.get("ACCELERATE_USE_DEEPSPEED", "False").lower() == "true"

if config.repo_id is None and config.username is not None:
config.repo_id = f"{config.username}/{config.project_name}"

train_data, valid_data = process_input_data(config)
tokenizer = AutoTokenizer.from_pretrained(config.model, token=config.token, trust_remote_code=True)

if config.apply_chat_template and config.trainer != "default":
special_tokens = None
chat_template = None
if config.chat_template == "chatml":
special_tokens = ChatmlSpecialTokens
chat_template = utils.CHATML_CHAT_TEMPLATE
elif config.chat_template == "zephyr":
special_tokens = ZephyrSpecialTokens
chat_template = utils.ZEPHYR_CHAT_TEMPLATE

if special_tokens is not None:
tokenizer = AutoTokenizer.from_pretrained(
config.model,
pad_token=special_tokens.PAD_TOKEN.value,
bos_token=special_tokens.BOS_TOKEN.value,
eos_token=special_tokens.EOS_TOKEN.value,
additional_special_tokens=special_tokens.list(),
token=config.token,
trust_remote_code=True,
)
tokenizer.chat_template = chat_template
else:
tokenizer = AutoTokenizer.from_pretrained(config.model, token=config.token, trust_remote_code=True)

if config.chat_template in ("chatml", "zephyr", "tokenizer"):
train_data = train_data.map(
utils.apply_chat_template,
fn_kwargs={
Expand Down
2 changes: 1 addition & 1 deletion src/autotrain/trainers/clm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class LLMTrainingParams(AutoTrainParams):
weight_decay: float = Field(0.0, title="Weight decay")
max_grad_norm: float = Field(1.0, title="Max gradient norm")
seed: int = Field(42, title="Seed")
apply_chat_template: bool = Field(False, title="Apply chat template")
chat_template: Optional[str] = Field(None, title="Chat template, one of: None, zephyr, chatml or tokenizer")

# peft
quantization: Optional[str] = Field(None, title="int4, int8, or None")
Expand Down
14 changes: 9 additions & 5 deletions src/autotrain/trainers/clm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from autotrain import logger


DEFAULT_CHAT_TEMPLATE = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
CHATML_CHAT_TEMPLATE = "{% for message in messages %}\n{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{{'<|im_start|>assistant\n' }}{% endif %}{% endfor %}"
ZEPHYR_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"


IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
Expand Down Expand Up @@ -131,6 +133,8 @@ def get_target_modules(config):
return TARGET_MODULES.get(config.model)
if config.target_modules.strip() == "":
return TARGET_MODULES.get(config.model)
if config.target_modules.strip().lower() == "all-linear":
return "all-linear"
return config.target_modules.split(",")


Expand Down Expand Up @@ -222,19 +226,19 @@ def pause_endpoint(params):
project_name = endpoint_id.split("/")[1]
api_url = f"https://api.endpoints.huggingface.cloud/v2/endpoint/{username}/{project_name}/pause"
headers = {"Authorization": f"Bearer {params.token}"}
r = requests.post(api_url, headers=headers)
r = requests.post(api_url, headers=headers, timeout=30)
return r.json()


def create_peft_handler(config):
txt = PEFT_HANDLER.strip()
with open(f"{config.project_name}/handler.py", "w") as f:
with open(f"{config.project_name}/handler.py", "w", encoding="utf-8") as f:
f.write(txt)


def create_requirements_txt(config):
txt = REQUIREMENTS_TXT.strip()
with open(f"{config.project_name}/requirements.txt", "w") as f:
with open(f"{config.project_name}/requirements.txt", "w", encoding="utf-8") as f:
f.write(txt)


Expand All @@ -244,7 +248,7 @@ def apply_chat_template(
config,
):
# kudos to Hugging Face H4 Team for this snippet
if config.trainer == "sft":
if config.trainer in ("default", "sft"):
messages = example[config.text_column]
if isinstance(messages, str):
messages = ast.literal_eval(messages)
Expand Down

0 comments on commit 593ca62

Please sign in to comment.