Skip to content

Commit

Permalink
chore: tools feature
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Nov 19, 2024
1 parent 0a6e3f7 commit de65b87
Show file tree
Hide file tree
Showing 19 changed files with 1,464 additions and 403 deletions.
336 changes: 218 additions & 118 deletions src/llmling/client.py

Large diffs are not rendered by default.

32 changes: 29 additions & 3 deletions src/llmling/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Sequence as TypingSequence # noqa: TCH003
from typing import Any, Literal

from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from llmling.core.typedefs import ProcessingStep # noqa: TCH001
from llmling.processors.base import ProcessorConfig # noqa: TCH001
Expand All @@ -28,9 +28,19 @@ class LLMProviderConfig(BaseModel):
temperature: float | None = None
max_tokens: int | None = None
top_p: float | None = None
tools: dict[str, dict[str, Any]] | list[str] | None = None # Allow both formats
tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051

model_config = ConfigDict(frozen=True)

@field_validator("tools", mode="before")
@classmethod
def convert_tools(cls, v: Any) -> dict[str, dict[str, Any]] | None:
"""Convert tool references to dictionary format."""
if isinstance(v, list):
return {tool: {} for tool in v}
return v

@model_validator(mode="after")
def validate_model_format(self) -> LLMProviderConfig:
"""Validate that model follows provider/name format."""
Expand All @@ -41,11 +51,13 @@ def validate_model_format(self) -> LLMProviderConfig:


class TaskSettings(BaseModel):
"""Settings for a specific task."""
"""Settings for a task."""

temperature: float | None = None
max_tokens: int | None = None
top_p: float | None = None
tools: list[str] | None = None # Add tools field
tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051

model_config = ConfigDict(frozen=True)

Expand Down Expand Up @@ -158,7 +170,20 @@ class TaskTemplate(BaseModel):
provider: str # provider name or group name
context: str # context name or group name
settings: TaskSettings | None = None
inherit_tools: bool = True
# Make tool-related fields optional with None defaults
inherit_tools: bool | None = None
tools: list[str] | None = None
tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051

model_config = ConfigDict(frozen=True)


class ToolConfig(BaseModel):
"""Configuration for a tool."""

import_path: str
name: str | None = None
description: str | None = None

model_config = ConfigDict(frozen=True)

Expand All @@ -174,6 +199,7 @@ class Config(BaseModel):
contexts: dict[str, Context]
context_groups: dict[str, list[str]] = Field(default_factory=dict)
task_templates: dict[str, TaskTemplate]
tools: dict[str, ToolConfig] = Field(default_factory=dict)

model_config = ConfigDict(
frozen=True,
Expand Down
20 changes: 20 additions & 0 deletions src/llmling/config/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,23 @@ def validate_references(self) -> list[str]:
)

return warnings

def validate_tools(self) -> list[str]:
"""Validate tool configuration."""
warnings: list[str] = []

# Skip tool validation if tools aren't configured
if not self.config.tools:
return warnings

# Validate tool references in tasks
for name, template in self.config.task_templates.items():
if not template.tools:
continue
warnings.extend(
f"Tool {tool} referenced in task {name} not found"
for tool in template.tools
if tool not in self.config.tools
)

return warnings
26 changes: 22 additions & 4 deletions src/llmling/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


if TYPE_CHECKING:
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator


class LLMConfig(BaseModel):
Expand All @@ -31,17 +31,35 @@ class LLMConfig(BaseModel):
timeout: int = 30
max_retries: int = 3
streaming: bool = False
# New fields for tool support
tools: list[dict[str, Any]] | None = None
tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051

model_config = ConfigDict(frozen=True)


MessageRole = Literal["system", "user", "assistant"]
"""Valid message roles for chat completion."""


class ToolCall(BaseModel):
"""A tool call request from the LLM."""

id: str # Required by OpenAI
name: str
parameters: dict[str, Any]

model_config = ConfigDict(frozen=True)


class Message(BaseModel):
"""A chat message."""

role: MessageRole
role: Literal["system", "user", "assistant", "tool"]
content: str
name: str | None = None # For tool messages
tool_calls: list[ToolCall] | None = None # For assistant messages
tool_call_id: str | None = None # For tool response messages

model_config = ConfigDict(frozen=True)

Expand All @@ -52,7 +70,7 @@ class CompletionResult(BaseModel):
content: str
model: str
finish_reason: str | None = None
is_stream_chunk: bool = False
tool_calls: list[ToolCall] | None = None
metadata: dict[str, Any] = Field(default_factory=dict)

model_config = ConfigDict(frozen=True)
Expand Down Expand Up @@ -201,7 +219,7 @@ async def _complete_stream_impl(
self,
messages: list[Message],
**kwargs: Any,
) -> AsyncGenerator[CompletionResult, None]:
) -> AsyncIterator[CompletionResult]:
"""Implement actual streaming completion logic."""
yield NotImplemented # pragma: no cover

Expand Down
84 changes: 70 additions & 14 deletions src/llmling/llm/providers/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
import litellm

from llmling.core import exceptions
from llmling.llm.base import CompletionResult, RetryableProvider
from llmling.llm.base import CompletionResult, Message, RetryableProvider, ToolCall


if TYPE_CHECKING:
from collections.abc import AsyncGenerator

from llmling.llm.base import Message
from collections.abc import AsyncIterator


class LiteLLMProvider(RetryableProvider):
Expand All @@ -26,40 +24,97 @@ async def _complete_impl(
) -> CompletionResult:
"""Implement completion using LiteLLM."""
try:
# Convert messages to dict format, explicitly handling tool_calls
messages_dict = []
for msg in messages:
msg_dict: dict[str, Any] = {
"role": msg.role,
"content": msg.content,
}
if msg.name:
msg_dict["name"] = msg.name
if msg.tool_calls:
msg_dict["tool_calls"] = [tc.model_dump() for tc in msg.tool_calls]
messages_dict.append(msg_dict)

# Add tool configuration if present and provider supports it
if self.config.tools and not self._is_local_provider():
kwargs["tools"] = self.config.tools
if self.config.tool_choice is not None:
kwargs["tool_choice"] = self.config.tool_choice

response = await litellm.acompletion(
model=self.config.model,
messages=[msg.model_dump() for msg in messages],
messages=messages_dict,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
timeout=self.config.timeout,
**kwargs,
)

# Handle tool calls if present
tool_calls = None
if hasattr(response.choices[0].message, "tool_calls"):
tc = response.choices[0].message.tool_calls
if tc:
tool_calls = [
ToolCall(
id=call.id,
name=call.function.name,
parameters=call.function.arguments,
)
for call in tc
]

return CompletionResult(
content=response.choices[0].message.content,
content=response.choices[0].message.content or "",
model=response.model,
finish_reason=response.choices[0].finish_reason,
tool_calls=tool_calls,
metadata={
"provider": "litellm",
"usage": response.usage.model_dump(),
},
)

except Exception as exc:
msg = f"LiteLLM completion failed: {exc}"
raise exceptions.LLMError(msg) from exc
msg_ = f"LiteLLM completion failed: {exc}"
raise exceptions.LLMError(msg_) from exc

def _is_local_provider(self) -> bool:
"""Check if the current model is a local provider (like Ollama)."""
return self.config.model.startswith(("ollama/", "local/"))

async def _complete_stream_impl(
self,
messages: list[Message],
**kwargs: Any,
) -> AsyncGenerator[CompletionResult, None]:
) -> AsyncIterator[CompletionResult]:
"""Implement streaming completion using LiteLLM."""
try:
# Convert messages to dict format, same as above
messages_dict = []
for msg in messages:
msg_dict: dict[str, Any] = {
"role": msg.role,
"content": msg.content,
}
if msg.name:
msg_dict["name"] = msg.name
if msg.tool_calls:
msg_dict["tool_calls"] = [tc.model_dump() for tc in msg.tool_calls]
messages_dict.append(msg_dict)

# Add tool configuration if present and provider supports it
if self.config.tools and not self._is_local_provider():
kwargs["tools"] = self.config.tools
if self.config.tool_choice is not None:
kwargs["tool_choice"] = self.config.tool_choice

response_stream = await litellm.acompletion(
model=self.config.model,
messages=[msg.model_dump() for msg in messages],
messages=messages_dict,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
Expand All @@ -72,16 +127,17 @@ async def _complete_stream_impl(
if not chunk.choices[0].delta.content:
continue

# Tool calls aren't supported in streaming mode yet
yield CompletionResult(
content=chunk.choices[0].delta.content,
model=chunk.model,
finish_reason=chunk.choices[0].finish_reason,
is_stream_chunk=True,
metadata={
"provider": "litellm",
"chunk": True,
},
)

except Exception as exc:
msg = f"LiteLLM streaming failed: {exc}"
raise exceptions.LLMError(msg) from exc
except Exception as e:
error_msg = f"LiteLLM streaming failed: {e}"
raise exceptions.LLMError(error_msg) from e
Loading

0 comments on commit de65b87

Please sign in to comment.