Skip to content

Commit

Permalink
chore: multimodal
Browse files Browse the repository at this point in the history
(cherry picked from commit 55fa3aa5ab1dba649a94df2d111eba001aa8033c)
  • Loading branch information
phil65 committed Nov 20, 2024
1 parent cc3feb4 commit a893000
Show file tree
Hide file tree
Showing 14 changed files with 568 additions and 121 deletions.
15 changes: 8 additions & 7 deletions src/llmling/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from llmling.core.typedefs import ProcessingStep # noqa: TCH001
from llmling.core.typedefs import ContentType, ProcessingStep # noqa: TCH001
from llmling.processors.base import ProcessorConfig # noqa: TCH001


Expand Down Expand Up @@ -167,12 +167,13 @@ def validate_import_path(self) -> CallableContext:
class TaskTemplate(BaseModel):
"""Template for a specific task."""

provider: str # Required: provider name or group name
context: str # Required: context name or group name
settings: TaskSettings | None = None # Optional
inherit_tools: bool | None = None # Optional
tools: list[str] | None = None # Optional
tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051
provider: str
context: str
settings: TaskSettings | None = None
inherit_tools: bool | None = None
tools: list[str] | None = None
tool_choice: Literal["none", "auto"] | str | None = None

Check failure on line 175 in src/llmling/config/models.py

View workflow job for this annotation

GitHub Actions / test (3.12, ubuntu-latest)

Ruff (PYI051)

src/llmling/config/models.py:175:26: PYI051 `Literal["none"]` is redundant in a union with `str`

Check failure on line 175 in src/llmling/config/models.py

View workflow job for this annotation

GitHub Actions / test (3.12, ubuntu-latest)

Ruff (PYI051)

src/llmling/config/models.py:175:34: PYI051 `Literal["auto"]` is redundant in a union with `str`
content_type: ContentType | None = None

model_config = ConfigDict(frozen=True)

Expand Down
32 changes: 18 additions & 14 deletions src/llmling/config/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,29 @@ def validate_all(self) -> list[str]:
return warnings

def _validate_providers(self) -> list[str]:
"""Validate provider configuration.
Returns:
List of validation warnings
"""
warnings = [
f"Provider {provider} in group {group} not found"
for group, providers in self.config.provider_groups.items()
for provider in providers
if provider not in self.config.llm_providers
]
"""Validate provider configuration."""
warnings = []

# Check provider models
for name, provider in self.config.llm_providers.items():
if "/" not in provider.model:
# Check model capabilities
if provider.model.startswith("gpt-4-vision") and not any(
context.type in {"path", "vision_callable"}
for context in self.config.contexts.values()
):
warnings.append(
f"Provider {name} model should be in format 'provider/model'",
f"Vision model {name} configured but no image contexts found"
)

# Check image generation capabilities
if provider.model.startswith(("dall-e", "stable-diffusion")):
for template in self.config.task_templates.values():
if template.provider == name and template.context:
context = self.config.contexts.get(template.context)
if context and context.type != "text":
warnings.append(
f"Image generation model {name} requires text context"
)

return warnings

def _validate_contexts(self) -> list[str]:
Expand Down
145 changes: 145 additions & 0 deletions src/llmling/context/loaders/vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from pathlib import Path
from typing import Any

from PIL import Image
from upath import UPath

from llmling.context.base import ContextLoader
from llmling.context.models import LoadedContext
from llmling.core.exceptions import LoaderError
from llmling.core.typedefs import Content, ContentType


class VisionPathLoader(ContextLoader):
"""Loads images from files or URLs."""

SUPPORTED_FORMATS = {".png", ".jpg", ".jpeg", ".webp", ".avif"}

Check failure on line 16 in src/llmling/context/loaders/vision.py

View workflow job for this annotation

GitHub Actions / test (3.12, ubuntu-latest)

Ruff (RUF012)

src/llmling/context/loaders/vision.py:16:25: RUF012 Mutable class attributes should be annotated with `typing.ClassVar`

async def load(
self,
context: Context,

Check failure on line 20 in src/llmling/context/loaders/vision.py

View workflow job for this annotation

GitHub Actions / test (3.12, ubuntu-latest)

Ruff (F821)

src/llmling/context/loaders/vision.py:20:18: F821 Undefined name `Context`
processor_registry: ProcessorRegistry,

Check failure on line 21 in src/llmling/context/loaders/vision.py

View workflow job for this annotation

GitHub Actions / test (3.12, ubuntu-latest)

Ruff (F821)

src/llmling/context/loaders/vision.py:21:29: F821 Undefined name `ProcessorRegistry`
) -> LoadedContext:
"""Load image content from path."""
if not isinstance(context, PathContext):

Check failure on line 24 in src/llmling/context/loaders/vision.py

View workflow job for this annotation

GitHub Actions / test (3.12, ubuntu-latest)

Ruff (F821)

src/llmling/context/loaders/vision.py:24:36: F821 Undefined name `PathContext`
msg = f"Expected PathContext, got {type(context).__name__}"
raise LoaderError(msg)

try:
path = UPath(context.path)
if not self._is_supported_format(path):
msg = f"Unsupported image format: {path.suffix}"
raise LoaderError(msg)

Check failure on line 32 in src/llmling/context/loaders/vision.py

View workflow job for this annotation

GitHub Actions / test (3.12, ubuntu-latest)

Ruff (TRY301)

src/llmling/context/loaders/vision.py:32:17: TRY301 Abstract `raise` to an inner function

# Load and validate image
image_data = await self._load_image(path)
image_meta = await self._get_image_metadata(image_data)

content = Content(
type=ContentType.IMAGE, data=image_data, metadata=image_meta
)

# Process if needed
if procs := context.processors:
processed = await processor_registry.process(content, procs)
content = processed.content

return LoadedContext(
content=content,
source_type="vision_path",
metadata={"path": str(path), "format": path.suffix.lower(), **image_meta},
)

except Exception as exc:
msg = f"Failed to load image from {context.path}: {exc}"
raise LoaderError(msg) from exc

def _is_supported_format(self, path: Path | UPath) -> bool:
"""Check if file format is supported."""
return path.suffix.lower() in self.SUPPORTED_FORMATS

async def _load_image(self, path: Path | UPath) -> bytes:
"""Load image data from path or URL."""
if isinstance(path, UPath) and path.protocol != "file":
# Handle remote URLs
async with aiohttp.ClientSession() as session:

Check failure on line 65 in src/llmling/context/loaders/vision.py

View workflow job for this annotation

GitHub Actions / test (3.12, ubuntu-latest)

Ruff (F821)

src/llmling/context/loaders/vision.py:65:24: F821 Undefined name `aiohttp`
async with session.get(str(path)) as response:

Check failure on line 66 in src/llmling/context/loaders/vision.py

View workflow job for this annotation

GitHub Actions / test (3.12, ubuntu-latest)

Ruff (SIM117)

src/llmling/context/loaders/vision.py:65:13: SIM117 Use a single `with` statement with multiple contexts instead of nested `with` statements
response.raise_for_status()
return await response.read()

# Local file
return path.read_bytes()

async def _get_image_metadata(self, image_data: bytes) -> dict[str, Any]:
"""Extract image metadata."""
with Image.open(io.BytesIO(image_data)) as img:

Check failure on line 75 in src/llmling/context/loaders/vision.py

View workflow job for this annotation

GitHub Actions / test (3.12, ubuntu-latest)

Ruff (F821)

src/llmling/context/loaders/vision.py:75:25: F821 Undefined name `io`
return {
"size": img.size,
"mode": img.mode,
"format": img.format,
"has_exif": hasattr(img, "_getexif") and img._getexif() is not None,
}


class VisionCallableLoader(ContextLoader):
"""Loads images from callable execution."""

async def load(
self,
context: Context,
processor_registry: ProcessorRegistry,
) -> LoadedContext:
"""Load image from callable execution."""
if not isinstance(context, CallableContext):
msg = f"Expected CallableContext, got {type(context).__name__}"
raise LoaderError(msg)

try:
# Execute callable
result = await calling.execute_callable(
context.import_path, **context.keyword_args
)

# Validate and process result
if isinstance(result, (str, Path)):
# Load from path
path = UPath(str(result))
loader = VisionPathLoader()
return await loader.load(PathContext(path=str(path)), processor_registry)
if isinstance(result, bytes):
# Direct bytes
image_meta = await self._get_image_metadata(result)
content = Content(
type=ContentType.IMAGE, data=result, metadata=image_meta
)
elif isinstance(result, Image.Image):
# PIL Image
buffer = io.BytesIO()
result.save(buffer, format=result.format or "PNG")
content = Content(
type=ContentType.IMAGE,
data=buffer.getvalue(),
metadata={
"size": result.size,
"mode": result.mode,
"format": result.format,
},
)
else:
msg = f"Unsupported image result type: {type(result)}"
raise LoaderError(msg)

# Process if needed
if procs := context.processors:
processed = await processor_registry.process(content, procs)
content = processed.content

return LoadedContext(
content=content,
source_type="vision_callable",
metadata={"import_path": context.import_path, **content.metadata},
)

except Exception as exc:
msg = f"Failed to load image from callable {context.import_path}: {exc}"
raise LoaderError(msg) from exc
13 changes: 9 additions & 4 deletions src/llmling/context/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,26 @@

from pydantic import BaseModel, ConfigDict, Field

from llmling.core.typedefs import Content, ContentData, ContentType


class BaseContext(BaseModel):
"""Base class for all context types."""

content: str
content: Content[ContentData]
metadata: dict[str, Any] = Field(default_factory=dict)

model_config = ConfigDict(frozen=True)
# model_config = ConfigDict(frozen=True)
@property
def content_type(self) -> ContentType:
return self.content.type


class ProcessingContext(BaseModel): # type: ignore[no-redef]
"""Context for processor execution."""

original_content: str
current_content: str
original_content: Content[ContentData]
current_content: Content[ContentData]
metadata: dict[str, Any] = Field(default_factory=dict)
kwargs: dict[str, Any] = Field(default_factory=dict)

Expand Down
12 changes: 12 additions & 0 deletions src/llmling/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,15 @@ class TaskError(LLMLingError):

class LLMError(LLMLingError):
"""LLM related errors."""


class ImageProcessingError(ProcessorError):
"""Error during image processing."""


class UnsupportedImageFormatError(ImageProcessingError):
"""Image format not supported."""


class ImageValidationError(ValidationError):
"""Image validation failed."""
47 changes: 46 additions & 1 deletion src/llmling/core/typedefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,56 @@
from __future__ import annotations

from collections.abc import Awaitable, Callable
from typing import Any, Protocol, TypeVar
from enum import Enum
from typing import Any, Literal, Protocol, TypedDict, TypeVar

from pydantic import BaseModel, ConfigDict, Field


class ImageMetadata(TypedDict, total=False):
"""Metadata for image content."""

format: str
mime_type: str
width: int
height: int
channels: int
mode: str # PIL image mode
has_alpha: bool
file_size: int
source: str # Origin of the image


class ImageContent(TypedDict):
"""Structure for image content in messages."""

type: Literal["image"]
data: bytes
metadata: ImageMetadata


class ContentType(Enum):
"""Types of content that can be processed."""

TEXT = "text"
IMAGE = "image"
MULTI = "multimodal"


T = TypeVar("T", str, bytes, list[Any])


class Content[T]:
"""Generic content container."""

type: ContentType
data: T
metadata: dict[str, Any]


ContentData = str | bytes | list[Any]


class SupportsStr(Protocol):
"""Protocol for objects that can be converted to string."""

Expand Down
35 changes: 34 additions & 1 deletion src/llmling/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,50 @@ class ToolCall(BaseModel):
model_config = ConfigDict(frozen=True)


class MessageContent(BaseModel):
"""Content structure for messages."""

type: Literal["text", "image"]
data: str | dict[str, Any] # text or image data

model_config = ConfigDict(frozen=True)


class Message(BaseModel):
"""A chat message."""

role: MessageRole
content: str
content: str | list[MessageContent]
name: str | None = None # For tool messages
tool_calls: list[ToolCall] | None = None # For assistant messages
tool_call_id: str | None = None # For tool response messages

model_config = ConfigDict(frozen=True)

@classmethod
def text(cls, role: MessageRole, content: str, **kwargs: Any) -> Message:
"""Create a text-only message."""
return cls(role=role, content=content, **kwargs)

@classmethod
def multimodal(
cls, role: MessageRole, contents: list[MessageContent], **kwargs: Any
) -> Message:
"""Create a multimodal message."""
return cls(role=role, content=contents, **kwargs)

def get_text_content(self) -> str:
"""Get text content, useful for backwards compatibility."""
if isinstance(self.content, str):
return self.content
# For multimodal, concatenate any text content
text_parts = [
c.data
for c in self.content
if isinstance(c, MessageContent) and c.type == "text"
]
return " ".join(text_parts) if text_parts else ""


class CompletionResult(BaseModel):
"""Result from an LLM completion."""
Expand Down
Loading

0 comments on commit a893000

Please sign in to comment.