forked from huggingface/datatrove
-
Notifications
You must be signed in to change notification settings - Fork 0
/
counter.py
83 lines (65 loc) · 3.12 KB
/
counter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from datatrove.data import DocumentsPipeline
from datatrove.pipeline.base import PipelineStep
from datatrove.utils.tokenization import PipelineStepWithTokenizer, batched
class TokensCounter(PipelineStepWithTokenizer):
"""Count the number of tokens in each document.
This pipeline step uses the HuggingFace fast tokenizers library to count the number of tokens in each document.
It doesn't save the tokenized documents, only the token count.
Args:
tokenizer_name_or_path (str): the name or path of the tokenizer to use, from the HuggingFace tokenizers library or a local file.
count_eos_token (bool): whether to count the EOS token on each document. (basically +1 per document)
batch_size: batch size for tokenization
"""
name = "📊 Counter"
type = "🔢 - TOKENIZER"
def __init__(
self,
tokenizer_name_or_path: str = "gpt2", # tokenizer to use, from HF or a local file path
count_eos_token: bool = False, # whether to count the EOS token on each document
batch_size: int = 10000, # batch size for tokenization
):
super().__init__()
self.tokenizer_name_or_path = tokenizer_name_or_path
self.count_eos_token = count_eos_token
self.batch_size = batch_size
def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline:
"""
Args:
data: DocumentsPipeline:
rank: int: (Default value = 0)
world_size: int: (Default value = 1)
Returns:
DocumentsPipeline: The pipeline with updated documents, each having a new or updated `token_count` in its metadata.
"""
from tokenizers import Encoding
# tokenize document's text in batches to go faster
for batch in batched(data, self.batch_size):
with self.track_time(unit="batch"):
encoded_batch: list[Encoding] = self.tokenizer.encode_batch([document.text for document in batch])
for document, encoded in zip(batch, encoded_batch):
count = len(encoded.ids)
if self.count_eos_token:
count += 1
document.metadata["token_count"] = count
self.stat_update("tokens", value=count)
yield document
class LengthCounter(PipelineStep):
"""This pipeline step can be used after a TokensCounter or Tokenization step
to create an histogram of the document token length.
It doesn't modify the documents, only update a counter for in the stats with each document length.
Will absolutely spam the hell out of your stats.json
"""
name = "📊 Document length counter"
type = "🔢 - TOKENIZER"
def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline:
"""
Args:
data: DocumentsPipeline:
rank: int: (Default value = 0)
world_size: int: (Default value = 1)
Returns:
"""
for document in data:
count = document.metadata["token_count"]
self.stats[count].update(1)
yield document