Skip to content

Commit

Permalink
Fix filter
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 17, 2024
1 parent 437c838 commit 8514a82
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions scripts/prepare_tulu_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ def main(opts) -> None:

dataset = ds.load_dataset("allenai/tulu-v2-sft-mixture", split="train")

log.info("Filtering dataset...")
dataset = dataset.filter(filter, batched=False, num_proc=opts.num_proc) # type: ignore

log.info("Tokenizing dataset...")
dataset = dataset.map(
partial(preprocess, tokenizer=tokenizer, max_seq_len=opts.seq_len),
Expand All @@ -37,12 +34,17 @@ def main(opts) -> None:
num_proc=opts.num_proc, # type: ignore
)

log.info("Filtering dataset...")
n = len(dataset) # type: ignore
dataset = dataset.filter(filter, batched=False, num_proc=opts.num_proc) # type: ignore
log.info(f"Filtered out {n - len(dataset):,d} examples")

log.info("Counting tokens...")
total_tokens = 0
for ex in track(dataset):
assert len(ex["input_ids"]) == opts.seq_len # type: ignore
total_tokens += len(ex["input_ids"]) # type: ignore
log.info(f"{total_tokens:,d}")
log.info(f"Total tokens: {total_tokens:,d}")

log.info(f"Saving results to '{opts.output_dir}'...")
output_dir = Path(opts.output_dir)
Expand All @@ -67,10 +69,7 @@ def main(opts) -> None:


def filter(example):
for msg in example["messages"]:
if msg["role"] == "assistant":
return True
return False
return example["n_labels"] > 0


def preprocess(example, tokenizer: Tokenizer, max_seq_len: int):
Expand Down Expand Up @@ -104,8 +103,9 @@ def preprocess(example, tokenizer: Tokenizer, max_seq_len: int):
label_mask += [False] * pad_len

assert len(input_ids) == len(label_mask)
n_labels = sum(label_mask)

return {"input_ids": input_ids, "label_mask": label_mask}
return {"input_ids": input_ids, "label_mask": label_mask, "n_labels": n_labels}


def get_parser() -> ArgumentParser:
Expand Down

0 comments on commit 8514a82

Please sign in to comment.