Skip to content

Commit

Permalink
mv StopWordsCriteria under lazy_import (#6128)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Oct 19, 2023
1 parent 025418c commit fe261b9
Showing 1 changed file with 43 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
from haystack.preview import component, default_to_dict
from haystack.preview.lazy_imports import LazyImport

logger = logging.getLogger(__name__)

SUPPORTED_TASKS = ["text-generation", "text2text-generation"]

with LazyImport(
message="PyTorch is needed to run this component. Please install it by following the instructions at https://pytorch.org/"
) as torch_import:
import torch

with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import model_info
from transformers import (
Expand All @@ -15,15 +24,41 @@
PreTrainedTokenizerFast,
)

with LazyImport(
message="PyTorch is needed to run this component. Please install it by following the instructions at https://pytorch.org/"
) as torch_import:
import torch

logger = logging.getLogger(__name__)

class StopWordsCriteria(StoppingCriteria):
"""
Stops text generation if any one of the stop words is generated.
Note: When a stop word is encountered, the generation of new text is stopped.
However, if the stop word is in the prompt itself, it can stop generating new text
prematurely after the first token. This is particularly important for LLMs designed
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
the output includes both the new text and the original prompt. Therefore, it's important
to make sure your prompt has no stop words.
"""

SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stop_words: List[str],
device: Union[str, torch.device] = "cpu",
):
super().__init__()
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
self.stop_ids = encoded_stop_words.input_ids.to(device)

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in self.stop_ids:
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
if found_stop_word:
return True
return False

def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
generated_text_ids = generated_text_ids[-1]
len_generated_text_ids = generated_text_ids.size(0)
len_stop_id = stop_id.size(0)
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
return result


@component
Expand Down Expand Up @@ -186,40 +221,3 @@ def run(self, prompt: str):
replies = [reply.replace(stop_word, "").rstrip() for reply in replies for stop_word in self.stop_words]

return {"replies": replies}


class StopWordsCriteria(StoppingCriteria):
"""
Stops text generation if any one of the stop words is generated.
Note: When a stop word is encountered, the generation of new text is stopped.
However, if the stop word is in the prompt itself, it can stop generating new text
prematurely after the first token. This is particularly important for LLMs designed
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
the output includes both the new text and the original prompt. Therefore, it's important
to make sure your prompt has no stop words.
"""

def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stop_words: List[str],
device: Union[str, torch.device] = "cpu",
):
super().__init__()
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
self.stop_ids = encoded_stop_words.input_ids.to(device)

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in self.stop_ids:
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
if found_stop_word:
return True
return False

def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
generated_text_ids = generated_text_ids[-1]
len_generated_text_ids = generated_text_ids.size(0)
len_stop_id = stop_id.size(0)
result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
return result

0 comments on commit fe261b9

Please sign in to comment.