diff --git a/integrations/llama_cpp/pyproject.toml b/integrations/llama_cpp/pyproject.toml index 563af391d..a90118ee4 100644 --- a/integrations/llama_cpp/pyproject.toml +++ b/integrations/llama_cpp/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "coverage[toml]>=6.5", "pytest", "haystack-pydoc-tools", + "transformers[sentencepiece]" ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py index cac9235bd..10b20d363 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from .chat.chat_generator import LlamaCppChatGenerator from .generator import LlamaCppGenerator -__all__ = ["LlamaCppGenerator"] +__all__ = ["LlamaCppGenerator", "LlamaCppChatGenerator"] diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py new file mode 100644 index 000000000..e305c2a3d --- /dev/null +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py @@ -0,0 +1,124 @@ +import logging +from typing import Any, Dict, List, Optional + +from haystack import component +from haystack.dataclasses import ChatMessage, ChatRole +from llama_cpp import Llama +from llama_cpp.llama_tokenizer import LlamaHFTokenizer + +logger = logging.getLogger(__name__) + + +@component +class LlamaCppChatGenerator: + """ + Provides an interface to generate text using LLM via llama.cpp. + + [llama.cpp](https://github.com/ggerganov/llama.cpp) is a project written in C/C++ for efficient inference of LLMs. + It employs the quantized GGUF format, suitable for running these models on standard machines (even without GPUs). + + Usage example: + ```python + from haystack_integrations.components.generators.llama_cpp import LlamaCppChatGenerator + user_message = [ChatMessage.from_user("Who is the best American actor?")] + generator = LlamaCppGenerator(model="zephyr-7b-beta.Q4_0.gguf", n_ctx=2048, n_batch=512) + + print(generator.run(user_message, generation_kwargs={"max_tokens": 128})) + # {"replies": [ChatMessage(content="John Cusack", role=, name=None, meta={...}]} + ``` + """ + + def __init__( + self, + model: str, + n_ctx: Optional[int] = 0, + n_batch: Optional[int] = 512, + model_kwargs: Optional[Dict[str, Any]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + :param model: The path of a quantized model for text generation, for example, "zephyr-7b-beta.Q4_0.gguf". + If the model path is also specified in the `model_kwargs`, this parameter will be ignored. + :param n_ctx: The number of tokens in the context. When set to 0, the context will be taken from the model. + :param n_batch: Prompt processing maximum batch size. + :param model_kwargs: Dictionary containing keyword arguments used to initialize the LLM for text generation. + These keyword arguments provide fine-grained control over the model loading. + In case of duplication, these kwargs override `model`, `n_ctx`, and `n_batch` init parameters. + For more information on the available kwargs, see + [llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__). + :param generation_kwargs: A dictionary containing keyword arguments to customize text generation. + For more information on the available kwargs, see + [llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion). + """ + + model_kwargs = model_kwargs or {} + generation_kwargs = generation_kwargs or {} + + if "hf_tokenizer_path" in model_kwargs: + tokenizer = LlamaHFTokenizer.from_pretrained(model_kwargs["hf_tokenizer_path"]) + model_kwargs["tokenizer"] = tokenizer + + # check if the model_kwargs contain the essential parameters + # otherwise, populate them with values from init parameters + model_kwargs.setdefault("model_path", model) + model_kwargs.setdefault("n_ctx", n_ctx) + model_kwargs.setdefault("n_batch", n_batch) + + self.model_path = model + self.n_ctx = n_ctx + self.n_batch = n_batch + self.model_kwargs = model_kwargs + self.generation_kwargs = generation_kwargs + self.model = None + + def warm_up(self): + if self.model is None: + self.model = Llama(**self.model_kwargs) + + @component.output_types(replies=List[ChatMessage]) + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Run the text generation model on the given list of ChatMessages. + + :param messages: + A list of ChatMessage instances representing the input messages. + :param generation_kwargs: A dictionary containing keyword arguments to customize text generation. + For more information on the available kwargs, see + [llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion). + :returns: A dictionary with the following keys: + - `replies`: The responses from the model + """ + if self.model is None: + error_msg = "The model has not been loaded. Please call warm_up() before running." + raise RuntimeError(error_msg) + + if not messages: + return {"replies": []} + + updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + formatted_messages = [msg.to_openai_format() for msg in messages] + + response = self.model.create_chat_completion(messages=formatted_messages, **updated_generation_kwargs) + replies = [ + ChatMessage( + content=choice["message"]["content"], + role=ChatRole[choice["message"]["role"].upper()], + name=None, + meta={ + "response_id": response["id"], + "model": response["model"], + "created": response["created"], + "index": choice["index"], + "finish_reason": choice["finish_reason"], + "usage": response["usage"], + }, + ) + for choice in response["choices"] + ] + + for reply, choice in zip(replies, response["choices"]): + tool_calls = choice.get("message", {}).get("tool_calls", []) + if tool_calls: + reply.meta["tool_calls"] = tool_calls + reply.name = tool_calls[0]["function"]["name"] if tool_calls else None + return {"replies": replies} diff --git a/integrations/llama_cpp/tests/test_chat_generator.py b/integrations/llama_cpp/tests/test_chat_generator.py new file mode 100644 index 000000000..5666f109a --- /dev/null +++ b/integrations/llama_cpp/tests/test_chat_generator.py @@ -0,0 +1,483 @@ +import json +import os +import urllib.request +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +from haystack import Document, Pipeline +from haystack.components.builders.dynamic_chat_prompt_builder import DynamicChatPromptBuilder +from haystack.components.retrievers.in_memory import InMemoryBM25Retriever +from haystack.dataclasses import ChatMessage, ChatRole +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack_integrations.components.generators.llama_cpp import LlamaCppChatGenerator + + +@pytest.fixture +def model_path(): + return Path(__file__).parent / "models" + + +def download_file(file_link, filename, capsys): + # Checks if the file already exists before downloading + if not os.path.isfile(filename): + urllib.request.urlretrieve(file_link, filename) # noqa: S310 + with capsys.disabled(): + print("\nModel file downloaded successfully.") + else: + with capsys.disabled(): + print("\nModel file already exists.") + + +class TestLlamaCppChatGenerator: + @pytest.fixture + def generator(self, model_path, capsys): + gguf_model_path = ( + "https://huggingface.co/TheBloke/openchat-3.5-1210-GGUF/resolve/main/openchat-3.5-1210.Q3_K_S.gguf" + ) + filename = "openchat-3.5-1210.Q3_K_S.gguf" + + # Download GGUF model from HuggingFace + download_file(gguf_model_path, str(model_path / filename), capsys) + + model_path = str(model_path / filename) + generator = LlamaCppChatGenerator(model=model_path, n_ctx=8192, n_batch=512) + generator.warm_up() + return generator + + @pytest.fixture + def generator_mock(self): + mock_model = MagicMock() + generator = LlamaCppChatGenerator(model="test_model.gguf", n_ctx=2048, n_batch=512) + generator.model = mock_model + return generator, mock_model + + def test_default_init(self): + """ + Test default initialization parameters. + """ + generator = LlamaCppChatGenerator(model="test_model.gguf") + + assert generator.model_path == "test_model.gguf" + assert generator.n_ctx == 0 + assert generator.n_batch == 512 + assert generator.model_kwargs == {"model_path": "test_model.gguf", "n_ctx": 0, "n_batch": 512} + assert generator.generation_kwargs == {} + + def test_custom_init(self): + """ + Test custom initialization parameters. + """ + generator = LlamaCppChatGenerator( + model="test_model.gguf", + n_ctx=8192, + n_batch=512, + ) + + assert generator.model_path == "test_model.gguf" + assert generator.n_ctx == 8192 + assert generator.n_batch == 512 + assert generator.model_kwargs == {"model_path": "test_model.gguf", "n_ctx": 8192, "n_batch": 512} + assert generator.generation_kwargs == {} + + def test_ignores_model_path_if_specified_in_model_kwargs(self): + """ + Test that model_path is ignored if already specified in model_kwargs. + """ + generator = LlamaCppChatGenerator( + model="test_model.gguf", + n_ctx=8192, + n_batch=512, + model_kwargs={"model_path": "other_model.gguf"}, + ) + assert generator.model_kwargs["model_path"] == "other_model.gguf" + + def test_ignores_n_ctx_if_specified_in_model_kwargs(self): + """ + Test that n_ctx is ignored if already specified in model_kwargs. + """ + generator = LlamaCppChatGenerator(model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_ctx": 8192}) + assert generator.model_kwargs["n_ctx"] == 8192 + + def test_ignores_n_batch_if_specified_in_model_kwargs(self): + """ + Test that n_batch is ignored if already specified in model_kwargs. + """ + generator = LlamaCppChatGenerator( + model="test_model.gguf", n_ctx=8192, n_batch=512, model_kwargs={"n_batch": 1024} + ) + assert generator.model_kwargs["n_batch"] == 1024 + + def test_raises_error_without_warm_up(self): + """ + Test that the generator raises an error if warm_up() is not called before running. + """ + generator = LlamaCppChatGenerator(model="test_model.gguf", n_ctx=512, n_batch=512) + with pytest.raises(RuntimeError): + generator.run("What is the capital of China?") + + def test_run_with_empty_message(self, generator_mock): + """ + Test that an empty message returns an empty list of replies. + """ + generator, _ = generator_mock + result = generator.run([]) + assert isinstance(result["replies"], list) + assert len(result["replies"]) == 0 + + def test_run_with_valid_message(self, generator_mock): + """ + Test that a valid message returns a list of replies. + """ + generator, mock_model = generator_mock + mock_output = { + "id": "unique-id-123", + "model": "Test Model Path", + "created": 1715226164, + "choices": [ + {"index": 0, "message": {"content": "Generated text", "role": "assistant"}, "finish_reason": "stop"} + ], + "usage": {"prompt_tokens": 14, "completion_tokens": 57, "total_tokens": 71}, + } + mock_model.create_chat_completion.return_value = mock_output + result = generator.run(messages=[ChatMessage.from_system("Test")]) + assert isinstance(result["replies"], list) + assert len(result["replies"]) == 1 + assert isinstance(result["replies"][0], ChatMessage) + assert result["replies"][0].content == "Generated text" + assert result["replies"][0].role == ChatRole.ASSISTANT + + def test_run_with_generation_kwargs(self, generator_mock): + """ + Test that a valid message and generation kwargs returns a list of replies. + """ + generator, mock_model = generator_mock + mock_output = { + "id": "unique-id-123", + "model": "Test Model Path", + "created": 1715226164, + "choices": [ + {"index": 0, "message": {"content": "Generated text", "role": "assistant"}, "finish_reason": "length"} + ], + "usage": {"prompt_tokens": 14, "completion_tokens": 57, "total_tokens": 71}, + } + mock_model.create_chat_completion.return_value = mock_output + generation_kwargs = {"max_tokens": 128} + result = generator.run([ChatMessage.from_system("Write a 200 word paragraph.")], generation_kwargs) + assert result["replies"][0].content == "Generated text" + assert result["replies"][0].meta["finish_reason"] == "length" + + @pytest.mark.integration + def test_run(self, generator): + """ + Test that a valid message returns a list of replies. + """ + questions_and_answers = [ + ("What's the capital of France?", "Paris"), + ("What is the capital of Canada?", "Ottawa"), + ("What is the capital of Ghana?", "Accra"), + ] + + for question, answer in questions_and_answers: + chat_message = ChatMessage.from_system( + f"GPT4 Correct User: Answer in a single word. {question} <|end_of_turn|>\n GPT4 Correct Assistant:" + ) + result = generator.run([chat_message]) + + assert "replies" in result + assert isinstance(result["replies"], list) + assert len(result["replies"]) > 0 + assert any(answer.lower() in reply.content.lower() for reply in result["replies"]) + + @pytest.mark.integration + def test_run_rag_pipeline(self, generator): + """ + Test that a valid message returns a list of replies. + """ + document_store = InMemoryDocumentStore() + documents = [ + Document(content="There are over 7,000 languages spoken around the world today."), + Document( + content="""Elephants have been observed to behave in a way that indicates a high + level of self-awareness, such as recognizing themselves in mirrors.""" + ), + Document( + content="""In certain parts of the world, like the Maldives, Puerto Rico, + and San Diego, you can witness the phenomenon of bioluminescent waves.""" + ), + ] + document_store.write_documents(documents=documents) + + pipeline = Pipeline() + pipeline.add_component( + instance=InMemoryBM25Retriever(document_store=document_store, top_k=1), + name="retriever", + ) + pipeline.add_component( + instance=DynamicChatPromptBuilder(runtime_variables=["query", "documents"]), name="prompt_builder" + ) + pipeline.add_component(instance=generator, name="llm") + pipeline.connect("retriever.documents", "prompt_builder.documents") + pipeline.connect("prompt_builder.prompt", "llm.messages") + + question = "How many languages are there?" + location = "Puerto Rico" + system_message = ChatMessage.from_system( + "You are a helpful assistant giving out valuable information to tourists." + ) + messages = [ + system_message, + ChatMessage.from_user( + """ + Given these documents and given that I am currently in {{ location }}, answer the question.\nDocuments: + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + + \nQuestion: {{query}} + \nAnswer: + """ + ), + ] + question = "Can I see bioluminescent waves at my current location?" + result = pipeline.run( + data={ + "retriever": {"query": question}, + "prompt_builder": { + "template_variables": {"location": location}, + "prompt_source": messages, + "query": question, + }, + } + ) + + replies = result["llm"]["replies"] + assert len(replies) > 0 + assert any("bioluminescent waves" in reply.content for reply in replies) + assert all(reply.role == ChatRole.ASSISTANT for reply in replies) + + @pytest.mark.integration + def test_json_constraining(self, generator): + """ + Test that the generator can output valid JSON. + """ + messages = [ChatMessage.from_system("Output valid json only. List 2 people with their name and age.")] + json_schema = { + "type": "object", + "properties": { + "people": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "number"}, + }, + }, + }, + }, + "required": ["people"], + } + + result = generator.run( + messages=messages, + generation_kwargs={ + "response_format": {"type": "json_object", "schema": json_schema}, + }, + ) + + assert "replies" in result + assert isinstance(result["replies"], list) + assert len(result["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in result["replies"]) + for reply in result["replies"]: + assert json.loads(reply.content) + assert isinstance(json.loads(reply.content), dict) + assert "people" in json.loads(reply.content) + assert isinstance(json.loads(reply.content)["people"], list) + assert all(isinstance(person, dict) for person in json.loads(reply.content)["people"]) + assert all("name" in person for person in json.loads(reply.content)["people"]) + assert all("age" in person for person in json.loads(reply.content)["people"]) + assert all(isinstance(person["name"], str) for person in json.loads(reply.content)["people"]) + assert all(isinstance(person["age"], int) for person in json.loads(reply.content)["people"]) + + +class TestLlamaCppChatGeneratorFunctionary: + def get_current_temperature(self, location): + """Get the current temperature in a given location""" + if "tokyo" in location.lower(): + return json.dumps({"location": "Tokyo", "temperature": "10", "unit": "celsius"}) + elif "san francisco" in location.lower(): + return json.dumps({"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}) + elif "paris" in location.lower(): + return json.dumps({"location": "Paris", "temperature": "22", "unit": "celsius"}) + else: + return json.dumps({"location": location, "temperature": "unknown"}) + + @pytest.fixture + def generator(self, model_path, capsys): + gguf_model_path = ( + "https://huggingface.co/meetkai/functionary-small-v2.4-GGUF/resolve/main/functionary-small-v2.4.Q4_0.gguf" + ) + filename = "functionary-small-v2.4.Q4_0.gguf" + download_file(gguf_model_path, str(model_path / filename), capsys) + model_path = str(model_path / filename) + hf_tokenizer_path = "meetkai/functionary-small-v2.4-GGUF" + generator = LlamaCppChatGenerator( + model=model_path, + n_ctx=8192, + n_batch=512, + model_kwargs={ + "chat_format": "functionary-v2", + "hf_tokenizer_path": hf_tokenizer_path, + }, + ) + generator.warm_up() + return generator + + @pytest.mark.integration + def test_function_call(self, generator): + tools = [ + { + "type": "function", + "function": { + "name": "get_user_info", + "parameters": { + "type": "object", + "properties": { + "username": {"type": "string", "description": "The username to retrieve information for."} + }, + "required": ["username"], + }, + "description": "Retrieves detailed information about a user.", + }, + } + ] + tool_choice = {"type": "function", "function": {"name": "get_user_info"}} + + messages = [ + ChatMessage.from_user("Get information for user john_doe"), + ] + response = generator.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": tool_choice}) + + assert "tool_calls" in response["replies"][0].meta + tool_calls = response["replies"][0].meta["tool_calls"] + assert len(tool_calls) > 0 + assert tool_calls[0]["function"]["name"] == "get_user_info" + assert "username" in json.loads(tool_calls[0]["function"]["arguments"]) + assert response["replies"][0].role == ChatRole.ASSISTANT + + def test_function_call_and_execute(self, generator): + messages = [ChatMessage.from_user("What's the weather like in San Francisco?")] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_temperature", + "description": "Get the current temperature in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + + response = generator.run(messages=messages, generation_kwargs={"tools": tools}) + + available_functions = { + "get_current_temperature": self.get_current_temperature, + } + + assert "replies" in response + assert len(response["replies"]) > 0 + + first_reply = response["replies"][0] + assert "tool_calls" in first_reply.meta + tool_calls = first_reply.meta["tool_calls"] + + for tool_call in tool_calls: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + assert function_name in available_functions + function_response = available_functions[function_name](**function_args) + function_message = ChatMessage.from_function(function_response, function_name) + messages.append(function_message) + + second_response = generator.run(messages=messages) + print(second_response) + assert "replies" in second_response + assert len(second_response["replies"]) > 0 + assert any("San Francisco" in reply.content for reply in second_response["replies"]) + assert any("72" in reply.content for reply in second_response["replies"]) + + +class TestLlamaCppChatGeneratorChatML: + + @pytest.fixture + def generator(self, model_path, capsys): + gguf_model_path = ( + "https://huggingface.co/TheBloke/openchat-3.5-1210-GGUF/resolve/main/openchat-3.5-1210.Q3_K_S.gguf" + ) + filename = "openchat-3.5-1210.Q3_K_S.gguf" + download_file(gguf_model_path, str(model_path / filename), capsys) + model_path = str(model_path / filename) + generator = LlamaCppChatGenerator( + model=model_path, + n_ctx=8192, + n_batch=512, + model_kwargs={ + "chat_format": "chatml-function-calling", + }, + ) + generator.warm_up() + return generator + + @pytest.mark.integration + def test_function_call_chatml(self, generator): + messages = [ + ChatMessage.from_system( + """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, + detailed, and polite answers to the user's questions. The assistant calls functions with appropriate + input when necessary""" + ), + ChatMessage.from_user("Extract Jason is 25 years old"), + ] + + tools = [ + { + "type": "function", + "function": { + "name": "UserDetail", + "parameters": { + "type": "object", + "title": "UserDetail", + "properties": { + "name": {"title": "Name", "type": "string"}, + "age": {"title": "Age", "type": "integer"}, + }, + "required": ["name", "age"], + }, + }, + } + ] + + tool_choice = {"type": "function", "function": {"name": "UserDetail"}} + + response = generator.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": tool_choice}) + for reply in response["replies"]: + assert "tool_calls" in reply.meta + tool_calls = reply.meta["tool_calls"] + assert len(tool_calls) > 0 + assert tool_calls[0]["function"]["name"] == "UserDetail" + assert "name" in json.loads(tool_calls[0]["function"]["arguments"]) + assert "age" in json.loads(tool_calls[0]["function"]["arguments"]) + assert "Jason" in json.loads(tool_calls[0]["function"]["arguments"])["name"] + assert 25 == json.loads(tool_calls[0]["function"]["arguments"])["age"]