Skip to content

Commit

Permalink
Add TypeDict LLMMessage
Browse files Browse the repository at this point in the history
* to help with the type declaration of the message history
  • Loading branch information
leila-messallem committed Dec 18, 2024
1 parent 3c55d3f commit d5a287b
Show file tree
Hide file tree
Showing 17 changed files with 96 additions and 63 deletions.
5 changes: 3 additions & 2 deletions examples/customize/llms/custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional

from neo4j_graphrag.llm import LLMInterface, LLMResponse
from neo4j_graphrag.llm.types import LLMMessage


class CustomLLM(LLMInterface):
Expand All @@ -12,7 +13,7 @@ def __init__(self, model_name: str, **kwargs: Any):
def invoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
content: str = (
Expand All @@ -23,7 +24,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
raise NotImplementedError()
Expand Down
7 changes: 4 additions & 3 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RetrieverResult

Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(
def search(
self,
query_text: str = "",
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool | None = None,
Expand Down Expand Up @@ -147,11 +148,11 @@ def search(
return RagResultModel(**result)

def build_query(
self, query_text: str, message_history: Optional[list[dict[str, str]]] = None
self, query_text: str, message_history: Optional[list[LLMMessage]] = None
) -> str:
if message_history:
summarization_prompt = ChatSummaryTemplate().format(
message_history=message_history
message_history=message_history # type: ignore
)
summary = self.llm.invoke(
input=summarization_prompt,
Expand Down
22 changes: 14 additions & 8 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Iterable, Optional, TYPE_CHECKING
from typing import Any, Iterable, Optional, TYPE_CHECKING, cast

from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage
from neo4j_graphrag.llm.types import (
BaseMessage,
LLMMessage,
LLMResponse,
MessageList,
UserMessage,
)

if TYPE_CHECKING:
from anthropic.types.message_param import MessageParam
Expand Down Expand Up @@ -71,22 +77,22 @@ def __init__(
self.async_client = anthropic.AsyncAnthropic(**kwargs)

def get_messages(
self, input: str, message_history: Optional[list[dict[str, str]]] = None
self, input: str, message_history: Optional[list[LLMMessage]] = None
) -> Iterable[MessageParam]:
messages = []
messages: list[dict[str, str]] = []
if message_history:
try:
MessageList(messages=message_history) # type: ignore
MessageList(messages=cast(list[BaseMessage], message_history))
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(message_history)
messages.extend(cast(Iterable[dict[str, Any]], message_history))
messages.append(UserMessage(content=input).model_dump())
return messages # type: ignore

def invoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand Down Expand Up @@ -119,7 +125,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Expand Down
6 changes: 3 additions & 3 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from typing import Any, Optional

from .types import LLMResponse
from .types import LLMMessage, LLMResponse


class LLMInterface(ABC):
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(
def invoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the LLM and retrieves a response.
Expand All @@ -66,7 +66,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the LLM and retrieves a response.
Expand Down
14 changes: 8 additions & 6 deletions src/neo4j_graphrag/llm/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast
from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import (
BaseMessage,
LLMMessage,
LLMResponse,
MessageList,
SystemMessage,
Expand Down Expand Up @@ -76,7 +78,7 @@ def __init__(
def get_messages(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> ChatMessages:
messages = []
Expand All @@ -89,17 +91,17 @@ def get_messages(
messages.append(SystemMessage(content=system_message).model_dump())
if message_history:
try:
MessageList(messages=message_history) # type: ignore
MessageList(messages=cast(list[BaseMessage], message_history))
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(message_history)
messages.extend(cast(Iterable[dict[str, Any]], message_history))
messages.append(UserMessage(content=input).model_dump())
return messages # type: ignore

def invoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand Down Expand Up @@ -127,7 +129,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Expand Down
14 changes: 8 additions & 6 deletions src/neo4j_graphrag/llm/mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from __future__ import annotations

import os
from typing import Any, Optional, cast
from typing import Any, Iterable, Optional, cast
from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import (
BaseMessage,
LLMMessage,
LLMResponse,
MessageList,
SystemMessage,
Expand Down Expand Up @@ -67,7 +69,7 @@ def __init__(
def get_messages(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> list[Messages]:
messages = []
Expand All @@ -80,17 +82,17 @@ def get_messages(
messages.append(SystemMessage(content=system_message).model_dump())
if message_history:
try:
MessageList(messages=message_history) # type: ignore
MessageList(messages=cast(list[BaseMessage], message_history))
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(message_history)
messages.extend(cast(Iterable[dict[str, Any]], message_history))
messages.append(UserMessage(content=input).model_dump())
return cast(list[Messages], messages)

def invoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the Mistral chat completion model
Expand Down Expand Up @@ -126,7 +128,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the MistralAI chat
Expand Down
21 changes: 14 additions & 7 deletions src/neo4j_graphrag/llm/ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any, Optional, Sequence, TYPE_CHECKING
from typing import Any, Iterable, Optional, Sequence, TYPE_CHECKING, cast

from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError

from .base import LLMInterface
from .types import LLMResponse, SystemMessage, UserMessage, MessageList
from .types import (
BaseMessage,
LLMMessage,
LLMResponse,
SystemMessage,
UserMessage,
MessageList,
)

if TYPE_CHECKING:
from ollama import Message
Expand Down Expand Up @@ -53,7 +60,7 @@ def __init__(
def get_messages(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> Sequence[Message]:
messages = []
Expand All @@ -66,17 +73,17 @@ def get_messages(
messages.append(SystemMessage(content=system_message).model_dump())
if message_history:
try:
MessageList(messages=message_history) # type: ignore
MessageList(messages=cast(list[BaseMessage], message_history))
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(message_history)
messages.extend(cast(Iterable[dict[str, Any]], message_history))
messages.append(UserMessage(content=input).model_dump())
return messages # type: ignore

def invoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand All @@ -103,7 +110,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the OpenAI chat
Expand Down
21 changes: 14 additions & 7 deletions src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any, Iterable, Optional
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast

from pydantic import ValidationError

from ..exceptions import LLMGenerationError
from .base import LLMInterface
from .types import LLMResponse, SystemMessage, UserMessage, MessageList
from .types import (
BaseMessage,
LLMMessage,
LLMResponse,
SystemMessage,
UserMessage,
MessageList,
)

if TYPE_CHECKING:
import openai
Expand Down Expand Up @@ -63,7 +70,7 @@ def __init__(
def get_messages(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> Iterable[ChatCompletionMessageParam]:
messages = []
Expand All @@ -76,17 +83,17 @@ def get_messages(
messages.append(SystemMessage(content=system_message).model_dump())
if message_history:
try:
MessageList(messages=message_history) # type: ignore
MessageList(messages=cast(list[BaseMessage], message_history))
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(message_history)
messages.extend(cast(Iterable[dict[str, Any]], message_history))
messages.append(UserMessage(content=input).model_dump())
return messages # type: ignore

def invoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the OpenAI chat completion model
Expand Down Expand Up @@ -117,7 +124,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[dict[str, str]]] = None,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the OpenAI chat
Expand Down
7 changes: 6 additions & 1 deletion src/neo4j_graphrag/llm/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from pydantic import BaseModel
from typing import Literal
from typing import Literal, TypedDict


class LLMResponse(BaseModel):
content: str


class LLMMessage(TypedDict):
role: Literal["system", "user", "assistant"]
content: str


class BaseMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
Expand Down
Loading

0 comments on commit d5a287b

Please sign in to comment.