Source code for data_juicer.analysis.collector
-from itertools import chain
-
-from data_juicer.format import load_formatter
-from data_juicer.utils.lazy_loader import LazyLoader
-
-torch = LazyLoader('torch', 'torch')
-transformers = LazyLoader('transformers', 'transformers')
-
-
-
-[docs]
-class TextTokenDistCollector(object):
- """Tokenize and collect distribution of tokens for given
- dataset with a specified tokenizer.
- """
-
-
-[docs]
- def __init__(self, tokenizer):
- """
- Initialization method.
-
- :param tokenizer: tokenizer name on huggingface
- """
- self.tokenizer = transformers.AutoTokenizer.from_pretrained(
- tokenizer, trust_remote_code=True)
- self.vocab_size = len(self.tokenizer)
-
-
-
-[docs]
- def collect(self,
- data_path,
- text_key,
- num_proc=1) -> 'torch.distributions.Categorical':
- """
- Tokenize and collect tokens distribution of input dataset
- :param data_path: path to input dataset.
- :param text_key: field keys that will be considered into token counts.
- :param num_proc: number of processes to count tokens.
- :return: token distribution.
- """
-
- formatter = load_formatter(data_path)
- dataset = formatter.load_dataset(num_proc=num_proc)
- assert text_key in dataset.features, f'[{text_key} not find in dataset'
-
- def prepare_tokenizer(
- tokenizer,
- text_key,
- ):
- """
- Prepare a tokenizer function for dataset.
- :param tokenizer: a tokenizer to tokenize sample.
- :param text_key: field keys that will be
- considered into token counts.
- """
-
- def _tokenize_fn(example, ):
- example = tokenizer(example[text_key],
- add_special_tokens=False)
- return example
-
- return _tokenize_fn
-
- tokenize_proc = prepare_tokenizer(self.tokenizer, text_key)
- dataset = dataset.map(tokenize_proc,
- num_proc=num_proc,
- desc=f'tokenize {data_path.split("/")[-1]}')
-
- token_count = torch.zeros(self.vocab_size, dtype=torch.int64)
- token_ids = torch.tensor(
- list(chain.from_iterable(dataset['input_ids'])))
- indices, counts = token_ids.unique(return_counts=True)
- token_count.scatter_(0, indices, counts.to(token_count.dtype))
- dist = torch.distributions.Categorical(token_count)
- return dist
-
-
-