-
Notifications
You must be signed in to change notification settings - Fork 503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fine-tuning with label mask #410
Conversation
epwalsh
commented
Jan 17, 2024
- Add support for fine-tuning with a label mask.
- Add a script for preparing Tulu V2 for fine-tuning.
- Add fine-tuning instructions to README.
- Add support for fine-tuning with a label mask. - Add a script for preparing Tulu V2 for fine-tuning. - Add fine-tuning instructions to README.
@@ -0,0 +1,111 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hamishivi could you review this script?
scripts/prepare_tulu_data.py
Outdated
def preprocess(example, tokenizer: Tokenizer, max_seq_len: int): | ||
parts = [] | ||
for msg in example["messages"]: | ||
parts.append(f"<|{msg['role']}|>") | ||
parts.append(msg["content"]) | ||
|
||
prompt = "\n".join(parts[:-1]) + "\n" | ||
completion = parts[-1] | ||
|
||
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) | ||
completion_ids = tokenizer.encode(completion, add_special_tokens=True) | ||
|
||
input_ids = (prompt_ids + completion_ids)[:max_seq_len] | ||
label_mask = ([False] * len(prompt_ids) + [True] * len(completion_ids))[:max_seq_len] | ||
|
||
if len(input_ids) < max_seq_len: | ||
pad_len = max_seq_len - len(input_ids) | ||
input_ids += [tokenizer.pad_token_id] * pad_len | ||
label_mask += [False] * pad_len | ||
|
||
assert len(input_ids) == len(label_mask) | ||
|
||
return {"input_ids": input_ids, "label_mask": label_mask} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hamishivi in particular this function for preprocessing/tokenizing each example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't quite right. Actually, the content for any message from the assistant
role should be trained on, not just the final role. This is because we have some multi-turn dialogues in our dataset, and so this is important for that. This is a bit tricky to do but a reference is here: https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L292
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated to match your script
@@ -10,3 +10,19 @@ | |||
``` | |||
pip install ai2-olmo | |||
``` | |||
|
|||
## Fine-tuning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AkshitaB, fine-tuning instructions added here.
scripts/prepare_tulu_data.py
Outdated
def preprocess(example, tokenizer: Tokenizer, max_seq_len: int): | ||
parts = [] | ||
for msg in example["messages"]: | ||
parts.append(f"<|{msg['role']}|>") | ||
parts.append(msg["content"]) | ||
|
||
prompt = "\n".join(parts[:-1]) + "\n" | ||
completion = parts[-1] | ||
|
||
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) | ||
completion_ids = tokenizer.encode(completion, add_special_tokens=True) | ||
|
||
input_ids = (prompt_ids + completion_ids)[:max_seq_len] | ||
label_mask = ([False] * len(prompt_ids) + [True] * len(completion_ids))[:max_seq_len] | ||
|
||
if len(input_ids) < max_seq_len: | ||
pad_len = max_seq_len - len(input_ids) | ||
input_ids += [tokenizer.pad_token_id] * pad_len | ||
label_mask += [False] * pad_len | ||
|
||
assert len(input_ids) == len(label_mask) | ||
|
||
return {"input_ids": input_ids, "label_mask": label_mask} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't quite right. Actually, the content for any message from the assistant
role should be trained on, not just the final role. This is because we have some multi-turn dialogues in our dataset, and so this is important for that. This is a bit tricky to do but a reference is here: https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L292
scripts/prepare_tulu_data.py
Outdated
completion = parts[-1] | ||
|
||
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) | ||
completion_ids = tokenizer.encode(completion, add_special_tokens=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What special tokens does the olmo tokenizer add? There should be an eos token after every assistant message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
scripts/prepare_tulu_data.py
Outdated
prompt = "\n".join(parts[:-1]) + "\n" | ||
completion = parts[-1] | ||
|
||
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found it useful to add a bos token in training (or rather, using the eos as a bos marker), but I don't think its essential.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
input_ids = (prompt_ids + completion_ids)[:max_seq_len] | ||
label_mask = ([False] * len(prompt_ids) + [True] * len(completion_ids))[:max_seq_len] | ||
|
||
if len(input_ids) < max_seq_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
random q: what happens when the sequence length is over your training max_seq_len? just naive truncation? (this is fine just curious)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea just naive truncation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't see anything obviously wrong.