Skip to content

Commit

Permalink
System instruction override
Browse files Browse the repository at this point in the history
  • Loading branch information
leila-messallem committed Dec 16, 2024
1 parent abef33c commit d7df9e8
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 36 deletions.
26 changes: 22 additions & 4 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,31 @@ def get_messages(
return messages

def invoke(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, message_history)
system_message = (
system_instruction
if system_instruction is not None
else self.system_instruction
)
response = self.client.messages.create(
model=self.model_name,
system=self.system_instruction,
system=system_message,
messages=messages,
**self.model_params,
)
Expand All @@ -108,22 +117,31 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, message_history)
system_message = (
system_instruction
if system_instruction is not None
else self.system_instruction
)
response = await self.async_client.messages.create(
model=self.model_name,
system=self.system_instruction,
system=system_message,
messages=messages,
**self.model_params,
)
Expand Down
7 changes: 5 additions & 2 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,17 @@ def invoke(

@abstractmethod
async def ainvoke(
self, input: str, message_history: Optional[list[dict[str, str]]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the LLM and retrieves a response.
Args:
input (str): Text sent to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
Returns:
LLMResponse: The response from the LLM.
Expand Down
30 changes: 23 additions & 7 deletions src/neo4j_graphrag/llm/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,19 @@ def __init__(
self.async_client = cohere.AsyncClientV2(**kwargs)

def get_messages(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> ChatMessages:
messages = []
if self.system_instruction:
messages.append(SystemMessage(content=self.system_instruction).model_dump())
system_message = (
system_instruction
if system_instruction is not None
else self.system_instruction
)
if system_message:
messages.append(SystemMessage(content=system_message).model_dump())
if message_history:
try:
MessageList(messages=message_history)
Expand All @@ -90,19 +98,23 @@ def get_messages(
return messages

def invoke(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, message_history)
messages = self.get_messages(input, message_history, system_instruction)
res = self.client.chat(
messages=messages,
model=self.model_name,
Expand All @@ -114,19 +126,23 @@ def invoke(
)

async def ainvoke(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, message_history)
messages = self.get_messages(input, message_history, system_instruction)
res = self.async_client.chat(
messages=messages,
model=self.model_name,
Expand Down
30 changes: 23 additions & 7 deletions src/neo4j_graphrag/llm/mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,19 @@ def __init__(
self.client = Mistral(api_key=api_key, **kwargs)

def get_messages(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> list[Messages]:
messages = []
if self.system_instruction:
messages.append(SystemMessage(content=self.system_instruction).model_dump())
system_message = (
system_instruction
if system_instruction is not None
else self.system_instruction
)
if system_message:
messages.append(SystemMessage(content=system_message).model_dump())
if message_history:
try:
MessageList(messages=message_history)
Expand All @@ -81,14 +89,18 @@ def get_messages(
return messages

def invoke(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the Mistral chat completion model
and returns the response's content.
Args:
input (str): Text sent to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
Returns:
LLMResponse: The response from MistralAI.
Expand All @@ -97,7 +109,7 @@ def invoke(
LLMGenerationError: If anything goes wrong.
"""
try:
messages = self.get_messages(input, message_history)
messages = self.get_messages(input, message_history, system_instruction)
response = self.client.chat.complete(
model=self.model_name,
messages=messages,
Expand All @@ -113,14 +125,18 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the MistralAI chat
completion model and returns the response's content.
Args:
input (str): Text sent to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
Returns:
LLMResponse: The response from MistralAI.
Expand All @@ -129,7 +145,7 @@ async def ainvoke(
LLMGenerationError: If anything goes wrong.
"""
try:
messages = self.get_messages(input, message_history)
messages = self.get_messages(input, message_history, system_instruction)
response = await self.client.chat.complete_async(
model=self.model_name,
messages=messages,
Expand Down
52 changes: 45 additions & 7 deletions src/neo4j_graphrag/llm/ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,19 @@ def __init__(
)

def get_messages(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> Sequence[Message]:
messages = []
if self.system_instruction:
messages.append(SystemMessage(content=self.system_instruction).model_dump())
system_message = (
system_instruction
if system_instruction is not None
else self.system_instruction
)
if system_message:
messages.append(SystemMessage(content=system_message).model_dump())
if message_history:
try:
MessageList(messages=message_history)
Expand All @@ -65,12 +73,25 @@ def get_messages(
return messages

def invoke(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
Returns:
LLMResponse: The response from the LLM.
"""
try:
response = self.client.chat(
model=self.model_name,
messages=self.get_messages(input, message_history),
messages=self.get_messages(input, message_history, system_instruction),
options=self.model_params,
)
content = response.message.content or ""
Expand All @@ -79,12 +100,29 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the OpenAI chat
completion model and returns the response's content.
Args:
input (str): Text sent to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
Returns:
LLMResponse: The response from OpenAI.
Raises:
LLMGenerationError: If anything goes wrong.
"""
try:
response = await self.async_client.chat(
model=self.model_name,
messages=self.get_messages(input, message_history),
messages=self.get_messages(input, message_history, system_instruction),
options=self.model_params,
)
content = response.message.content or ""
Expand Down
8 changes: 6 additions & 2 deletions src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,18 @@ def invoke(
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the OpenAI chat
completion model and returns the response's content.
Args:
input (str): Text sent to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.
Returns:
LLMResponse: The response from OpenAI.
Expand All @@ -132,7 +136,7 @@ async def ainvoke(
"""
try:
response = await self.async_client.chat.completions.create(
messages=self.get_messages(input, message_history),
messages=self.get_messages(input, message_history, system_instruction),
model=self.model_name,
**self.model_params,
)
Expand Down
Loading

0 comments on commit d7df9e8

Please sign in to comment.