Skip to content

Commit

Permalink
chore: vision stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Nov 21, 2024
1 parent eb7a001 commit 830f6cd
Show file tree
Hide file tree
Showing 9 changed files with 324 additions and 46 deletions.
30 changes: 29 additions & 1 deletion src/llmling/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class LLMProviderConfig(BaseModel):
tools: dict[str, dict[str, Any]] | list[str] | None = None # Optional tools
tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051

max_image_size: int | None = None

model_config = ConfigDict(frozen=True)

@field_validator("tools", mode="before")
Expand Down Expand Up @@ -161,7 +163,33 @@ def validate_import_path(self) -> CallableContext:
return self


Context = PathContext | TextContext | CLIContext | SourceContext | CallableContext
class ImageContext(BaseContext):
"""Context for image input."""

type: Literal["image"]
path: str # Local path or URL
alt_text: str | None = None

model_config = ConfigDict(frozen=True)

@model_validator(mode="before")
@classmethod
def validate_path(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Validate that path is not empty."""
if isinstance(data, dict) and not data.get("path"):
msg = "Path cannot be empty for image context"
raise ValueError(msg)
return data


Context = (
PathContext
| TextContext
| CLIContext
| SourceContext
| CallableContext
| ImageContext
)


class TaskTemplate(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions src/llmling/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
)
from llmling.context.registry import ContextLoaderRegistry
from llmling.context.models import LoadedContext
from llmling.context.loaders.image import ImageContextLoader

# Create and populate the default registry
default_registry = ContextLoaderRegistry()
default_registry.register("image", ImageContextLoader)
default_registry.register("path", PathContextLoader)
default_registry.register("text", TextContextLoader)
default_registry.register("cli", CLIContextLoader)
Expand Down
102 changes: 102 additions & 0 deletions src/llmling/context/loaders/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Image context loader implementation."""

from __future__ import annotations

import base64
from typing import TYPE_CHECKING

import upath

from llmling.config.models import ImageContext
from llmling.context.base import ContextLoader
from llmling.context.models import LoadedContext
from llmling.core import exceptions
from llmling.core.log import get_logger
from llmling.llm.base import MessageContent


if TYPE_CHECKING:
from llmling.config.models import Context
from llmling.processors.registry import ProcessorRegistry

logger = get_logger(__name__)


class ImageContextLoader(ContextLoader):
"""Loads image content from files or URLs."""

async def load(
self,
context: Context,
processor_registry: ProcessorRegistry,
) -> LoadedContext:
"""Load and process image content.
Args:
context: Image context configuration
processor_registry: Registry of available processors
Returns:
Loaded and processed context
Raises:
LoaderError: If loading fails or context type is invalid
"""
if not isinstance(context, ImageContext):
msg = f"Expected ImageContext, got {type(context).__name__}"
raise exceptions.LoaderError(msg)

try:
# Use UPath to handle the path
path_obj = upath.UPath(context.path)
is_url = path_obj.as_uri().startswith(("http://", "https://"))

content_item = MessageContent(
type="image_url" if is_url else "image_base64",
content=await self._load_content(path_obj, is_url),
alt_text=context.alt_text,
)

return LoadedContext(
content="", # Keep empty for backward compatibility
content_items=[content_item],
source_type="image",
metadata={
"path": context.path,
"type": "url" if is_url else "local",
"alt_text": context.alt_text,
},
)

except Exception as exc:
msg = f"Failed to load image from {context.path}"
raise exceptions.LoaderError(msg) from exc

async def _load_content(self, path_obj: upath.UPath, is_url: bool) -> str:
"""Load content from path.
Args:
path_obj: UPath object representing the path
is_url: Whether the path is a URL
Returns:
URL or base64-encoded content
Raises:
LoaderError: If loading fails
"""
if is_url:
return path_obj.as_uri()

try:
if not path_obj.exists():
msg = f"Image file not found: {path_obj}"
raise exceptions.LoaderError(msg) # noqa: TRY301

with path_obj.open("rb") as f:
return base64.b64encode(f.read()).decode()
except Exception as exc:
if isinstance(exc, exceptions.LoaderError):
raise
msg = f"Failed to read image file: {path_obj}"
raise exceptions.LoaderError(msg) from exc
34 changes: 31 additions & 3 deletions src/llmling/context/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from typing import Any

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator

from llmling.llm.base import MessageContent


class BaseContext(BaseModel):
Expand All @@ -27,10 +29,36 @@ class ProcessingContext(BaseModel): # type: ignore[no-redef]
model_config = ConfigDict(frozen=True)


class LoadedContext(BaseContext):
class LoadedContext(BaseModel):
"""Result of loading and processing a context."""

content: str = "" # Keep for backward compatibility
content_items: list[MessageContent] = Field(default_factory=list)
source_type: str | None = None
source_metadata: dict[str, Any] = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)

model_config = ConfigDict(frozen=True)

@model_validator(mode="before")
@classmethod
def ensure_content_sync(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Ensure content and content_items are in sync."""
if isinstance(data, dict):
content = data.get("content", "")
content_items = data.get("content_items", [])

# If we have content but no items, create a text item
if content and not content_items:
data["content_items"] = [
MessageContent(type="text", content=content).model_dump()
]
# If we have items but no content, use first text item's content
elif content_items and not content:
text_items = [
item
for item in content_items
if isinstance(item, dict) and item.get("type") == "text"
]
if text_items:
data["content"] = text_items[0]["content"]
return data
39 changes: 34 additions & 5 deletions src/llmling/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Literal

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator

from llmling.core import exceptions
from llmling.core.log import get_logger
Expand Down Expand Up @@ -35,6 +35,8 @@ class LLMConfig(BaseModel):
tools: list[dict[str, Any]] | None = None
tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051

max_image_size: int | None = None # Maximum image size in pixels

# LiteLLM settings
api_base: str | None = None
api_key: str | None = None
Expand Down Expand Up @@ -72,17 +74,44 @@ class ToolCall(BaseModel):
model_config = ConfigDict(frozen=True)


ContentType = Literal["text", "image_url", "image_base64"]


class MessageContent(BaseModel):
"""Content item in a message."""

type: ContentType = "text" # Default to text for backward compatibility
content: str
alt_text: str | None = None # For image descriptions

model_config = ConfigDict(frozen=True)


class Message(BaseModel):
"""A chat message."""

role: MessageRole
content: str
name: str | None = None # For tool messages
tool_calls: list[ToolCall] | None = None # For assistant messages
tool_call_id: str | None = None # For tool response messages
content: str = "" # Keep for backward compatibility
content_items: list[MessageContent] = Field(default_factory=list)
name: str | None = None
tool_calls: list[ToolCall] | None = None
tool_call_id: str | None = None

model_config = ConfigDict(frozen=True)

@model_validator(mode="before")
@classmethod
def ensure_content_items(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Ensure content_items is populated from content if empty."""
if isinstance(data, dict): # Type check for static analysis
content = data.get("content", "")
content_items = data.get("content_items", [])
if content and not content_items:
data["content_items"] = [
MessageContent(type="text", content=content).model_dump()
]
return data


class CompletionResult(BaseModel):
"""Result from an LLM completion."""
Expand Down
73 changes: 48 additions & 25 deletions src/llmling/llm/providers/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,42 @@ def _get_provider_from_model(self) -> str:
except Exception: # noqa: BLE001
return "unknown"

def _prepare_content(self, msg: Message) -> str | list[dict[str, Any]]:
"""Prepare message content for LiteLLM.
Handles both text and image content, converting to the format
expected by the API.
"""
if not msg.content_items:
return msg.content

content: list[Any] = []
for i in msg.content_items:
match i.type:
case "text":
content.append({"type": "text", "text": i.content})
case "image_url":
content.append({"type": "image_url", "image_url": {"url": i.content}})
case "image_base64":
content.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{i.content}"},
})

# For better compatibility, if only text and single item, return just the text
if len(content) == 1 and content[0]["type"] == "text":
return content[0]["text"]

return content

def _check_vision_support(self, messages: list[Message]) -> None:
"""Check if model supports vision when image content is present."""
types = ("image_url", "image_base64")
has_images = any(i.type in types for msg in messages for i in msg.content_items)
if has_images and not self.model_info.supports_vision:
msg = f"Model {self.config.model} does not support vision inputs"
raise exceptions.LLMError(msg)

def _prepare_request_kwargs(self, **additional_kwargs: Any) -> dict[str, Any]:
"""Prepare request kwargs from config and additional kwargs."""
# Start with essential settings preserved from initialization
Expand Down Expand Up @@ -113,29 +149,25 @@ async def complete(
) -> CompletionResult:
"""Implement completion using LiteLLM."""
try:
# Check vision support if needed
self._check_vision_support(messages)

# Convert messages to dict format
messages_dict = [
messages_list: list[dict[str, Any]] = [
{
"role": msg.role,
"content": msg.content,
**({"name": msg.name} if msg.name else {}),
"content": self._prepare_content(msg),
}
for msg in messages
]

# Clean up kwargs
# Remove empty tools array and related settings
if "tools" in kwargs and not kwargs["tools"]:
kwargs.pop("tools")
kwargs.pop("tool_choice", None)

# Prepare request kwargs
request_kwargs = self._prepare_request_kwargs(**kwargs)

# Execute completion
response = await litellm.acompletion(
model=self.config.model,
messages=messages_dict,
messages=messages_list,
**request_kwargs,
)

Expand All @@ -152,27 +184,18 @@ async def complete_stream(
) -> AsyncIterator[CompletionResult]:
"""Implement streaming completion using LiteLLM."""
try:
# Check vision support if needed
self._check_vision_support(messages)

# Convert messages to dict format
messages_dict = [
messages_dict: list[dict[str, Any]] = [
{
"role": msg.role,
"content": msg.content,
**({"name": msg.name} if msg.name else {}),
"content": self._prepare_content(msg),
}
for msg in messages
]

# Clean up kwargs
# Remove empty tools array and related settings
if "tools" in kwargs and not kwargs["tools"]:
kwargs.pop("tools")
kwargs.pop("tool_choice", None)

# Remove tool-related kwargs if model doesn't support them
if not self.model_info.supports_function_calling:
kwargs.pop("tools", None)
kwargs.pop("tool_choice", None)

# Prepare kwargs with streaming enabled
request_kwargs = self._prepare_request_kwargs(stream=True, **kwargs)

Expand Down Expand Up @@ -201,9 +224,9 @@ async def complete_stream(

def _process_response(self, response: Any) -> CompletionResult:
"""Process LiteLLM response into CompletionResult."""
tool_calls = None
try:
# Handle tool calls if present
tool_calls = None
if hasattr(response.choices[0].message, "tool_calls"):
tc = response.choices[0].message.tool_calls
logger.debug("Received tool calls from LLM: %s", tc)
Expand Down
Loading

0 comments on commit 830f6cd

Please sign in to comment.