Skip to content

Commit

Permalink
clm fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Feb 22, 2024
1 parent 593ca62 commit 64a4107
Showing 4 changed files with 4 additions and 4 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -43,3 +43,4 @@ hf-transfer
pyngrok==7.0.3
authlib==1.3.0
itsdangerous==2.1.2
seqeval==1.2.2
2 changes: 1 addition & 1 deletion src/autotrain/__init__.py
Original file line number Diff line number Diff line change
@@ -30,4 +30,4 @@
warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow")


__version__ = "0.6.98.dev0"
__version__ = "0.6.99.dev0"
4 changes: 1 addition & 3 deletions src/autotrain/trainers/clm/__main__.py
Original file line number Diff line number Diff line change
@@ -149,6 +149,7 @@ def train(config):
tokenizer.chat_template = chat_template
else:
tokenizer = AutoTokenizer.from_pretrained(config.model, token=config.token, trust_remote_code=True)
tokenizer.chat_template = utils.DEFAULT_CHAT_TEMPLATE

if config.chat_template in ("chatml", "zephyr", "tokenizer"):
train_data = train_data.map(
@@ -167,9 +168,6 @@ def train(config):
},
)

if tokenizer.chat_template is None and config.trainer != "default":
tokenizer.chat_template = utils.DEFAULT_CHAT_TEMPLATE

if tokenizer.model_max_length > 2048:
tokenizer.model_max_length = config.model_max_length

1 change: 1 addition & 0 deletions src/autotrain/trainers/clm/utils.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
from autotrain import logger


DEFAULT_CHAT_TEMPLATE = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
CHATML_CHAT_TEMPLATE = "{% for message in messages %}\n{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{{'<|im_start|>assistant\n' }}{% endif %}{% endfor %}"
ZEPHYR_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

0 comments on commit 64a4107

Please sign in to comment.