From 6696abfa522ca883df9d868689871f5d41fae382 Mon Sep 17 00:00:00 2001 From: Daniel Walmsley Date: Sat, 15 Jun 2024 11:46:06 -0700 Subject: [PATCH] Implement custom tensor.isin --- TTS/tts/layers/xtts/stream_generator.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index fa8b9c730f..451c783af0 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -23,6 +23,20 @@ ) from transformers.generation.utils import GenerateOutput, SampleOutput, logger +def custom_isin(elements, test_elements): + # Flatten the tensors + elements_flat = elements.view(-1) + test_elements_flat = test_elements.view(-1) + + # Create a mask tensor + mask = torch.zeros_like(elements_flat, dtype=torch.bool) + + # Compare each element + for test_element in test_elements_flat: + mask |= (elements_flat == test_element) + + # Reshape the mask to the original elements shape + return mask.view(elements.shape) def setup_seed(seed): if seed == -1: @@ -202,10 +216,10 @@ def generate( default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device) is_pad_token_in_inputs = (pad_token_tensor is not None) and ( - torch.isin(elements=inputs_tensor, test_elements=pad_token_tensor).any() + custom_isin(elements=inputs_tensor, test_elements=pad_token_tensor).any() ) is_pad_token_not_equal_to_eos_token_id = (eos_token_tensor is None) or ~( - torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any() + custom_isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any() ) can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id attention_mask_from_padding = inputs_tensor.ne(pad_token_tensor).long()