Skip to content

Commit

Permalink
handled case where is
Browse files Browse the repository at this point in the history
  • Loading branch information
bauerem committed Nov 23, 2024
1 parent 7170a4e commit 8abb2c0
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions libs/partners/ollama/langchain_ollama/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _get_tool_calls_from_response(
) -> List[ToolCall]:
"""Get tool calls from ollama response."""
tool_calls = []
if "message" in response:
if response.get("message", None):
if "tool_calls" in response["message"]:
for tc in response["message"]["tool_calls"]:
tool_calls.append(
Expand Down Expand Up @@ -350,7 +350,8 @@ def _chat_params(
ollama_messages = self._convert_messages_to_ollama_messages(messages)

if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
raise ValueError(
"`stop` found in both the input and default params.")
elif self.stop is not None:
stop = self.stop

Expand Down Expand Up @@ -427,7 +428,8 @@ def _convert_messages_to_ollama_messages(
role = "tool"
tool_call_id = message.tool_call_id
else:
raise ValueError("Received unsupported message type for Ollama.")
raise ValueError(
"Received unsupported message type for Ollama.")

content = ""
images = []
Expand Down Expand Up @@ -536,7 +538,8 @@ def _chat_stream_with_aggregation(
tool_calls=_get_tool_calls_from_response(stream_resp),
),
generation_info=(
dict(stream_resp) if stream_resp.get("done") is True else None
dict(stream_resp) if stream_resp.get(
"done") is True else None
),
)
if final_chunk is None:
Expand Down Expand Up @@ -579,7 +582,8 @@ async def _achat_stream_with_aggregation(
tool_calls=_get_tool_calls_from_response(stream_resp),
),
generation_info=(
dict(stream_resp) if stream_resp.get("done") is True else None
dict(stream_resp) if stream_resp.get(
"done") is True else None
),
)
if final_chunk is None:
Expand Down Expand Up @@ -626,8 +630,10 @@ def _generate(
chat_generation = ChatGeneration(
message=AIMessage(
content=final_chunk.text,
usage_metadata=cast(AIMessageChunk, final_chunk.message).usage_metadata,
tool_calls=cast(AIMessageChunk, final_chunk.message).tool_calls,
usage_metadata=cast(
AIMessageChunk, final_chunk.message).usage_metadata,
tool_calls=cast(
AIMessageChunk, final_chunk.message).tool_calls,
),
generation_info=generation_info,
)
Expand Down Expand Up @@ -656,7 +662,8 @@ def _stream(
tool_calls=_get_tool_calls_from_response(stream_resp),
),
generation_info=(
dict(stream_resp) if stream_resp.get("done") is True else None
dict(stream_resp) if stream_resp.get(
"done") is True else None
),
)
if run_manager:
Expand Down Expand Up @@ -689,7 +696,8 @@ async def _astream(
tool_calls=_get_tool_calls_from_response(stream_resp),
),
generation_info=(
dict(stream_resp) if stream_resp.get("done") is True else None
dict(stream_resp) if stream_resp.get(
"done") is True else None
),
)
if run_manager:
Expand All @@ -713,8 +721,10 @@ async def _agenerate(
chat_generation = ChatGeneration(
message=AIMessage(
content=final_chunk.text,
usage_metadata=cast(AIMessageChunk, final_chunk.message).usage_metadata,
tool_calls=cast(AIMessageChunk, final_chunk.message).tool_calls,
usage_metadata=cast(
AIMessageChunk, final_chunk.message).usage_metadata,
tool_calls=cast(
AIMessageChunk, final_chunk.message).tool_calls,
),
generation_info=generation_info,
)
Expand Down

0 comments on commit 8abb2c0

Please sign in to comment.