From de65b87954b5241643aa694720a11a763ceb2bbf Mon Sep 17 00:00:00 2001 From: Philipp Temminghoff Date: Tue, 19 Nov 2024 02:19:41 +0100 Subject: [PATCH] chore: tools feature --- src/llmling/client.py | 336 +++++++++++++++++---------- src/llmling/config/models.py | 32 ++- src/llmling/config/validation.py | 20 ++ src/llmling/llm/base.py | 26 ++- src/llmling/llm/providers/litellm.py | 84 +++++-- src/llmling/resources/test.yml | 117 +++++----- src/llmling/task/executor.py | 93 ++++++-- src/llmling/task/manager.py | 62 +++-- src/llmling/task/models.py | 6 +- src/llmling/testing/tools.py | 38 +++ src/llmling/tools/__init__.py | 14 ++ src/llmling/tools/base.py | 168 ++++++++++++++ src/llmling/tools/code.py | 64 +++++ src/llmling/tools/exceptions.py | 21 ++ src/llmling/utils/calling.py | 65 ++++-- tests/test_client.py | 255 +++++++++----------- tests/test_task_manager.py | 62 ++++- tests/test_tool_integration.py | 244 +++++++++++++++++++ tests/test_tools.py | 160 +++++++++++++ 19 files changed, 1464 insertions(+), 403 deletions(-) create mode 100644 src/llmling/testing/tools.py create mode 100644 src/llmling/tools/__init__.py create mode 100644 src/llmling/tools/base.py create mode 100644 src/llmling/tools/code.py create mode 100644 src/llmling/tools/exceptions.py create mode 100644 tests/test_tool_integration.py create mode 100644 tests/test_tools.py diff --git a/src/llmling/client.py b/src/llmling/client.py index 8095c81..eed26f6 100644 --- a/src/llmling/client.py +++ b/src/llmling/client.py @@ -5,6 +5,7 @@ import asyncio from typing import TYPE_CHECKING, Any, Literal, Protocol, Self, TypeVar, cast, overload +from llmling.config.loading import load_config from llmling.config.manager import ConfigManager from llmling.context import default_registry as context_registry from llmling.core import exceptions @@ -14,6 +15,7 @@ from llmling.task.concurrent import execute_concurrent from llmling.task.executor import TaskExecutor from llmling.task.manager import TaskManager +from llmling.tools.base import ToolRegistry if TYPE_CHECKING: @@ -24,7 +26,6 @@ logger = get_logger(__name__) - T = TypeVar("T") @@ -34,7 +35,7 @@ class Registerable(Protocol): def register(self, name: str, item: Any) -> None: ... -ComponentType = Literal["processor", "context", "provider"] +ComponentType = Literal["processor", "context", "provider", "tool"] class LLMLingClient: @@ -54,7 +55,7 @@ def __init__( config_path: Path to YAML configuration file log_level: Optional logging level validate_config: Whether to validate configuration on load - components: Optional components to register, organized by type + components: Optional dictionary of components to register Example: { "processor": {"name": ProcessorConfig(...)}, @@ -68,17 +69,18 @@ def __init__( self.config_path = config_path self.validate_config = validate_config self.components = components or {} - - # Initialize components as None + self.tool_registry = ToolRegistry() + # Components will be initialized in startup self.config_manager: ConfigManager | None = None - self.processor_registry: ProcessorRegistry | None = None - self.executor: TaskExecutor | None = None + self._processor_registry: ProcessorRegistry | None = None + self._executor: TaskExecutor | None = None self._manager: TaskManager | None = None self._initialized = False @property def manager(self) -> TaskManager: """Get the task manager, raising if not initialized.""" + self._ensure_initialized() if self._manager is None: msg = "Task manager not initialized" raise exceptions.LLMLingError(msg) @@ -98,24 +100,41 @@ def create( Returns: Initialized client instance + + Raises: + LLMLingError: If initialization fails """ client = cls(config_path, **kwargs) - asyncio.run(client.startup()) - return client + try: + # Create new event loop for sync operations + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(client.startup()) + return client + finally: + loop.close() + asyncio.set_event_loop(None) + except Exception as exc: + logger.exception("Failed to create client") + msg = f"Failed to create client: {exc}" + raise exceptions.LLMLingError(msg) from exc async def startup(self) -> None: - """Initialize all components.""" + """Initialize all components. + + Raises: + LLMLingError: If initialization fails + """ if self._initialized: return try: # Initialize registries - self.processor_registry = ProcessorRegistry() + self._processor_registry = ProcessorRegistry() + llm_registry.reset() # Ensure clean state # Load configuration - logger.info("Loading configuration from %s", self.config_path) - from llmling.config.loading import load_config - config = load_config( self.config_path, validate=self.validate_config, @@ -129,31 +148,34 @@ async def startup(self) -> None: await self._register_components() # Start processor registry - await self.processor_registry.startup() + await self._processor_registry.startup() # Create executor and manager - self.executor = TaskExecutor( + self._executor = TaskExecutor( context_registry=context_registry, - processor_registry=self.processor_registry, + processor_registry=self._processor_registry, provider_registry=llm_registry, + tool_registry=self.tool_registry, ) + self._manager = TaskManager(self.config_manager.config, self._executor) - if self.config_manager is None: - msg = "Configuration manager not initialized" - raise exceptions.LLMLingError(msg) # noqa: TRY301 - - self._manager = TaskManager(self.config_manager.config, self.executor) self._initialized = True - logger.info("Client initialized successfully") + logger.debug("Client initialized successfully") except Exception as exc: - msg = "Failed to initialize client" + logger.exception("Client initialization failed") + await self.shutdown() # Ensure cleanup on failure + msg = f"Failed to initialize client: {exc}" raise exceptions.LLMLingError(msg) from exc async def _register_components(self) -> None: - """Register all configured components.""" + """Register all configured components. + + Raises: + LLMLingError: If component registration fails + """ registries: dict[ComponentType, Registerable] = { - "processor": cast(Registerable, self.processor_registry), + "processor": cast(Registerable, self._processor_registry), "context": cast(Registerable, context_registry), "provider": cast(Registerable, llm_registry), } @@ -167,46 +189,108 @@ async def _register_components(self) -> None: for name, item in items.items(): try: registry.register(name, item) - logger.debug( - "Registered %s: %s", - component_type, - name, - ) - except Exception: - logger.exception("Failed to register %s %s", component_type, name) - - async def shutdown(self) -> None: - """Clean up resources.""" - if not self._initialized: - return - - if self.processor_registry: - await self.processor_registry.shutdown() - self._initialized = False - logger.info("Client shut down successfully") - - def _ensure_initialized(self) -> None: - """Ensure client is initialized.""" - if not self._initialized: - msg = "Client not initialized" - raise exceptions.LLMLingError(msg) + logger.debug("Registered %s: %s", component_type, name) + except Exception as exc: + msg = f"Failed to register {component_type} {name}: {exc}" + raise exceptions.LLMLingError(msg) from exc async def _register_providers(self) -> None: - """Register all providers from configuration.""" + """Register all providers from configuration. + + Raises: + LLMLingError: If provider registration fails + """ if not self.config_manager: msg = "Configuration not loaded" raise exceptions.LLMLingError(msg) - # Register direct providers - for provider_key in self.config_manager.config.llm_providers: - llm_registry.register_provider(provider_key, "litellm") - logger.debug("Registered provider: %s", provider_key) + try: + # Register direct providers + for provider_key in self.config_manager.config.llm_providers: + llm_registry.register_provider(provider_key, "litellm") + logger.debug("Registered provider: %s", provider_key) + + # Register provider groups + for ( + group_name, + providers, + ) in self.config_manager.config.provider_groups.items(): + if providers: + llm_registry.register_provider(group_name, "litellm") + logger.debug("Registered provider group: %s", group_name) + except Exception as exc: + msg = "Failed to register providers" + raise exceptions.LLMLingError(msg) from exc + + def execute_sync( + self, + template: str, + *, + system_prompt: str | None = None, + **kwargs: Any, + ) -> TaskResult: + """Execute a task template synchronously. + + Args: + template: Name of template to execute + system_prompt: Optional system prompt + **kwargs: Additional parameters for LLM + + Returns: + Task result + + Raises: + TaskError: If execution fails + """ + try: + # Create new event loop for sync operation + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete( + self.execute( + template, + system_prompt=system_prompt, + stream=False, + **kwargs, + ) + ) + finally: + loop.close() + asyncio.set_event_loop(None) + except Exception as exc: + msg = f"Synchronous execution failed: {exc}" + raise exceptions.TaskError(msg) from exc + + def execute_many_sync( + self, + templates: Sequence[str], + **kwargs: Any, + ) -> list[TaskResult]: + """Execute multiple templates synchronously. + + Args: + templates: Template names to execute + **kwargs: Additional parameters passed to execute_many - # Register provider groups - for group_name, providers in self.config_manager.config.provider_groups.items(): - if providers: - llm_registry.register_provider(group_name, "litellm") - logger.debug("Registered provider group: %s", group_name) + Returns: + List of task results + + Raises: + TaskError: If execution fails + """ + try: + # Create new event loop for sync operation + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(self.execute_many(templates, **kwargs)) + finally: + loop.close() + asyncio.set_event_loop(None) + except Exception as exc: + msg = f"Synchronous concurrent execution failed: {exc}" + raise exceptions.TaskError(msg) from exc @overload async def execute( @@ -246,6 +330,9 @@ async def execute( Returns: Task result or async iterator of results if streaming + + Raises: + TaskError: If execution fails """ self._ensure_initialized() try: @@ -261,35 +348,10 @@ async def execute( **kwargs, ) except Exception as exc: - msg = f"Failed to execute template {template}" + logger.exception("Task execution failed") + msg = f"Failed to execute template {template}: {exc}" raise exceptions.TaskError(msg) from exc - def execute_sync( - self, - template: str, - *, - system_prompt: str | None = None, - **kwargs: Any, - ) -> TaskResult: - """Execute a task template synchronously. - - Args: - template: Name of template to execute - system_prompt: Optional system prompt - **kwargs: Additional parameters for LLM - - Returns: - Task result - """ - return asyncio.run( - self.execute( - template, - system_prompt=system_prompt, - stream=False, - **kwargs, - ) - ) - async def execute_many( self, templates: Sequence[str], @@ -308,31 +370,23 @@ async def execute_many( Returns: List of task results - """ - self._ensure_initialized() - return await execute_concurrent( - self.manager, - templates, - system_prompt=system_prompt, - max_concurrent=max_concurrent, - **kwargs, - ) - - def execute_many_sync( - self, - templates: Sequence[str], - **kwargs: Any, - ) -> list[TaskResult]: - """Execute multiple templates concurrently (synchronous version). - - Args: - templates: Template names to execute - **kwargs: Additional parameters passed to execute_many - Returns: - List of task results + Raises: + TaskError: If execution fails """ - return asyncio.run(self.execute_many(templates, **kwargs)) + self._ensure_initialized() + try: + return await execute_concurrent( + self.manager, + templates, + system_prompt=system_prompt, + max_concurrent=max_concurrent, + **kwargs, + ) + except Exception as exc: + logger.exception("Concurrent execution failed") + msg = f"Concurrent execution failed: {exc}" + raise exceptions.TaskError(msg) from exc async def stream( self, @@ -343,16 +397,16 @@ async def stream( ) -> AsyncIterator[TaskResult]: """Stream results from a task template. - This is a more explicit way to stream results compared to using - execute() with stream=True. - Args: template: Name of template to execute system_prompt: Optional system prompt **kwargs: Additional parameters for LLM - Returns: - Async iterator of task results + Yields: + Task results + + Raises: + TaskError: If streaming fails """ self._ensure_initialized() try: @@ -363,9 +417,32 @@ async def stream( ): yield result except Exception as exc: - msg = f"Failed to stream template {template}" + logger.exception("Task streaming failed") + msg = f"Failed to stream template {template}: {exc}" raise exceptions.TaskError(msg) from exc + def _ensure_initialized(self) -> None: + """Ensure client is initialized.""" + if not self._initialized: + msg = "Client not initialized" + raise exceptions.LLMLingError(msg) + + async def shutdown(self) -> None: + """Clean up resources.""" + if not self._initialized: + return + + try: + if self._processor_registry: + await self._processor_registry.shutdown() + except Exception as exc: + logger.exception("Error during shutdown") + msg = f"Failed to shutdown client: {exc}" + raise exceptions.LLMLingError(msg) from exc + finally: + self._initialized = False + logger.info("Client shut down successfully") + async def __aenter__(self) -> Self: """Async context manager entry.""" await self.startup() @@ -377,12 +454,35 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: def __enter__(self) -> Self: """Synchronous context manager entry.""" - asyncio.run(self.startup()) - return self + try: + # Create new event loop for sync operation + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(self.startup()) + return self + finally: + loop.close() + asyncio.set_event_loop(None) + except Exception as exc: + msg = "Failed to enter context" + raise exceptions.LLMLingError(msg) from exc def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Synchronous context manager exit.""" - asyncio.run(self.shutdown()) + try: + # Create new event loop for sync operation + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(self.shutdown()) + finally: + loop.close() + asyncio.set_event_loop(None) + except Exception as exc: + logger.exception("Error during context exit") + msg = "Failed to exit context" + raise exceptions.LLMLingError(msg) from exc async def async_example() -> None: diff --git a/src/llmling/config/models.py b/src/llmling/config/models.py index 63162f8..6b829af 100644 --- a/src/llmling/config/models.py +++ b/src/llmling/config/models.py @@ -5,7 +5,7 @@ from collections.abc import Sequence as TypingSequence # noqa: TCH003 from typing import Any, Literal -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from llmling.core.typedefs import ProcessingStep # noqa: TCH001 from llmling.processors.base import ProcessorConfig # noqa: TCH001 @@ -28,9 +28,19 @@ class LLMProviderConfig(BaseModel): temperature: float | None = None max_tokens: int | None = None top_p: float | None = None + tools: dict[str, dict[str, Any]] | list[str] | None = None # Allow both formats + tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051 model_config = ConfigDict(frozen=True) + @field_validator("tools", mode="before") + @classmethod + def convert_tools(cls, v: Any) -> dict[str, dict[str, Any]] | None: + """Convert tool references to dictionary format.""" + if isinstance(v, list): + return {tool: {} for tool in v} + return v + @model_validator(mode="after") def validate_model_format(self) -> LLMProviderConfig: """Validate that model follows provider/name format.""" @@ -41,11 +51,13 @@ def validate_model_format(self) -> LLMProviderConfig: class TaskSettings(BaseModel): - """Settings for a specific task.""" + """Settings for a task.""" temperature: float | None = None max_tokens: int | None = None top_p: float | None = None + tools: list[str] | None = None # Add tools field + tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051 model_config = ConfigDict(frozen=True) @@ -158,7 +170,20 @@ class TaskTemplate(BaseModel): provider: str # provider name or group name context: str # context name or group name settings: TaskSettings | None = None - inherit_tools: bool = True + # Make tool-related fields optional with None defaults + inherit_tools: bool | None = None + tools: list[str] | None = None + tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051 + + model_config = ConfigDict(frozen=True) + + +class ToolConfig(BaseModel): + """Configuration for a tool.""" + + import_path: str + name: str | None = None + description: str | None = None model_config = ConfigDict(frozen=True) @@ -174,6 +199,7 @@ class Config(BaseModel): contexts: dict[str, Context] context_groups: dict[str, list[str]] = Field(default_factory=dict) task_templates: dict[str, TaskTemplate] + tools: dict[str, ToolConfig] = Field(default_factory=dict) model_config = ConfigDict( frozen=True, diff --git a/src/llmling/config/validation.py b/src/llmling/config/validation.py index 8aa7563..c909138 100644 --- a/src/llmling/config/validation.py +++ b/src/llmling/config/validation.py @@ -170,3 +170,23 @@ def validate_references(self) -> list[str]: ) return warnings + + def validate_tools(self) -> list[str]: + """Validate tool configuration.""" + warnings: list[str] = [] + + # Skip tool validation if tools aren't configured + if not self.config.tools: + return warnings + + # Validate tool references in tasks + for name, template in self.config.task_templates.items(): + if not template.tools: + continue + warnings.extend( + f"Tool {tool} referenced in task {name} not found" + for tool in template.tools + if tool not in self.config.tools + ) + + return warnings diff --git a/src/llmling/llm/base.py b/src/llmling/llm/base.py index e29d50f..9204845 100644 --- a/src/llmling/llm/base.py +++ b/src/llmling/llm/base.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: - from collections.abc import AsyncGenerator + from collections.abc import AsyncGenerator, AsyncIterator class LLMConfig(BaseModel): @@ -31,17 +31,35 @@ class LLMConfig(BaseModel): timeout: int = 30 max_retries: int = 3 streaming: bool = False + # New fields for tool support + tools: list[dict[str, Any]] | None = None + tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051 + + model_config = ConfigDict(frozen=True) MessageRole = Literal["system", "user", "assistant"] """Valid message roles for chat completion.""" +class ToolCall(BaseModel): + """A tool call request from the LLM.""" + + id: str # Required by OpenAI + name: str + parameters: dict[str, Any] + + model_config = ConfigDict(frozen=True) + + class Message(BaseModel): """A chat message.""" - role: MessageRole + role: Literal["system", "user", "assistant", "tool"] content: str + name: str | None = None # For tool messages + tool_calls: list[ToolCall] | None = None # For assistant messages + tool_call_id: str | None = None # For tool response messages model_config = ConfigDict(frozen=True) @@ -52,7 +70,7 @@ class CompletionResult(BaseModel): content: str model: str finish_reason: str | None = None - is_stream_chunk: bool = False + tool_calls: list[ToolCall] | None = None metadata: dict[str, Any] = Field(default_factory=dict) model_config = ConfigDict(frozen=True) @@ -201,7 +219,7 @@ async def _complete_stream_impl( self, messages: list[Message], **kwargs: Any, - ) -> AsyncGenerator[CompletionResult, None]: + ) -> AsyncIterator[CompletionResult]: """Implement actual streaming completion logic.""" yield NotImplemented # pragma: no cover diff --git a/src/llmling/llm/providers/litellm.py b/src/llmling/llm/providers/litellm.py index b79a6d6..7692c7d 100644 --- a/src/llmling/llm/providers/litellm.py +++ b/src/llmling/llm/providers/litellm.py @@ -7,13 +7,11 @@ import litellm from llmling.core import exceptions -from llmling.llm.base import CompletionResult, RetryableProvider +from llmling.llm.base import CompletionResult, Message, RetryableProvider, ToolCall if TYPE_CHECKING: - from collections.abc import AsyncGenerator - - from llmling.llm.base import Message + from collections.abc import AsyncIterator class LiteLLMProvider(RetryableProvider): @@ -26,9 +24,28 @@ async def _complete_impl( ) -> CompletionResult: """Implement completion using LiteLLM.""" try: + # Convert messages to dict format, explicitly handling tool_calls + messages_dict = [] + for msg in messages: + msg_dict: dict[str, Any] = { + "role": msg.role, + "content": msg.content, + } + if msg.name: + msg_dict["name"] = msg.name + if msg.tool_calls: + msg_dict["tool_calls"] = [tc.model_dump() for tc in msg.tool_calls] + messages_dict.append(msg_dict) + + # Add tool configuration if present and provider supports it + if self.config.tools and not self._is_local_provider(): + kwargs["tools"] = self.config.tools + if self.config.tool_choice is not None: + kwargs["tool_choice"] = self.config.tool_choice + response = await litellm.acompletion( model=self.config.model, - messages=[msg.model_dump() for msg in messages], + messages=messages_dict, temperature=self.config.temperature, max_tokens=self.config.max_tokens, top_p=self.config.top_p, @@ -36,10 +53,25 @@ async def _complete_impl( **kwargs, ) + # Handle tool calls if present + tool_calls = None + if hasattr(response.choices[0].message, "tool_calls"): + tc = response.choices[0].message.tool_calls + if tc: + tool_calls = [ + ToolCall( + id=call.id, + name=call.function.name, + parameters=call.function.arguments, + ) + for call in tc + ] + return CompletionResult( - content=response.choices[0].message.content, + content=response.choices[0].message.content or "", model=response.model, finish_reason=response.choices[0].finish_reason, + tool_calls=tool_calls, metadata={ "provider": "litellm", "usage": response.usage.model_dump(), @@ -47,19 +79,42 @@ async def _complete_impl( ) except Exception as exc: - msg = f"LiteLLM completion failed: {exc}" - raise exceptions.LLMError(msg) from exc + msg_ = f"LiteLLM completion failed: {exc}" + raise exceptions.LLMError(msg_) from exc + + def _is_local_provider(self) -> bool: + """Check if the current model is a local provider (like Ollama).""" + return self.config.model.startswith(("ollama/", "local/")) async def _complete_stream_impl( self, messages: list[Message], **kwargs: Any, - ) -> AsyncGenerator[CompletionResult, None]: + ) -> AsyncIterator[CompletionResult]: """Implement streaming completion using LiteLLM.""" try: + # Convert messages to dict format, same as above + messages_dict = [] + for msg in messages: + msg_dict: dict[str, Any] = { + "role": msg.role, + "content": msg.content, + } + if msg.name: + msg_dict["name"] = msg.name + if msg.tool_calls: + msg_dict["tool_calls"] = [tc.model_dump() for tc in msg.tool_calls] + messages_dict.append(msg_dict) + + # Add tool configuration if present and provider supports it + if self.config.tools and not self._is_local_provider(): + kwargs["tools"] = self.config.tools + if self.config.tool_choice is not None: + kwargs["tool_choice"] = self.config.tool_choice + response_stream = await litellm.acompletion( model=self.config.model, - messages=[msg.model_dump() for msg in messages], + messages=messages_dict, temperature=self.config.temperature, max_tokens=self.config.max_tokens, top_p=self.config.top_p, @@ -72,16 +127,17 @@ async def _complete_stream_impl( if not chunk.choices[0].delta.content: continue + # Tool calls aren't supported in streaming mode yet yield CompletionResult( content=chunk.choices[0].delta.content, model=chunk.model, finish_reason=chunk.choices[0].finish_reason, - is_stream_chunk=True, metadata={ "provider": "litellm", + "chunk": True, }, ) - except Exception as exc: - msg = f"LiteLLM streaming failed: {exc}" - raise exceptions.LLMError(msg) from exc + except Exception as e: + error_msg = f"LiteLLM streaming failed: {e}" + raise exceptions.LLMError(error_msg) from e diff --git a/src/llmling/resources/test.yml b/src/llmling/resources/test.yml index b54a896..973186f 100644 --- a/src/llmling/resources/test.yml +++ b/src/llmling/resources/test.yml @@ -1,31 +1,21 @@ # Version of the configuration schema -# Used for compatibility checking and migrations version: "1.0" -# Global settings that apply to all components unless overridden +# Global settings that apply to all components global_settings: - timeout: 30 # Global timeout in seconds - max_retries: 3 # Default retry count - temperature: 0.7 # Default temperature for all LLMs + timeout: 30 + max_retries: 3 + temperature: 0.7 -# Reusable text processors that can be referenced in context definitions -# Each processor either references a Python function or defines a Jinja template +# Context processors definitions context_processors: - # Clean and standardize Python code python_cleaner: type: function import_path: llmling.testing.processors.uppercase_text - # Function must take str input and return str output - # Additional parameters can be passed via kwargs in context definitions - - # Remove sensitive information sanitize: type: function import_path: llmling.testing.processors.multiply - # Will be called with content as first argument - # kwargs can include patterns to remove, replacement text, etc. - # Add metadata header to any content add_metadata: type: template template: | @@ -34,74 +24,83 @@ context_processors: # Version: {{ version }} {{ content }} - # Templates always have access to: - # - content: the text being processed - # - now(): current datetime - # - env: environment variables - # Plus any kwargs passed in the context definition -# LLM provider configurations using litellm format +# Global tools definitions +tools: + analyze_code: + import_path: "llmling.tools.code.analyze_complexity" + description: "Analyze Python code complexity metrics" + analyze_ast: + import_path: "llmling.testing.tools.analyze_ast" + description: "Analyze Python code structure" + format_code: + import_path: "llmling.tools.code.format_code" + name: "format_python" + +# LLM provider configurations llm_providers: - # GPT-4 Turbo configuration gpt4-turbo: name: "GPT-4 Turbo" model: openai/gpt-4-1106-preview temperature: 0.8 max_tokens: 4096 top_p: 0.95 - # Can include any parameter supported by the provider + # Can now use either format: + tools: + - analyze_code # Simple string reference + - analyze_ast + - format_code + # Or with settings: + # tools: + # analyze_code: {} + # analyze_ast: {max_lines: 1000} + # format_code: {style: "black"} - # Anthropic Claude configuration claude2: name: "Claude 2" model: anthropic/claude-2 temperature: 0.7 max_tokens: 8192 - # Local model configuration local-llama: name: "Local Llama" model: ollama/smollm2:360m temperature: 0.7 max_tokens: 2048 - # Local models can use the same configuration format -# Groups of providers for different use cases or fallback chains +# Provider groups provider_groups: - # High-quality responses for code review code_review: - gpt4-turbo - - claude2 # Fallback if first model fails + - claude2 - # Cost-effective group for draft content draft_content: - local-llama - # Fallback chain from best to most reliable fallback_chain: - - gpt4-turbo # First choice - - claude2 # Second choice - - local-llama # Last resort + - gpt4-turbo + - claude2 + - local-llama -# Context definitions - sources of text content +# Context definitions contexts: - # Loading from a URL python_guidelines: - type: path # Can be URL or file path + type: path path: "https://example.com/python-guidelines.md" description: "Python coding standards and best practices" - processors: # Chain of processors to apply + processors: - name: sanitize keyword_args: { remove_emails: true } - name: add_metadata keyword_args: source: "company guidelines" version: "1.2.3" + my_utils: type: source import_path: "my_project.utils" description: "Utility module source code" - recursive: true # Include all submodules + recursive: true processors: - name: python_cleaner @@ -109,8 +108,9 @@ contexts: type: source import_path: "my_project.models.user" description: "User model implementation" - recursive: false # Default, only this module - include_tests: false # Don't include test files + recursive: false + include_tests: false + system_info: type: callable import_path: "my_project.utils.system_diagnostics.get_info" @@ -118,7 +118,7 @@ contexts: keyword_args: include_memory: true include_disk: true - # Local file template + code_review_template: type: path path: "./templates/code_review.txt" @@ -126,7 +126,6 @@ contexts: processors: - name: python_cleaner - # Raw text content system_prompt: type: text content: | @@ -144,53 +143,49 @@ contexts: or meta-commentary. description: "Test prompt for consistent output" - # Dynamic content from CLI git_diff: type: cli command: "git diff HEAD~1" description: "Current git changes" - shell: true # Execute in shell + shell: true processors: - name: python_cleaner - # Process the git diff output before use -# Groups of related contexts +# Context groups context_groups: - # Basic code review contexts code_review_basic: - system_prompt - code_review_template - # Advanced code review with additional contexts code_review_advanced: - system_prompt - code_review_template - python_guidelines - git_diff -# Task templates combining providers and contexts +# Task templates task_templates: - # Simple code review task quick_review: - provider: local-llama # Single provider - context: system_prompt # Context group - settings: # Task-specific settings + provider: local-llama + context: system_prompt + inherit_tools: false # Explicitly set + tools: [] # Empty list for testing + settings: temperature: 0.7 max_tokens: 2048 + tools: [] # Empty list for testing + tool_choice: "auto" - # Comprehensive code review detailed_review: - provider: code_review # Provider group - context: code_review_advanced # Context group + provider: code_review + context: code_review_advanced settings: temperature: 0.5 max_tokens: 4096 - # These settings override both global and provider settings - # Draft generation task generate_draft: - provider: draft_content # Provider group - context: system_prompt # Single context + provider: draft_content + context: system_prompt settings: temperature: 0.9 max_tokens: 2048 diff --git a/src/llmling/task/executor.py b/src/llmling/task/executor.py index 5ba340c..ac10e8f 100644 --- a/src/llmling/task/executor.py +++ b/src/llmling/task/executor.py @@ -10,6 +10,7 @@ from llmling.core.log import get_logger from llmling.llm.base import LLMConfig, Message from llmling.task.models import TaskContext, TaskProvider, TaskResult +from llmling.tools.base import ToolRegistry if TYPE_CHECKING: @@ -31,6 +32,7 @@ def __init__( context_registry: ContextLoaderRegistry, processor_registry: ProcessorRegistry, provider_registry: ProviderRegistry, + tool_registry: ToolRegistry | None = None, *, default_timeout: int = 30, default_max_retries: int = 3, @@ -41,15 +43,56 @@ def __init__( context_registry: Registry of context loaders processor_registry: Registry of processors provider_registry: Registry of LLM providers + tool_registry: Registry of LLM model tools default_timeout: Default timeout for LLM calls default_max_retries: Default retry count for LLM calls """ self.context_registry = context_registry self.processor_registry = processor_registry self.provider_registry = provider_registry + self.tool_registry = tool_registry or ToolRegistry() self.default_timeout = default_timeout self.default_max_retries = default_max_retries + def _prepare_tool_config( + self, + task_context: TaskContext, + task_provider: TaskProvider, + ) -> dict[str, Any] | None: + """Prepare tool configuration if tools are enabled.""" + if not self.tool_registry: + return None + + available_tools = [] + + # Add inherited tools from provider if enabled + if ( + task_context.inherit_tools + and task_provider.settings + and task_provider.settings.tools + ): + available_tools.extend(task_provider.settings.tools) + + # Add task-specific tools + if task_context.tools: + available_tools.extend( + self.tool_registry.get_schema(tool) for tool in task_context.tools + ) + + if not available_tools: + return None + + return { + "tools": available_tools, + "tool_choice": ( + task_context.tool_choice + or ( + task_provider.settings.tool_choice if task_provider.settings else None + ) + or "auto" + ), + } + @logfire.instrument( "Executing task with provider {task_provider.name}, model {task_provider.model}" ) @@ -62,31 +105,53 @@ async def execute( ) -> TaskResult: """Execute a task.""" try: + # Add tool configuration if available + if tool_config := self._prepare_tool_config(task_context, task_provider): + kwargs.update(tool_config) # Load and process context context_result = await self._load_context(task_context) # Prepare messages - messages = self._prepare_messages( - context_result.content, - system_prompt, - ) + messages = self._prepare_messages(context_result.content, system_prompt) # Configure and create provider llm_config = self._create_llm_config(task_provider) provider = self.provider_registry.create_provider( - task_provider.name, # This is the lookup key + task_provider.name, llm_config, ) - # Get completion - completion = await provider.complete(messages, **kwargs) - - return TaskResult( - content=completion.content, - model=completion.model, - context_metadata=context_result.metadata, - completion_metadata=completion.metadata, - ) + # Get completion with potential tool calls + while True: + completion = await provider.complete(messages, **kwargs) + + # Handle tool calls if present + if completion.tool_calls: + tool_results = [] + for tool_call in completion.tool_calls: + result = await self.tool_registry.execute( + tool_call.name, + **tool_call.parameters, + ) + tool_results.append(result) + + # Add tool results to messages + messages.append( + Message( + role="tool", + content=str(tool_results), + name="tool_results", + ) + ) + continue # Get next completion + + # No tool calls, return final result + return TaskResult( + content=completion.content, + model=completion.model, + context_metadata=context_result.metadata, + completion_metadata=completion.metadata, + ) except Exception as exc: msg = "Task execution failed" diff --git a/src/llmling/task/manager.py b/src/llmling/task/manager.py index 38f8b57..8b6a7c0 100644 --- a/src/llmling/task/manager.py +++ b/src/llmling/task/manager.py @@ -31,6 +31,22 @@ def __init__( self.config = config self.executor = executor + # Register tools only if they exist in config + if self.config.tools: + self._register_tools() + + def _register_tools(self) -> None: + """Register tools from configuration.""" + if not self.config.tools: + return + + for tool_id, tool_config in self.config.tools.items(): + self.executor.tool_registry.register_path( + import_path=tool_config.import_path, + name=tool_config.name or tool_id, + description=tool_config.description, + ) + def _prepare_task( self, template_name: str, @@ -41,12 +57,16 @@ def _prepare_task( context = self._resolve_context(template) provider_name, provider_config = self._resolve_provider(template) + # Create task context with proper tool configuration task_context = TaskContext( context=context, processors=context.processors, - inherit_tools=template.inherit_tools, + inherit_tools=template.inherit_tools or False, + tools=template.tools, + tool_choice=template.tool_choice, ) + # Create task provider with settings task_provider = TaskProvider( name=provider_name, model=provider_config.model, @@ -63,13 +83,18 @@ async def execute_template( **kwargs: Any, ) -> TaskResult: """Execute a task template.""" - task_context, task_provider = self._prepare_task(template_name, system_prompt) - return await self.executor.execute( - task_context, - task_provider, - system_prompt=system_prompt, - **kwargs, - ) + try: + task_context, task_provider = self._prepare_task(template_name, system_prompt) + return await self.executor.execute( + task_context, + task_provider, + system_prompt=system_prompt, + **kwargs, + ) + except Exception as exc: + logger.exception("Task execution failed") + msg = f"Task execution failed for template {template_name}: {exc}" + raise exceptions.TaskError(msg) from exc async def execute_template_stream( self, @@ -78,14 +103,19 @@ async def execute_template_stream( **kwargs: Any, ) -> AsyncIterator[TaskResult]: """Execute a task template with streaming results.""" - task_context, task_provider = self._prepare_task(template_name, system_prompt) - async for result in self.executor.execute_stream( - task_context, - task_provider, - system_prompt=system_prompt, - **kwargs, - ): - yield result + try: + task_context, task_provider = self._prepare_task(template_name, system_prompt) + async for result in self.executor.execute_stream( + task_context, + task_provider, + system_prompt=system_prompt, + **kwargs, + ): + yield result + except Exception as exc: + logger.exception("Task streaming failed") + msg = f"Task streaming failed: {exc}" + raise exceptions.TaskError(msg) from exc def _get_template(self, name: str) -> TaskTemplate: """Get a task template by name.""" diff --git a/src/llmling/task/models.py b/src/llmling/task/models.py index 0dbf1b9..1687cc3 100644 --- a/src/llmling/task/models.py +++ b/src/llmling/task/models.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, ConfigDict @@ -15,7 +15,9 @@ class TaskContext(BaseModel): context: Context processors: list[ProcessingStep] - inherit_tools: bool = True + inherit_tools: bool = False # Set default value to False + tools: list[str] | None = None + tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051 model_config = ConfigDict(frozen=True) diff --git a/src/llmling/testing/tools.py b/src/llmling/testing/tools.py new file mode 100644 index 0000000..ea1f880 --- /dev/null +++ b/src/llmling/testing/tools.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import ast + + +async def failing_tool(text: str) -> str: + """Tool that always fails.""" + msg = "Intentional failure" + raise ValueError(msg) + + +async def example_tool(text: str, repeat: int = 1) -> str: + """Example tool that repeats text. + + Args: + text: Text to repeat + repeat: Number of times to repeat + + Returns: + The repeated text + """ + return text * repeat + + +async def analyze_ast(code: str) -> dict[str, int]: + """Analyze Python code AST. + + Args: + code: Python source code to analyze + + Returns: + Dictionary with analysis results + """ + tree = ast.parse(code) + return { + "classes": len([n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)]), + "functions": len([n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)]), + } diff --git a/src/llmling/tools/__init__.py b/src/llmling/tools/__init__.py new file mode 100644 index 0000000..3862355 --- /dev/null +++ b/src/llmling/tools/__init__.py @@ -0,0 +1,14 @@ +"""Tool system for LLMling.""" + +from __future__ import annotations + +from llmling.tools.base import BaseTool, ToolRegistry, DynamicTool +from llmling.tools.exceptions import ToolError, ToolExecutionError + +__all__ = [ + "BaseTool", + "DynamicTool", + "ToolRegistry", + "ToolError", + "ToolExecutionError", +] diff --git a/src/llmling/tools/base.py b/src/llmling/tools/base.py new file mode 100644 index 0000000..005639e --- /dev/null +++ b/src/llmling/tools/base.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar + +import py2openai +from pydantic import BaseModel, ConfigDict + +from llmling.tools.exceptions import ToolError +from llmling.utils import calling + + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + +class ToolSchema(BaseModel): + """OpenAPI-compatible schema for a tool.""" + + type: str = "function" + function: dict[str, Any] + + model_config = ConfigDict(frozen=True) + + +class BaseTool(ABC): + """Base class for implementing complex tools that need state or custom logic.""" + + # Class-level schema definition + name: ClassVar[str] + description: ClassVar[str] + parameters_schema: ClassVar[dict[str, Any]] + + @classmethod + def get_schema(cls) -> ToolSchema: + """Get the tool's schema for LLM function calling.""" + return ToolSchema( + type="function", + function={ + "name": cls.name, + "description": cls.description, + "parameters": cls.parameters_schema, + }, + ) + + @abstractmethod + async def execute(self, **params: Any) -> Any | Awaitable[Any]: + """Execute the tool with given parameters.""" + + +class DynamicTool: + """Tool created from a function import path.""" + + def __init__( + self, + import_path: str, + name: str | None = None, + description: str | None = None, + ) -> None: + """Initialize tool from import path.""" + self.import_path = import_path + self._func: Callable[..., Any] | None = None + self._name = name + self._description = description + + @property + def name(self) -> str: + """Get tool name.""" + if self._name: + return self._name + return self.import_path.split(".")[-1] + + @property + def description(self) -> str: + """Get tool description.""" + if self._description: + return self._description + if self.func.__doc__: + return self.func.__doc__.strip() + return f"Tool imported from {self.import_path}" + + @property + def func(self) -> Callable[..., Any]: + """Get the imported function.""" + if self._func is None: + self._func = calling.import_callable(self.import_path) + return self._func + + def get_schema(self) -> ToolSchema: + """Generate schema from function signature.""" + func_schema = py2openai.create_schema(self.func) + schema_dict = func_schema.model_dump_openai() + + # Override description if custom one is provided + if self._description: + schema_dict["description"] = self._description + + return ToolSchema( + type="function", + function=schema_dict, + ) + + async def execute(self, **params: Any) -> Any: + """Execute the function.""" + return await calling.execute_callable(self.import_path, **params) + + +class ToolRegistry: + """Registry for available tools.""" + + def __init__(self) -> None: + """Initialize an empty registry.""" + self._tools: dict[str, BaseTool | DynamicTool] = {} + + def has_tool(self, name: str) -> bool: + """Check if a tool is registered.""" + return name in self._tools + + def is_empty(self) -> bool: + """Check if registry has any tools.""" + return not bool(self._tools) + + def register(self, tool: type[BaseTool] | BaseTool) -> None: + """Register a tool class or instance.""" + if isinstance(tool, type): + instance = tool() + self._tools[tool.name] = instance + else: + self._tools[tool.name] = tool + + def register_path( + self, + import_path: str, + name: str | None = None, + description: str | None = None, + ) -> None: + """Register a tool from import path.""" + tool = DynamicTool( + import_path=import_path, + name=name, + description=description, + ) + if tool.name in self._tools: + msg = f"Tool already registered: {tool.name}" + raise ToolError(msg) + self._tools[tool.name] = tool + + def get_tool(self, name: str) -> DynamicTool | BaseTool: + """Get a tool by name.""" + try: + return self._tools[name] + except KeyError as exc: + msg = f"Tool not found: {name}" + raise ToolError(msg) from exc + + def get_schema(self, name: str) -> ToolSchema: + """Get schema for a tool.""" + tool = self.get_tool(name) + return tool.get_schema() + + def list_tools(self) -> list[str]: + """List all registered tool names.""" + return list(self._tools.keys()) + + async def execute(self, name: str, **params: Any) -> Any: + """Execute a tool by name.""" + tool = self.get_tool(name) + return await tool.execute(**params) diff --git a/src/llmling/tools/code.py b/src/llmling/tools/code.py new file mode 100644 index 0000000..f95d779 --- /dev/null +++ b/src/llmling/tools/code.py @@ -0,0 +1,64 @@ +"""Code analysis tools.""" + +from __future__ import annotations + +import ast +from typing import Any + + +async def analyze_code(code: str) -> dict[str, Any]: + """Analyze Python code complexity and structure. + + Args: + code: Python code to analyze + + Returns: + Dictionary containing analysis metrics including: + - Number of classes + - Number of functions + - Number of imports + - Total lines of code + """ + try: + tree = ast.parse(code) + return { + "classes": len([n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)]), + "functions": len([ + n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef) + ]), + "imports": len([n for n in ast.walk(tree) if isinstance(n, ast.Import)]), + "lines": len(code.splitlines()), + } + except SyntaxError as exc: + msg = f"Invalid Python code: {exc}" + raise ValueError(msg) from exc + + +async def count_tokens( + text: str, + model: str = "gpt-4", +) -> dict[str, Any]: + """Count the approximate number of tokens in text. + + Args: + text: Text to analyze + model: Model to use for tokenization (default: gpt-4) + + Returns: + Dictionary containing: + - token_count: Number of tokens + - model: Model used for tokenization + """ + import tiktoken + + try: + encoding = tiktoken.encoding_for_model(model) + token_count = len(encoding.encode(text)) + except Exception as exc: + msg = f"Token counting failed: {exc}" + raise ValueError(msg) from exc + else: + return { + "token_count": token_count, + "model": model, + } diff --git a/src/llmling/tools/exceptions.py b/src/llmling/tools/exceptions.py new file mode 100644 index 0000000..6e28f1a --- /dev/null +++ b/src/llmling/tools/exceptions.py @@ -0,0 +1,21 @@ +"""Exceptions for the tool system.""" + +from __future__ import annotations + +from llmling.core.exceptions import LLMLingError + + +class ToolError(LLMLingError): + """Base exception for tool-related errors.""" + + +class ToolExecutionError(ToolError): + """Error during tool execution.""" + + +class ToolNotFoundError(ToolError): + """Tool not found in registry.""" + + +class ToolValidationError(ToolError): + """Tool parameter validation error.""" diff --git a/src/llmling/utils/calling.py b/src/llmling/utils/calling.py index a8633c9..f3c4c02 100644 --- a/src/llmling/utils/calling.py +++ b/src/llmling/utils/calling.py @@ -8,40 +8,67 @@ if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Awaitable, Callable -def is_async_callable(obj: Any) -> TypeGuard[Callable[..., Any]]: +def is_async_callable(obj: Any) -> TypeGuard[Callable[..., Awaitable[Any]]]: """Check if an object is an async callable.""" return asyncio.iscoroutinefunction(obj) -async def execute_callable(import_path: str, **kwargs: Any) -> str: - """Execute a callable and return its result as a string.""" +def import_callable(import_path: str) -> Callable[..., Any]: + """Import a callable from an import path. + + Args: + import_path: Dot-separated path to callable (e.g., "module.submodule.func") + + Returns: + The imported callable + + Raises: + ValueError: If import fails or object is not callable + """ try: - # Import the callable module_path, callable_name = import_path.rsplit(".", 1) module = importlib.import_module(module_path) callable_obj = getattr(module, callable_name) + if not callable(callable_obj): + msg = f"Imported object {import_path} is not callable" + raise ValueError(msg) # noqa: TRY004 + except ImportError as exc: + msg = f"Could not import callable: {import_path}" + raise ValueError(msg) from exc + except AttributeError as exc: + msg = f"Could not find callable {import_path!r} " + raise ValueError(msg) from exc + else: + return callable_obj + + +async def execute_callable(import_path: str, **kwargs: Any) -> Any: + """Execute a callable and return its result. + + Args: + import_path: Dot-separated path to callable + **kwargs: Arguments to pass to the callable + + Returns: + Result of the callable execution + + Raises: + ValueError: If import or execution fails + """ + try: + callable_obj = import_callable(import_path) + # Execute the callable if is_async_callable(callable_obj): result = await callable_obj(**kwargs) else: result = callable_obj(**kwargs) - - # Convert result to string - if isinstance(result, str): - return result - if isinstance(result, list | dict | set | tuple): - import json - - return json.dumps(result, indent=2, default=str) - return str(result) - - except ImportError as exc: - msg = f"Could not import callable: {import_path}" - raise ValueError(msg) from exc - except (TypeError, ValueError) as exc: + except Exception as exc: msg = f"Error executing callable {import_path}: {exc}" raise ValueError(msg) from exc + else: + return result diff --git a/tests/test_client.py b/tests/test_client.py index 60b1967..ba27e5c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -30,14 +30,14 @@ MAX_CONCURRENT_TASKS = 3 # LLM output related constants -MAX_CONTENT_DIFF_RATIO = 0.5 # Allow 50% difference between streaming and non-streaming -MIN_CONTENT_LENGTH = 10 # Minimum expected content length -MAX_RETRIES = 3 # Maximum number of retries for consistency test +MAX_CONTENT_DIFF_RATIO = 0.5 +MIN_CONTENT_LENGTH = 10 +MAX_RETRIES = 3 -STREAM_TIMEOUT = 30.0 # Maximum time to wait for streaming -MIN_CHUNKS = 1 # Minimum number of chunks expected -MIN_CHUNK_LENGTH = 1 # Minimum length of each chunk -TEST_TEMPLATE = "quick_review" # Template known to work +STREAM_TIMEOUT = 30.0 +MIN_CHUNKS = 1 +MIN_CHUNK_LENGTH = 1 +TEST_TEMPLATE = "quick_review" # Mock response for LLM calls MOCK_RESPONSE = CompletionResult( @@ -48,6 +48,7 @@ ) +# Common fixtures @pytest.fixture def config_path() -> Path: """Provide path to test configuration file.""" @@ -67,38 +68,52 @@ def components() -> dict[ComponentType, dict[str, Any]]: import_path="llmling.testing.processors.uppercase_text", ), }, + "tool": { + "test_tool": { + "name": "test_tool", + "description": "Test tool", + "import_path": "llmling.testing.tools.example_tool", + } + }, } -@pytest.fixture -def mock_llm_response() -> CompletionResult: - """Provide mock LLM response.""" - return MOCK_RESPONSE - - @pytest.fixture def mock_provider(): - """Mock LLM provider.""" + """Mock LLM provider with proper async support.""" with mock.patch("llmling.llm.registry.ProviderRegistry.create_provider") as m: provider = mock.AsyncMock() - provider.complete.return_value = MOCK_RESPONSE - # Properly mock the streaming response - async def mock_stream(*args, **kwargs): + async def mock_complete(*args, tools=None, tool_choice=None, **kwargs): + return MOCK_RESPONSE + + provider.complete = mock_complete + + async def mock_stream(*args, tools=None, tool_choice=None, **kwargs): yield MOCK_RESPONSE provider.complete_stream = mock_stream + + provider.model = MOCK_RESPONSE.model m.return_value = provider + + from llmling.llm.registry import default_registry + + default_registry.reset() + default_registry.register_provider("local-llama", "litellm") + yield provider + default_registry.reset() + @pytest.fixture -async def client( +async def async_client( config_path: Path, components: dict[ComponentType, dict[str, Any]], mock_provider, ) -> AsyncGenerator[LLMLingClient, None]: - """Provide initialized LLMLing client.""" + """Provide initialized client for async tests.""" client = LLMLingClient( config_path, log_level=TEST_LOG_LEVEL, @@ -113,7 +128,7 @@ async def client( @pytest.mark.unit class TestClientCreation: - """Test client initialization and context managers.""" + """Test client initialization and synchronous operations.""" def test_create_sync( self, @@ -122,9 +137,25 @@ def test_create_sync( mock_provider, ) -> None: """Test synchronous client creation.""" + # Ensure no event loop exists + asyncio.set_event_loop(None) + client = LLMLingClient.create(config_path, components=components) - assert isinstance(client, LLMLingClient) - assert client._initialized + try: + assert isinstance(client, LLMLingClient) + assert client._initialized + + # Test sync execution works + result = client.execute_sync("quick_review") + assert isinstance(result, TaskResult) + assert result.content == MOCK_RESPONSE.content + finally: + # Clean up synchronously + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(client.shutdown()) + finally: + loop.close() def test_sync_context_manager( self, @@ -133,12 +164,25 @@ def test_sync_context_manager( mock_provider, ) -> None: """Test synchronous context manager.""" + # Ensure no event loop exists + asyncio.set_event_loop(None) + with LLMLingClient.create(config_path, components=components) as client: result = client.execute_sync("quick_review") assert isinstance(result, TaskResult) assert result.content == MOCK_RESPONSE.content - @pytest.mark.asyncio + def test_invalid_config_sync(self) -> None: + """Test synchronous initialization with invalid configuration.""" + with pytest.raises(exceptions.LLMLingError): + LLMLingClient.create(NONEXISTENT_CONFIG_PATH) + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestAsyncOperations: + """Test asynchronous operations.""" + async def test_async_context_manager( self, config_path: Path, @@ -151,32 +195,18 @@ async def test_async_context_manager( assert isinstance(result, TaskResult) assert result.content == MOCK_RESPONSE.content - @pytest.mark.asyncio - async def test_client_invalid_config(self) -> None: - """Test client initialization with invalid configuration.""" - client = LLMLingClient(NONEXISTENT_CONFIG_PATH) - with pytest.raises(exceptions.LLMLingError): - await client.startup() - - -@pytest.mark.unit -class TestMockedTaskExecution: - """Unit tests with mocked LLM responses.""" - - @pytest.mark.asyncio - async def test_execute_single_task(self, client: LLMLingClient) -> None: - """Test executing a single task with mocked LLM.""" - result = await client.execute( + async def test_execute_single_task(self, async_client: LLMLingClient) -> None: + """Test executing a single task.""" + result = await async_client.execute( "quick_review", system_prompt=DEFAULT_SYSTEM_PROMPT, ) assert result.content == MOCK_RESPONSE.content - @pytest.mark.asyncio - async def test_execute_stream(self, client: LLMLingClient) -> None: - """Test streaming execution with mocked LLM.""" + async def test_execute_stream(self, async_client: LLMLingClient) -> None: + """Test streaming execution.""" chunks = [] - stream = await client.execute( + stream = await async_client.execute( "quick_review", system_prompt=DEFAULT_SYSTEM_PROMPT, stream=True, @@ -192,10 +222,9 @@ async def test_execute_stream(self, client: LLMLingClient) -> None: assert len(chunks) >= MIN_CHUNKS - @pytest.mark.asyncio - async def test_concurrent_execution(self, client: LLMLingClient) -> None: - """Test concurrent task execution.""" - results = await client.execute_many( + async def test_concurrent_execution(self, async_client: LLMLingClient) -> None: + """Test concurrent execution.""" + results = await async_client.execute_many( TEST_TEMPLATES, max_concurrent=MAX_CONCURRENT_TASKS, ) @@ -203,16 +232,35 @@ async def test_concurrent_execution(self, client: LLMLingClient) -> None: assert all(isinstance(r, TaskResult) for r in results) assert all(r.content == MOCK_RESPONSE.content for r in results) + +@pytest.mark.unit +class TestErrorHandling: + """Test error handling in both sync and async operations.""" + @pytest.mark.asyncio - async def test_error_handling(self, client: LLMLingClient) -> None: - """Test error handling for invalid templates.""" + async def test_async_error_handling(self, async_client: LLMLingClient) -> None: + """Test async error handling.""" with pytest.raises(LLMLingError): - await client.execute("nonexistent_template") + await async_client.execute("nonexistent_template") + + def test_sync_error_handling( + self, + config_path: Path, + components: dict[ComponentType, dict[str, Any]], + mock_provider, + ) -> None: + """Test sync error handling.""" + asyncio.set_event_loop(None) + with ( + LLMLingClient.create(config_path, components=components) as client, + pytest.raises(LLMLingError), + ): + client.execute_sync("nonexistent_template") @pytest.mark.integration class TestIntegrationTaskExecution: - """Integration tests for task execution.""" + """Integration tests with real LLM.""" @pytest.fixture async def integration_client( @@ -220,7 +268,7 @@ async def integration_client( config_path: Path, components: dict[ComponentType, dict[str, Any]], ) -> AsyncGenerator[LLMLingClient, None]: - """Provide client for integration tests without mocks.""" + """Provide client for integration tests.""" client = LLMLingClient( config_path, log_level=TEST_LOG_LEVEL, @@ -235,7 +283,7 @@ async def integration_client( @pytest.mark.slow @pytest.mark.asyncio async def test_real_llm_execution(self, integration_client: LLMLingClient) -> None: - """Test executing a task with real LLM.""" + """Test with real LLM.""" result = await integration_client.execute( "quick_review", system_prompt=DEFAULT_SYSTEM_PROMPT, @@ -279,22 +327,14 @@ def _validate_task_result(result: TaskResult) -> None: @staticmethod def _validate_chunk(chunk: TaskResult, index: int) -> None: - """Validate individual stream chunk.""" + """Validate streaming chunk.""" try: - assert isinstance(chunk, TaskResult), ( - f"Chunk {index}: Invalid type {type(chunk)}" - ) - assert chunk.model, f"Chunk {index}: Missing model" - assert isinstance(chunk.content, str), f"Chunk {index}: Content is not string" - assert chunk.context_metadata is not None, ( - f"Chunk {index}: Missing context metadata" - ) - assert chunk.completion_metadata is not None, ( - f"Chunk {index}: Missing completion metadata" - ) - assert len(chunk.content) >= MIN_CHUNK_LENGTH, ( - f"Chunk {index}: Content too short ({len(chunk.content)} chars)" - ) + assert isinstance(chunk, TaskResult) + assert chunk.model + assert isinstance(chunk.content, str) + assert chunk.context_metadata is not None + assert chunk.completion_metadata is not None + assert len(chunk.content) >= MIN_CHUNK_LENGTH except AssertionError: print(f"\nChunk {index} Validation Error:") @@ -302,82 +342,3 @@ def _validate_chunk(chunk: TaskResult, index: int) -> None: print(f"Model: {chunk.model}") print(f"Metadata: {chunk.completion_metadata}") raise - - -@pytest.mark.unit -class TestCustomization: - """Test client customization options.""" - - @pytest.mark.asyncio - async def test_custom_components( - self, - config_path: Path, - components: dict[ComponentType, dict[str, Any]], - mock_provider, - ) -> None: - """Test execution with custom components.""" - client = LLMLingClient(config_path, components=components) - await client.startup() - try: - result = await client.execute("quick_review") - assert isinstance(result, TaskResult) - assert result.content == MOCK_RESPONSE.content - finally: - await client.shutdown() - - @pytest.mark.asyncio - async def test_component_registration( - self, - config_path: Path, - components: dict[ComponentType, dict[str, Any]], - mock_provider, - ) -> None: - """Test that components are properly registered.""" - client = LLMLingClient(config_path, components=components) - await client.startup() - - try: - # Verify processor registration - if "processor" in components: - for name in components["processor"]: - assert client.processor_registry - assert name in client.processor_registry._processors - - # Verify successful task execution with custom components - result = await client.execute("quick_review") - assert isinstance(result, TaskResult) - assert result.content - - finally: - await client.shutdown() - - @pytest.mark.asyncio - async def test_invalid_component_type(self, config_path: Path) -> None: - """Test handling of invalid component types.""" - invalid_components = { - "invalid_type": {"test": "value"} # type: ignore - } - - client = LLMLingClient(config_path, components=invalid_components) # type: ignore - await client.startup() - - try: - # Should still work even with invalid component type - result = await client.execute("quick_review") - assert isinstance(result, TaskResult) - finally: - await client.shutdown() - - @pytest.mark.slow - def test_sync_execution(self, config_path: Path) -> None: - """Test synchronous execution methods.""" - with LLMLingClient.create(config_path) as client: - # Test single execution - result = client.execute_sync("quick_review") - assert isinstance(result, TaskResult) - assert result.content - - # Test concurrent execution - results = client.execute_many_sync(TEST_TEMPLATES) - assert len(results) == len(TEST_TEMPLATES) - assert all(isinstance(r, TaskResult) for r in results) diff --git a/tests/test_task_manager.py b/tests/test_task_manager.py index a715a28..9f4cff9 100644 --- a/tests/test_task_manager.py +++ b/tests/test_task_manager.py @@ -13,9 +13,17 @@ @pytest.fixture -def mock_config() -> Config: +def mock_config() -> mock.MagicMock: """Create a mock configuration.""" - return mock.MagicMock(spec=Config) + config = mock.MagicMock(spec=Config) + # Set required attributes + config.tools = {} # Empty dict for no tools + config.llm_providers = {} + config.provider_groups = {} + config.task_templates = {} + config.contexts = {} + config.context_groups = {} + return config @pytest.fixture @@ -32,7 +40,6 @@ def test_resolve_provider_direct(mock_config: mock.MagicMock) -> None: model="test/model", ) mock_config.llm_providers = {"test-provider": provider_config} - mock_config.provider_groups = {} template = TaskTemplate( provider="test-provider", @@ -70,8 +77,7 @@ def test_resolve_provider_group(mock_config: mock.MagicMock) -> None: def test_resolve_provider_not_found(mock_config: mock.MagicMock) -> None: """Test provider resolution failure.""" - mock_config.llm_providers = {} - mock_config.provider_groups = {} + # Config already has empty dicts from fixture template = TaskTemplate( provider="non-existent", @@ -81,3 +87,49 @@ def test_resolve_provider_not_found(mock_config: mock.MagicMock) -> None: manager = TaskManager(mock_config, mock.MagicMock()) with pytest.raises(exceptions.TaskError): manager._resolve_provider(template) + + +# Additional tests for context resolution +def test_resolve_context_direct(mock_config: mock.MagicMock) -> None: + """Test direct context resolution.""" + context = mock.MagicMock() + mock_config.contexts = {"test-context": context} + + template = TaskTemplate( + provider="test-provider", + context="test-context", + ) + + manager = TaskManager(mock_config, mock.MagicMock()) + result = manager._resolve_context(template) + + assert result == context + + +def test_resolve_context_group(mock_config: mock.MagicMock) -> None: + """Test context group resolution.""" + context = mock.MagicMock() + mock_config.contexts = {"test-context": context} + mock_config.context_groups = {"group1": ["test-context"]} + + template = TaskTemplate( + provider="test-provider", + context="group1", + ) + + manager = TaskManager(mock_config, mock.MagicMock()) + result = manager._resolve_context(template) + + assert result == context + + +def test_resolve_context_not_found(mock_config: mock.MagicMock) -> None: + """Test context resolution failure.""" + template = TaskTemplate( + provider="test-provider", + context="non-existent", + ) + + manager = TaskManager(mock_config, mock.MagicMock()) + with pytest.raises(exceptions.TaskError): + manager._resolve_context(template) diff --git a/tests/test_tool_integration.py b/tests/test_tool_integration.py new file mode 100644 index 0000000..7df4395 --- /dev/null +++ b/tests/test_tool_integration.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +from unittest import mock + +import pytest + +from llmling.config.models import ( + Config, + GlobalSettings, + LLMProviderConfig, + TextContext, + ToolConfig, +) +from llmling.context.models import LoadedContext +from llmling.llm.base import CompletionResult +from llmling.task.executor import TaskExecutor +from llmling.task.models import TaskContext, TaskProvider +from llmling.tools.base import ToolRegistry + + +@pytest.fixture +def tool_config() -> Config: + """Create a test configuration with tools.""" + return Config( + version="1.0", + global_settings=GlobalSettings( + timeout=30, + max_retries=3, + temperature=0.7, + ), + context_processors={}, + llm_providers={ + "test_provider": LLMProviderConfig( + name="Test Provider", + model="test/model", + ), + }, + contexts={ + "test_context": TextContext( + type="text", + content="Test content", + description="Test context", + ), + }, + task_templates={}, + tools={ + "analyze": ToolConfig( + import_path="llmling.testing.tools.analyze_ast", + description="Analyze code", + ), + "repeat": ToolConfig( + import_path="llmling.testing.tools.example_tool", + name="repeat_text", + ), + }, + ) + + +@pytest.fixture +def mock_context_registry() -> mock.MagicMock: + """Create mock context registry with async support.""" + registry = mock.MagicMock() + loader = mock.MagicMock() + + # Make load method a coroutine + loader.load = mock.AsyncMock( + return_value=LoadedContext( + content="Test content", + source_type="test", + metadata={}, + ) + ) + + registry.get_loader.return_value = loader + return registry + + +@pytest.fixture +def mock_processor_registry() -> mock.MagicMock: + """Create mock processor registry.""" + registry = mock.MagicMock() + registry.process = mock.AsyncMock() + return registry + + +@pytest.fixture +def mock_provider_registry() -> mock.MagicMock: + """Create mock provider registry with streaming support.""" + registry = mock.MagicMock() + + # Create mock provider with both regular and streaming methods + mock_provider = mock.MagicMock() + + # Regular completion + mock_provider.complete = mock.AsyncMock( + return_value=CompletionResult( + content="Test response", + model="test/model", + metadata={}, + ) + ) + + # Streaming completion will be set by the test that needs it + mock_provider.complete_stream = None + + registry.create_provider.return_value = mock_provider + return registry + + +@pytest.mark.asyncio +async def test_task_with_tools( + tool_config: Config, + mock_context_registry: mock.MagicMock, + mock_processor_registry: mock.MagicMock, + mock_provider_registry: mock.MagicMock, +) -> None: + """Test task execution with tools.""" + # Setup + tool_registry = ToolRegistry() + + # Register tools from config + for tool_id, tool_cfg in tool_config.tools.items(): + tool_registry.register_path( + import_path=tool_cfg.import_path, + name=tool_cfg.name or tool_id, + description=tool_cfg.description, + ) + + executor = TaskExecutor( + context_registry=mock_context_registry, + processor_registry=mock_processor_registry, + provider_registry=mock_provider_registry, + tool_registry=tool_registry, + ) + + # Create test context + test_context = tool_config.contexts["test_context"] + + # Create and execute task + task_context = TaskContext( + context=test_context, + processors=[], + tools=["analyze", "repeat_text"], + tool_choice="auto", + ) + + task_provider = TaskProvider( + name="test_provider", + model="test/model", + display_name="Test Provider", + ) + + result = await executor.execute(task_context, task_provider) + + # Verify the result + assert result.content == "Test response" + assert result.model == "test/model" + + # Verify interactions + mock_context_registry.get_loader.assert_called_once() + mock_provider_registry.create_provider.assert_called_once() + + # Verify provider was called with correct configuration + mock_provider = mock_provider_registry.create_provider.return_value + mock_provider.complete.assert_called_once() + + # Verify tool configuration was passed correctly + call_kwargs = mock_provider.complete.call_args[1] + assert "tools" in call_kwargs + assert len(call_kwargs["tools"]) == 2 # noqa: PLR2004 + assert call_kwargs["tool_choice"] == "auto" + + +@pytest.mark.asyncio +async def test_task_with_tools_streaming( + tool_config: Config, + mock_context_registry: mock.MagicMock, + mock_processor_registry: mock.MagicMock, + mock_provider_registry: mock.MagicMock, +) -> None: + """Test streaming task execution with tools.""" + + # Create a proper async generator for streaming + async def mock_stream(*args, **kwargs): + for content in ["Chunk 1", "Chunk 2"]: + yield CompletionResult( + content=content, + model="test/model", + metadata={}, + ) + + # Set up mock provider with streaming support + mock_provider = mock.MagicMock() + mock_provider.complete_stream = mock_stream + mock_provider_registry.create_provider.return_value = mock_provider + + # Setup + tool_registry = ToolRegistry() + for tool_id, tool_cfg in tool_config.tools.items(): + tool_registry.register_path( + import_path=tool_cfg.import_path, + name=tool_cfg.name or tool_id, + description=tool_cfg.description, + ) + + executor = TaskExecutor( + context_registry=mock_context_registry, + processor_registry=mock_processor_registry, + provider_registry=mock_provider_registry, + tool_registry=tool_registry, + ) + + test_context = tool_config.contexts["test_context"] + task_context = TaskContext( + context=test_context, + processors=[], + tools=["analyze", "repeat_text"], + tool_choice="auto", + ) + + task_provider = TaskProvider( + name="test_provider", + model="test/model", + display_name="Test Provider", + ) + # Collect streaming results + chunks = [ + chunk async for chunk in executor.execute_stream(task_context, task_provider) + ] + + # Verify results + assert len(chunks) == 2 # noqa: PLR2004 + assert chunks[0].content == "Chunk 1" + assert chunks[1].content == "Chunk 2" + assert all(chunk.model == "test/model" for chunk in chunks) + + # Verify interactions + mock_context_registry.get_loader.assert_called_once() + mock_provider_registry.create_provider.assert_called_once() + + # Verify provider configuration + provider_config = mock_provider_registry.create_provider.call_args[0][1] + assert provider_config.streaming is True # Verify streaming was enabled + assert provider_config.model == "test/model" diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..d7591f9 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import pytest + +from llmling.tools.base import DynamicTool, ToolRegistry +from llmling.tools.exceptions import ToolError + + +# Test fixtures +@pytest.fixture +def registry() -> ToolRegistry: + """Create a fresh tool registry.""" + return ToolRegistry() + + +# Test DynamicTool +class TestDynamicTool: + def test_init(self) -> None: + """Test tool initialization.""" + tool = DynamicTool( + import_path="llmling.testing.tools.example_tool", + name="custom_name", + description="Custom description", + ) + assert tool.name == "custom_name" + assert tool.description == "Custom description" + assert tool.import_path == "llmling.testing.tools.example_tool" + + def test_default_name(self) -> None: + """Test default name from import path.""" + tool = DynamicTool("llmling.testing.tools.example_tool") + assert tool.name == "example_tool" + + def test_default_description(self) -> None: + """Test default description from docstring.""" + tool = DynamicTool("llmling.testing.tools.example_tool") + assert "repeats text" in tool.description.lower() + + def test_schema_generation(self) -> None: + """Test schema generation from function signature.""" + tool = DynamicTool("llmling.testing.tools.example_tool") + schema = tool.get_schema() + + assert schema.type == "function" + assert schema.function["name"] == "example_tool" + assert "text" in schema.function["parameters"]["properties"] + assert "repeat" in schema.function["parameters"]["properties"] + assert schema.function["parameters"]["required"] == ["text"] + + @pytest.mark.asyncio + async def test_execution(self) -> None: + """Test tool execution.""" + tool = DynamicTool("llmling.testing.tools.example_tool") + result = await tool.execute(text="test", repeat=2) + assert result == "testtest" + + @pytest.mark.asyncio + async def test_execution_failure(self) -> None: + """Test tool execution failure.""" + tool = DynamicTool("llmling.testing.tools.failing_tool") + with pytest.raises(Exception, match="test"): + await tool.execute(text="test") + + +# Test ToolRegistry +class TestToolRegistry: + def test_register_path(self, registry: ToolRegistry) -> None: + """Test registering a tool by import path.""" + registry.register_path( + "llmling.testing.tools.example_tool", + name="custom_tool", + ) + assert "custom_tool" in registry.list_tools() + + def test_register_duplicate(self, registry: ToolRegistry) -> None: + """Test registering duplicate tool names.""" + registry.register_path("llmling.testing.tools.example_tool", name="tool1") + with pytest.raises(ToolError): + registry.register_path("llmling.testing.tools.example_tool", name="tool1") + + def test_get_nonexistent(self, registry: ToolRegistry) -> None: + """Test getting non-existent tool.""" + with pytest.raises(ToolError): + registry.get_tool("nonexistent") + + def test_list_tools(self, registry: ToolRegistry) -> None: + """Test listing registered tools.""" + registry.register_path("llmling.testing.tools.example_tool", name="tool1") + registry.register_path("llmling.testing.tools.analyze_ast", name="tool2") + tools = registry.list_tools() + assert len(tools) == 2 # noqa: PLR2004 + assert "tool1" in tools + assert "tool2" in tools + + @pytest.mark.asyncio + async def test_execute(self, registry: ToolRegistry) -> None: + """Test executing a registered tool.""" + registry.register_path("llmling.testing.tools.example_tool") + result = await registry.execute("example_tool", text="test", repeat=3) + assert result == "testtesttest" + + @pytest.mark.asyncio + async def test_execute_with_validation(self, registry: ToolRegistry) -> None: + """Test tool execution with invalid parameters.""" + registry.register_path("llmling.testing.tools.analyze_ast") + + # Valid Python code + result = await registry.execute( + "analyze_ast", + code="class Test: pass\ndef func(): pass", + ) + assert result["classes"] == 1 + assert result["functions"] == 1 + + # Invalid Python code + with pytest.raises(Exception, match="invalid syntax"): + await registry.execute("analyze_ast", code="invalid python") + + def test_schema_generation(self, registry: ToolRegistry) -> None: + """Test schema generation for registered tools.""" + registry.register_path( + "llmling.testing.tools.analyze_ast", + description="Custom description", + ) + schema = registry.get_schema("analyze_ast") + + assert schema.type == "function" + assert "code" in schema.function["parameters"]["properties"] + assert schema.function["parameters"]["required"] == ["code"] + assert schema.function["description"] == "Custom description" + + +# Integration tests +@pytest.mark.asyncio +async def test_tool_integration() -> None: + """Test full tool workflow.""" + # Setup + registry = ToolRegistry() + registry.register_path( + "llmling.testing.tools.analyze_ast", + name="analyze", + description="Analyze Python code", + ) + + # Get schema + schema = registry.get_schema("analyze") + assert schema.type == "function" + + # Execute tool + code = """ +class TestClass: + def method1(self): + pass + def method2(self): + pass + """ + result = await registry.execute("analyze", code=code) + + assert result["classes"] == 1 + assert result["functions"] == 2 # noqa: PLR2004