Skip to content

Commit

Permalink
chore: rework
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Nov 19, 2024
1 parent 2fe8121 commit 56a87a7
Show file tree
Hide file tree
Showing 7 changed files with 430 additions and 509 deletions.
141 changes: 29 additions & 112 deletions src/llmling/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

from abc import ABC, abstractmethod
import asyncio
from typing import TYPE_CHECKING, Any, Literal

from pydantic import BaseModel, ConfigDict, Field
Expand All @@ -12,33 +11,54 @@
from llmling.core.log import get_logger


logger = get_logger(__name__)
if TYPE_CHECKING:
from collections.abc import AsyncIterator


if TYPE_CHECKING:
from collections.abc import AsyncGenerator, AsyncIterator
logger = get_logger(__name__)


class LLMConfig(BaseModel):
"""Configuration for LLM providers."""

# Core identification
model: str
provider_name: str # Key used for provider lookup
display_name: str = "" # Human-readable name

# LLM parameters
temperature: float = 0.7
max_tokens: int | None = None
top_p: float | None = None
timeout: int = 30
max_retries: int = 3
streaming: bool = False
# New fields for tool support
tools: list[dict[str, Any]] | None = None
tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051

# LiteLLM settings
api_base: str | None = None
api_key: str | None = None
num_retries: int | None = None
request_timeout: float | None = None
metadata: dict[str, Any] | None = None
mock_response: str | None = None
cache: bool | None = None
cache_key: str | None = None
fallbacks: list[str] | None = None
context_window_fallbacks: list[str] | None = None
bearer_token: str | None = None
model_list: list[str] | None = None
drop_params: bool = False
add_function_to_prompt: bool = False
force_timeout: float | None = None
proxy_url: str | None = None
api_version: str | None = None
use_queue: bool = False

model_config = ConfigDict(frozen=True)


MessageRole = Literal["system", "user", "assistant"]
MessageRole = Literal["system", "user", "assistant", "tool"]
"""Valid message roles for chat completion."""


Expand All @@ -55,7 +75,7 @@ class ToolCall(BaseModel):
class Message(BaseModel):
"""A chat message."""

role: Literal["system", "user", "assistant", "tool"]
role: MessageRole
content: str
name: str | None = None # For tool messages
tool_calls: list[ToolCall] | None = None # For assistant messages
Expand Down Expand Up @@ -111,7 +131,7 @@ async def complete_stream(
self,
messages: list[Message],
**kwargs: Any,
) -> AsyncGenerator[CompletionResult, None]:
) -> AsyncIterator[CompletionResult]:
"""Generate a streaming completion for the messages.
Args:
Expand All @@ -138,106 +158,3 @@ async def validate_response(self, result: CompletionResult) -> None:
if not result.content and not result.tool_calls:
msg = "Empty response from LLM"
raise exceptions.LLMError(msg)


class RetryableProvider(LLMProvider):
"""LLM provider with retry support."""

async def complete(
self,
messages: list[Message],
**kwargs: Any,
) -> CompletionResult:
"""Generate a completion with retry support."""
retries = 0
last_error = None

while retries <= self.config.max_retries:
try:
result = await self._complete_impl(messages, **kwargs)
await self.validate_response(result)
except exceptions.LLMError as exc:
# Only retry LLM-specific errors
last_error = exc
retries += 1
if retries <= self.config.max_retries:
await self._handle_retry(exc, retries)
continue
break
except Exception as exc:
# Don't retry other errors
raise exc # noqa: TRY201
else:
return result

msg = f"Failed after {retries} retries"
raise exceptions.LLMError(msg) from last_error

async def complete_stream(
self,
messages: list[Message],
**kwargs: Any,
) -> AsyncGenerator[CompletionResult, None]:
"""Generate a streaming completion for the messages.
Args:
messages: List of messages for chat completion
**kwargs: Additional provider-specific parameters
Yields:
Streamed completion results
Raises:
LLMError: If completion fails
"""
retries = 0
last_error = None

while retries <= self.config.max_retries:
try:
async for result in self._complete_stream_impl(messages, **kwargs):
await self.validate_response(result)
yield result
except Exception as exc: # noqa: BLE001
last_error = exc
retries += 1
if retries <= self.config.max_retries:
await self._handle_retry(exc, retries)
continue
break
else:
return
msg = f"Failed after {retries} retries"
raise exceptions.LLMError(msg) from last_error

@abstractmethod
async def _complete_impl(
self,
messages: list[Message],
**kwargs: Any,
) -> CompletionResult:
"""Implement actual completion logic."""

@abstractmethod
async def _complete_stream_impl(
self,
messages: list[Message],
**kwargs: Any,
) -> AsyncIterator[CompletionResult]:
"""Implement actual streaming completion logic."""
yield NotImplemented # pragma: no cover

async def _handle_retry(self, error: Exception, attempt: int) -> None:
"""Handle retry after error.
Args:
error: The error that triggered the retry
attempt: The retry attempt number
"""
logger.warning(
"Attempt %d failed, retrying: %s",
attempt,
error,
exc_info=error, # This will log the full traceback
)
await asyncio.sleep(2**attempt)
Loading

0 comments on commit 56a87a7

Please sign in to comment.