diff --git a/haystack/preview/components/generators/hugging_face/hugging_face_local.py b/haystack/preview/components/generators/hugging_face/hugging_face_local.py index a41348e67b..8f1d9874f7 100644 --- a/haystack/preview/components/generators/hugging_face/hugging_face_local.py +++ b/haystack/preview/components/generators/hugging_face/hugging_face_local.py @@ -5,6 +5,15 @@ from haystack.preview import component, default_to_dict from haystack.preview.lazy_imports import LazyImport +logger = logging.getLogger(__name__) + +SUPPORTED_TASKS = ["text-generation", "text2text-generation"] + +with LazyImport( + message="PyTorch is needed to run this component. Please install it by following the instructions at https://pytorch.org/" +) as torch_import: + import torch + with LazyImport(message="Run 'pip install transformers'") as transformers_import: from huggingface_hub import model_info from transformers import ( @@ -15,15 +24,41 @@ PreTrainedTokenizerFast, ) -with LazyImport( - message="PyTorch is needed to run this component. Please install it by following the instructions at https://pytorch.org/" -) as torch_import: - import torch - -logger = logging.getLogger(__name__) - + class StopWordsCriteria(StoppingCriteria): + """ + Stops text generation if any one of the stop words is generated. + + Note: When a stop word is encountered, the generation of new text is stopped. + However, if the stop word is in the prompt itself, it can stop generating new text + prematurely after the first token. This is particularly important for LLMs designed + for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat, + the output includes both the new text and the original prompt. Therefore, it's important + to make sure your prompt has no stop words. + """ -SUPPORTED_TASKS = ["text-generation", "text2text-generation"] + def __init__( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + stop_words: List[str], + device: Union[str, torch.device] = "cpu", + ): + super().__init__() + encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt") + self.stop_ids = encoded_stop_words.input_ids.to(device) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + for stop_id in self.stop_ids: + found_stop_word = self.is_stop_word_found(input_ids, stop_id) + if found_stop_word: + return True + return False + + def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool: + generated_text_ids = generated_text_ids[-1] + len_generated_text_ids = generated_text_ids.size(0) + len_stop_id = stop_id.size(0) + result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id)) + return result @component @@ -186,40 +221,3 @@ def run(self, prompt: str): replies = [reply.replace(stop_word, "").rstrip() for reply in replies for stop_word in self.stop_words] return {"replies": replies} - - -class StopWordsCriteria(StoppingCriteria): - """ - Stops text generation if any one of the stop words is generated. - - Note: When a stop word is encountered, the generation of new text is stopped. - However, if the stop word is in the prompt itself, it can stop generating new text - prematurely after the first token. This is particularly important for LLMs designed - for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat, - the output includes both the new text and the original prompt. Therefore, it's important - to make sure your prompt has no stop words. - """ - - def __init__( - self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - stop_words: List[str], - device: Union[str, torch.device] = "cpu", - ): - super().__init__() - encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt") - self.stop_ids = encoded_stop_words.input_ids.to(device) - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - for stop_id in self.stop_ids: - found_stop_word = self.is_stop_word_found(input_ids, stop_id) - if found_stop_word: - return True - return False - - def is_stop_word_found(self, generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool: - generated_text_ids = generated_text_ids[-1] - len_generated_text_ids = generated_text_ids.size(0) - len_stop_id = stop_id.size(0) - result = all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id)) - return result