diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 98a544f5..15339b2a 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -58,7 +58,6 @@ def __call__(self, hashed_ids, attention_mask, language_ids=None): return {"logits": logits} -# TODO: comment def extract( batch_of_texts, model, @@ -69,6 +68,15 @@ def extract( pad_last_batch=False, verbose=False, ): + """ + Computes logits for the given batch of texts by: + 1. slicing the texts into chunks of size `block_size`. + 2. passing every chunk through the model forward. + 3. stitching predictings back together by averaging chunk logits. + + ad 1.: text is sliced into partially overlapping chunks by moving forward by a `stride` parameter (think conv1d). + """ + text_lengths = [len(text) for text in batch_of_texts] # reduce block size if possible block_size = min(block_size, max(text_lengths)) @@ -77,13 +85,22 @@ def extract( downsampling_rate = getattr(model.config, "downsampling_rate", 1) block_size = math.ceil(block_size / downsampling_rate) * downsampling_rate + # total number of forward passes num_chunks = sum(math.ceil(max(length - block_size, 0) / stride) + 1 for length in text_lengths) + + # preallocate a buffer for all input hashes & attention masks input_hashes = np.zeros((num_chunks, block_size, model.config.num_hash_functions), dtype=np.int64) attention_mask = np.zeros((num_chunks, block_size), dtype=np.float32) + + # locs keep track of the location of every chunk with a 3-tuple (text_idx, char_start, char_end) that indexes + # back into the batch_of_texts locs = np.zeros((num_chunks, 3), dtype=np.int32) + # this is equivalent to (but faster than) np.array([ord(c) for c in "".join(batch_of_texts)]) codec = "utf-32-le" if sys.byteorder == "little" else "utf-32-be" ordinals = np.frombuffer(bytearray("".join(batch_of_texts), encoding=codec), dtype=np.int32) + + # hash encode all ids flat_hashed_ids = hash_encode(ordinals, num_hashes=model.config.num_hash_functions, num_buckets=model.config.num_hash_buckets) @@ -92,6 +109,7 @@ def extract( for i in range(len(batch_of_texts)): for j in range(0, text_lengths[i], stride): + # for every chunk, assign input hashes, attention mask and loc start, end = j, j + block_size done = False @@ -114,6 +132,7 @@ def extract( assert current_chunk == num_chunks n_batches = math.ceil(len(input_hashes) / batch_size) + # containers for the final logits all_logits = [ np.zeros( ( @@ -124,6 +143,7 @@ def extract( ) for length in text_lengths ] + # container for the number of chunks that any character was part of (to average chunk predictions) all_counts = [np.zeros(length, dtype=np.int16) for length in text_lengths] uses_lang_adapters = getattr(model.config, "language_adapter", "off") == "on" @@ -141,6 +161,7 @@ def extract( else: language_ids = None + # forward passes through all chunks for batch_idx in tqdm(range(n_batches), disable=not verbose): start, end = batch_idx * batch_size, min(len(input_hashes), (batch_idx + 1) * batch_size) @@ -166,6 +187,7 @@ def extract( all_logits[original_idx][start_char_idx:end_char_idx] += logits[i - start, : end_char_idx - start_char_idx] all_counts[original_idx][start_char_idx:end_char_idx] += 1 + # so far, logits are summed, so we average them here all_logits = [(logits / counts[:, None]).astype(np.float16) for logits, counts in zip(all_logits, all_counts)] return all_logits