From 0a2b150da17f1c67be660257ed4aa646f48d1127 Mon Sep 17 00:00:00 2001 From: liamhazan Date: Thu, 8 Feb 2024 06:45:07 -0500 Subject: [PATCH] update --- fuse/data/utils/collates.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/fuse/data/utils/collates.py b/fuse/data/utils/collates.py index d4bf2fee3..3cb7ec360 100644 --- a/fuse/data/utils/collates.py +++ b/fuse/data/utils/collates.py @@ -191,3 +191,39 @@ def _pad_size(value: torch.Tensor, dim: int) -> List[int]: padded_values.append(padded_value) return default_collate(padded_values) + + @staticmethod + def crop_to_min_pad(input_ids_list: List[torch.Tensor], pad_token_id: int): + """ + Crop padding of a batch of input_ids tensors to the maximum length. + + Args: + input_ids_list (list of torch.Tensor): List of input_ids tensors, where each tensor represents a sequence. + pad_token_id (int): ID of the padding token used in input_ids tensors. + + Returns: + torch.Tensor: Batched and cropped input_ids tensor with padding removed to the maximum length. + + Example: + >>> input_ids_list = [ + ... torch.tensor([101, 2054, 2003, 0, 0, 0, 0, 0, 0, 0]), + ... torch.tensor([101, 2023, 2003, 1037, 1999, 0, 0, 0, 0, 0]), + ... torch.tensor([101, 2002, 0, 0, 0, 0, 0, 0, 0, 0]), + ... ] + >>> pad_token_id = 0 + >>> cropped_batch = crop_padding_to_max_length(input_ids_list, pad_token_id) + >>> print(cropped_batch) + tensor([[ 101, 2054, 2003, 0, 0], + [ 101, 2023, 2003, 1037, 1999], + [ 101, 2002, 0, 0, 0]]) + + Note: + This function assumes that the input_ids tensors are already padded, and it crops the sequences + to the minimum length by removing trailing padding tokens. + """ + min_length = min( + len(ids) - (ids == pad_token_id).sum().item() for ids in input_ids_list + ) + cropped_sequences = [ids[:min_length] for ids in input_ids_list] + batched_sequences = torch.stack(cropped_sequences, dim=0) + return batched_sequences