From 5c368eec20ad6e6ae784d8613d3798fcb5431a48 Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Wed, 2 Oct 2024 14:58:42 -0400 Subject: [PATCH] Fix pretrain token list->int for masking Signed-off-by: Mustafa Eyceoz --- src/instructlab/training/data_process.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 6bd05d9b..2e6cd393 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -310,7 +310,7 @@ def main(args: DataProcessArgs): print("\033[92mCategorizing training data type...\033[0m") data_with_input_ids = data_with_input_ids.map( lambda x: { - "is_pretrain": get_sp_token(tokenizer, "<|pretrain|>") in x["input_ids"] + "is_pretrain": get_sp_token(tokenizer, "<|pretrain|>")[0] in x["input_ids"] }, num_proc=NUM_PROC, ) @@ -320,8 +320,8 @@ def main(args: DataProcessArgs): user_tokens=user_tk, assist_tokens=assistant_tk, system_tokens=system_tk, - pretrain_token=get_sp_token(tokenizer, "<|pretrain|>"), - pretrain_end_token=get_sp_token(tokenizer, "<|/pretrain|>"), + pretrain_token=get_sp_token(tokenizer, "<|pretrain|>")[0], + pretrain_end_token=get_sp_token(tokenizer, "<|/pretrain|>")[0], ) print("\033[92munmasking the appropriate message content...\033[0m") data_with_labels = data_with_input_ids.map(