Skip to content

Commit

Permalink
comment extract method
Browse files Browse the repository at this point in the history
  • Loading branch information
bminixhofer committed Dec 20, 2023
1 parent 7fccec7 commit 26942ef
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __call__(self, hashed_ids, attention_mask, language_ids=None):

return {"logits": logits}

# TODO: comment
def extract(
batch_of_texts,
model,
Expand All @@ -69,6 +68,15 @@ def extract(
pad_last_batch=False,
verbose=False,
):
"""
Computes logits for the given batch of texts by:
1. slicing the texts into chunks of size `block_size`.
2. passing every chunk through the model forward.
3. stitching predictings back together by averaging chunk logits.
ad 1.: text is sliced into partially overlapping chunks by moving forward by a `stride` parameter (think conv1d).
"""

text_lengths = [len(text) for text in batch_of_texts]
# reduce block size if possible
block_size = min(block_size, max(text_lengths))
Expand All @@ -77,13 +85,22 @@ def extract(
downsampling_rate = getattr(model.config, "downsampling_rate", 1)
block_size = math.ceil(block_size / downsampling_rate) * downsampling_rate

# total number of forward passes
num_chunks = sum(math.ceil(max(length - block_size, 0) / stride) + 1 for length in text_lengths)

# preallocate a buffer for all input hashes & attention masks
input_hashes = np.zeros((num_chunks, block_size, model.config.num_hash_functions), dtype=np.int64)
attention_mask = np.zeros((num_chunks, block_size), dtype=np.float32)

# locs keep track of the location of every chunk with a 3-tuple (text_idx, char_start, char_end) that indexes
# back into the batch_of_texts
locs = np.zeros((num_chunks, 3), dtype=np.int32)

# this is equivalent to (but faster than) np.array([ord(c) for c in "".join(batch_of_texts)])
codec = "utf-32-le" if sys.byteorder == "little" else "utf-32-be"
ordinals = np.frombuffer(bytearray("".join(batch_of_texts), encoding=codec), dtype=np.int32)

# hash encode all ids
flat_hashed_ids = hash_encode(ordinals,
num_hashes=model.config.num_hash_functions,
num_buckets=model.config.num_hash_buckets)
Expand All @@ -92,6 +109,7 @@ def extract(

for i in range(len(batch_of_texts)):
for j in range(0, text_lengths[i], stride):
# for every chunk, assign input hashes, attention mask and loc
start, end = j, j + block_size
done = False

Expand All @@ -114,6 +132,7 @@ def extract(
assert current_chunk == num_chunks
n_batches = math.ceil(len(input_hashes) / batch_size)

# containers for the final logits
all_logits = [
np.zeros(
(
Expand All @@ -124,6 +143,7 @@ def extract(
)
for length in text_lengths
]
# container for the number of chunks that any character was part of (to average chunk predictions)
all_counts = [np.zeros(length, dtype=np.int16) for length in text_lengths]

uses_lang_adapters = getattr(model.config, "language_adapter", "off") == "on"
Expand All @@ -141,6 +161,7 @@ def extract(
else:
language_ids = None

# forward passes through all chunks
for batch_idx in tqdm(range(n_batches), disable=not verbose):
start, end = batch_idx * batch_size, min(len(input_hashes), (batch_idx + 1) * batch_size)

Expand All @@ -166,6 +187,7 @@ def extract(
all_logits[original_idx][start_char_idx:end_char_idx] += logits[i - start, : end_char_idx - start_char_idx]
all_counts[original_idx][start_char_idx:end_char_idx] += 1

# so far, logits are summed, so we average them here
all_logits = [(logits / counts[:, None]).astype(np.float16) for logits, counts in zip(all_logits, all_counts)]

return all_logits

0 comments on commit 26942ef

Please sign in to comment.