diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index d82ccd25..77cfb650 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -14,10 +14,18 @@ # limitations under the License. from typing import Any, Optional +from pydantic import ValidationError + from neo4j_graphrag.exceptions import LLMGenerationError from .base import LLMInterface -from .types import LLMResponse +from .types import LLMResponse, SystemMessage, UserMessage, MessageList + +try: + import ollama + from ollama import ChatResponse +except ImportError: + ollama = None class OllamaLLM(LLMInterface): @@ -25,16 +33,15 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ): - try: - import ollama - except ImportError: + if ollama is None: raise ImportError( "Could not import ollama Python client. " "Please install it with `pip install ollama`." ) - super().__init__(model_name, model_params, **kwargs) + super().__init__(model_name, model_params, system_instruction, **kwargs) self.ollama = ollama self.client = ollama.Client( **kwargs, @@ -43,32 +50,43 @@ def __init__( **kwargs, ) - def invoke(self, input: str) -> LLMResponse: + def get_messages( + self, input: str, chat_history: Optional[list[Any]] = None + ) -> ChatResponse: + messages = [] + if self.system_instruction: + messages.append(SystemMessage(content=self.system_instruction).model_dump()) + if chat_history: + try: + MessageList(messages=chat_history) + except ValidationError as e: + raise LLMGenerationError(e.errors()) from e + messages.extend(chat_history) + messages.append(UserMessage(content=input).model_dump()) + return messages + + def invoke( + self, input: str, chat_history: Optional[list[Any]] = None + ) -> LLMResponse: try: response = self.client.chat( model=self.model_name, - messages=[ - { - "role": "user", - "content": input, - }, - ], + messages=self.get_messages(input, chat_history), + options=self.model_params, ) content = response.message.content or "" return LLMResponse(content=content) except self.ollama.ResponseError as e: raise LLMGenerationError(e) - async def ainvoke(self, input: str) -> LLMResponse: + async def ainvoke( + self, input: str, chat_history: Optional[list[Any]] = None + ) -> LLMResponse: try: response = await self.async_client.chat( model=self.model_name, - messages=[ - { - "role": "user", - "content": input, - }, - ], + messages=self.get_messages(input, chat_history), + options=self.model_params, ) content = response.message.content or "" return LLMResponse(content=content) diff --git a/tests/unit/llm/test_ollama_llm.py b/tests/unit/llm/test_ollama_llm.py index f7be308b..72d5b71d 100644 --- a/tests/unit/llm/test_ollama_llm.py +++ b/tests/unit/llm/test_ollama_llm.py @@ -12,35 +12,124 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from unittest.mock import MagicMock, Mock, patch import ollama import pytest +from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.ollama_llm import OllamaLLM -def get_mock_ollama() -> MagicMock: - mock = MagicMock() - mock.ResponseError = ollama.ResponseError - return mock - - -@patch("builtins.__import__", side_effect=ImportError) -def test_ollama_llm_missing_dependency(mock_import: Mock) -> None: +@patch("neo4j_graphrag.llm.ollama_llm.ollama", None) +def test_ollama_llm_missing_dependency() -> None: with pytest.raises(ImportError): OllamaLLM(model_name="gpt-4o") -@patch("builtins.__import__") -def test_ollama_llm_happy_path(mock_import: Mock) -> None: - mock_ollama = get_mock_ollama() - mock_import.return_value = mock_ollama +@patch("neo4j_graphrag.llm.ollama_llm.ollama") +def test_ollama_llm_happy_path(mock_ollama: Mock) -> None: + mock_ollama.Client.return_value.chat.return_value = MagicMock( + message=MagicMock(content="ollama chat response"), + ) + model = "gpt" + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + question = "What is graph RAG?" + llm = OllamaLLM( + model, + model_params=model_params, + system_instruction=system_instruction, + ) + + res = llm.invoke(question) + assert isinstance(res, LLMResponse) + assert res.content == "ollama chat response" + messages = [ + {"role": "system", "content": system_instruction}, + {"role": "user", "content": question}, + ] + llm.client.chat.assert_called_once_with( + model=model, messages=messages, options=model_params + ) + + +@patch("neo4j_graphrag.llm.ollama_llm.ollama") +def test_ollama_invoke_with_chat_history_happy_path(mock_ollama: Mock) -> None: + mock_ollama.Client.return_value.chat.return_value = MagicMock( + message=MagicMock(content="ollama chat response"), + ) + model = "gpt" + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + llm = OllamaLLM( + model, + model_params=model_params, + system_instruction=system_instruction, + ) + chat_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + response = llm.invoke(question, chat_history) + assert response.content == "ollama chat response" + messages = [{"role": "system", "content": system_instruction}] + messages.extend(chat_history) + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_once_with( + model=model, messages=messages, options=model_params + ) + + +@patch("neo4j_graphrag.llm.ollama_llm.ollama") +def test_ollama_invoke_with_chat_history_validation_error( + mock_ollama: Mock, +) -> None: mock_ollama.Client.return_value.chat.return_value = MagicMock( message=MagicMock(content="ollama chat response"), ) - llm = OllamaLLM(model_name="gpt") + mock_ollama.ResponseError = ollama.ResponseError + model = "gpt" + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + llm = OllamaLLM( + model, + model_params=model_params, + system_instruction=system_instruction, + ) + chat_history = [ + {"role": "human", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + question = "What about next season?" + + with pytest.raises(LLMGenerationError) as exc_info: + llm.invoke(question, chat_history) + assert "Input should be 'user', 'assistant' or 'system" in str(exc_info.value) + + +@pytest.mark.asyncio +@patch("neo4j_graphrag.llm.ollama_llm.ollama") +async def test_ollama_ainvoke_happy_path(mock_ollama: Mock) -> None: + async def mock_chat_async(*args: Any, **kwargs: Any) -> MagicMock: + return MagicMock( + message=MagicMock(content="ollama chat response"), + ) + + mock_ollama.AsyncClient.return_value.chat = mock_chat_async + model = "gpt" + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + question = "What is graph RAG?" + llm = OllamaLLM( + model, + model_params=model_params, + system_instruction=system_instruction, + ) - res = llm.invoke("my text") + res = await llm.ainvoke(question) assert isinstance(res, LLMResponse) assert res.content == "ollama chat response"