Skip to content

Commit

Permalink
Filter out examples that don't have assistant
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 17, 2024
1 parent b36f8fb commit 829f090
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions scripts/prepare_tulu_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ def main(opts) -> None:

dataset = ds.load_dataset("allenai/tulu-v2-sft-mixture", split="train")

log.info("Filtering dataset...")
dataset = dataset.filter(filter, batched=False, num_proc=opts.num_proc) # type: ignore

log.info("Tokenizing dataset...")
preprocessed = dataset.map(
dataset = dataset.map(
partial(preprocess, tokenizer=tokenizer, max_seq_len=opts.seq_len),
batched=False,
remove_columns=["dataset", "id", "messages"],
Expand All @@ -36,7 +39,7 @@ def main(opts) -> None:

log.info("Counting tokens...")
total_tokens = 0
for ex in track(preprocessed):
for ex in track(dataset):
assert len(ex["input_ids"]) == opts.seq_len # type: ignore
total_tokens += len(ex["input_ids"]) # type: ignore
log.info(f"{total_tokens:,d}")
Expand All @@ -52,7 +55,7 @@ def main(opts) -> None:
str(output_dir / "label_mask.npy"), dtype=np.bool_, mode="w+", shape=(total_tokens,)
)
offset = 0
for ex in track(preprocessed):
for ex in track(dataset):
ex_len = len(ex["input_ids"]) # type: ignore
input_ids_file[offset : offset + ex_len] = ex["input_ids"] # type: ignore
label_mask_file[offset : offset + ex_len] = ex["label_mask"] # type: ignore
Expand All @@ -63,6 +66,13 @@ def main(opts) -> None:
log.info("Done!")


def filter(example):
for msg in example["messages"]:
if msg["role"] == "assistant":
return True
return False


def preprocess(example, tokenizer: Tokenizer, max_seq_len: int):
input_ids = [tokenizer.eos_token_id]
label_mask = [False]
Expand Down

0 comments on commit 829f090

Please sign in to comment.