From 830f6cd5eecbe6c7735a6f1abee24136917b9fc4 Mon Sep 17 00:00:00 2001 From: Philipp Temminghoff Date: Thu, 21 Nov 2024 02:16:40 +0100 Subject: [PATCH] chore: vision stuff --- src/llmling/config/models.py | 30 +++++++- src/llmling/context/__init__.py | 2 + src/llmling/context/loaders/image.py | 102 ++++++++++++++++++++++++++ src/llmling/context/models.py | 34 ++++++++- src/llmling/llm/base.py | 39 ++++++++-- src/llmling/llm/providers/litellm.py | 73 +++++++++++------- src/llmling/processors/base.py | 5 +- src/llmling/resources/vision_test.yml | 40 ++++++++++ src/llmling/task/executor.py | 45 +++++++++--- 9 files changed, 324 insertions(+), 46 deletions(-) create mode 100644 src/llmling/context/loaders/image.py create mode 100644 src/llmling/resources/vision_test.yml diff --git a/src/llmling/config/models.py b/src/llmling/config/models.py index 2a294fb..395eb70 100644 --- a/src/llmling/config/models.py +++ b/src/llmling/config/models.py @@ -31,6 +31,8 @@ class LLMProviderConfig(BaseModel): tools: dict[str, dict[str, Any]] | list[str] | None = None # Optional tools tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051 + max_image_size: int | None = None + model_config = ConfigDict(frozen=True) @field_validator("tools", mode="before") @@ -161,7 +163,33 @@ def validate_import_path(self) -> CallableContext: return self -Context = PathContext | TextContext | CLIContext | SourceContext | CallableContext +class ImageContext(BaseContext): + """Context for image input.""" + + type: Literal["image"] + path: str # Local path or URL + alt_text: str | None = None + + model_config = ConfigDict(frozen=True) + + @model_validator(mode="before") + @classmethod + def validate_path(cls, data: dict[str, Any]) -> dict[str, Any]: + """Validate that path is not empty.""" + if isinstance(data, dict) and not data.get("path"): + msg = "Path cannot be empty for image context" + raise ValueError(msg) + return data + + +Context = ( + PathContext + | TextContext + | CLIContext + | SourceContext + | CallableContext + | ImageContext +) class TaskTemplate(BaseModel): diff --git a/src/llmling/context/__init__.py b/src/llmling/context/__init__.py index 42d6b51..a6c5be1 100644 --- a/src/llmling/context/__init__.py +++ b/src/llmling/context/__init__.py @@ -10,9 +10,11 @@ ) from llmling.context.registry import ContextLoaderRegistry from llmling.context.models import LoadedContext +from llmling.context.loaders.image import ImageContextLoader # Create and populate the default registry default_registry = ContextLoaderRegistry() +default_registry.register("image", ImageContextLoader) default_registry.register("path", PathContextLoader) default_registry.register("text", TextContextLoader) default_registry.register("cli", CLIContextLoader) diff --git a/src/llmling/context/loaders/image.py b/src/llmling/context/loaders/image.py new file mode 100644 index 0000000..916d37f --- /dev/null +++ b/src/llmling/context/loaders/image.py @@ -0,0 +1,102 @@ +"""Image context loader implementation.""" + +from __future__ import annotations + +import base64 +from typing import TYPE_CHECKING + +import upath + +from llmling.config.models import ImageContext +from llmling.context.base import ContextLoader +from llmling.context.models import LoadedContext +from llmling.core import exceptions +from llmling.core.log import get_logger +from llmling.llm.base import MessageContent + + +if TYPE_CHECKING: + from llmling.config.models import Context + from llmling.processors.registry import ProcessorRegistry + +logger = get_logger(__name__) + + +class ImageContextLoader(ContextLoader): + """Loads image content from files or URLs.""" + + async def load( + self, + context: Context, + processor_registry: ProcessorRegistry, + ) -> LoadedContext: + """Load and process image content. + + Args: + context: Image context configuration + processor_registry: Registry of available processors + + Returns: + Loaded and processed context + + Raises: + LoaderError: If loading fails or context type is invalid + """ + if not isinstance(context, ImageContext): + msg = f"Expected ImageContext, got {type(context).__name__}" + raise exceptions.LoaderError(msg) + + try: + # Use UPath to handle the path + path_obj = upath.UPath(context.path) + is_url = path_obj.as_uri().startswith(("http://", "https://")) + + content_item = MessageContent( + type="image_url" if is_url else "image_base64", + content=await self._load_content(path_obj, is_url), + alt_text=context.alt_text, + ) + + return LoadedContext( + content="", # Keep empty for backward compatibility + content_items=[content_item], + source_type="image", + metadata={ + "path": context.path, + "type": "url" if is_url else "local", + "alt_text": context.alt_text, + }, + ) + + except Exception as exc: + msg = f"Failed to load image from {context.path}" + raise exceptions.LoaderError(msg) from exc + + async def _load_content(self, path_obj: upath.UPath, is_url: bool) -> str: + """Load content from path. + + Args: + path_obj: UPath object representing the path + is_url: Whether the path is a URL + + Returns: + URL or base64-encoded content + + Raises: + LoaderError: If loading fails + """ + if is_url: + return path_obj.as_uri() + + try: + if not path_obj.exists(): + msg = f"Image file not found: {path_obj}" + raise exceptions.LoaderError(msg) # noqa: TRY301 + + with path_obj.open("rb") as f: + return base64.b64encode(f.read()).decode() + except Exception as exc: + if isinstance(exc, exceptions.LoaderError): + raise + msg = f"Failed to read image file: {path_obj}" + raise exceptions.LoaderError(msg) from exc diff --git a/src/llmling/context/models.py b/src/llmling/context/models.py index 8941970..a0a9ec4 100644 --- a/src/llmling/context/models.py +++ b/src/llmling/context/models.py @@ -4,7 +4,9 @@ from typing import Any -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from llmling.llm.base import MessageContent class BaseContext(BaseModel): @@ -27,10 +29,36 @@ class ProcessingContext(BaseModel): # type: ignore[no-redef] model_config = ConfigDict(frozen=True) -class LoadedContext(BaseContext): +class LoadedContext(BaseModel): """Result of loading and processing a context.""" + content: str = "" # Keep for backward compatibility + content_items: list[MessageContent] = Field(default_factory=list) source_type: str | None = None - source_metadata: dict[str, Any] = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) model_config = ConfigDict(frozen=True) + + @model_validator(mode="before") + @classmethod + def ensure_content_sync(cls, data: dict[str, Any]) -> dict[str, Any]: + """Ensure content and content_items are in sync.""" + if isinstance(data, dict): + content = data.get("content", "") + content_items = data.get("content_items", []) + + # If we have content but no items, create a text item + if content and not content_items: + data["content_items"] = [ + MessageContent(type="text", content=content).model_dump() + ] + # If we have items but no content, use first text item's content + elif content_items and not content: + text_items = [ + item + for item in content_items + if isinstance(item, dict) and item.get("type") == "text" + ] + if text_items: + data["content"] = text_items[0]["content"] + return data diff --git a/src/llmling/llm/base.py b/src/llmling/llm/base.py index 651fff3..cf6f204 100644 --- a/src/llmling/llm/base.py +++ b/src/llmling/llm/base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Literal -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from llmling.core import exceptions from llmling.core.log import get_logger @@ -35,6 +35,8 @@ class LLMConfig(BaseModel): tools: list[dict[str, Any]] | None = None tool_choice: Literal["none", "auto"] | str | None = None # noqa: PYI051 + max_image_size: int | None = None # Maximum image size in pixels + # LiteLLM settings api_base: str | None = None api_key: str | None = None @@ -72,17 +74,44 @@ class ToolCall(BaseModel): model_config = ConfigDict(frozen=True) +ContentType = Literal["text", "image_url", "image_base64"] + + +class MessageContent(BaseModel): + """Content item in a message.""" + + type: ContentType = "text" # Default to text for backward compatibility + content: str + alt_text: str | None = None # For image descriptions + + model_config = ConfigDict(frozen=True) + + class Message(BaseModel): """A chat message.""" role: MessageRole - content: str - name: str | None = None # For tool messages - tool_calls: list[ToolCall] | None = None # For assistant messages - tool_call_id: str | None = None # For tool response messages + content: str = "" # Keep for backward compatibility + content_items: list[MessageContent] = Field(default_factory=list) + name: str | None = None + tool_calls: list[ToolCall] | None = None + tool_call_id: str | None = None model_config = ConfigDict(frozen=True) + @model_validator(mode="before") + @classmethod + def ensure_content_items(cls, data: dict[str, Any]) -> dict[str, Any]: + """Ensure content_items is populated from content if empty.""" + if isinstance(data, dict): # Type check for static analysis + content = data.get("content", "") + content_items = data.get("content_items", []) + if content and not content_items: + data["content_items"] = [ + MessageContent(type="text", content=content).model_dump() + ] + return data + class CompletionResult(BaseModel): """Result from an LLM completion.""" diff --git a/src/llmling/llm/providers/litellm.py b/src/llmling/llm/providers/litellm.py index 5a13a43..14b0777 100644 --- a/src/llmling/llm/providers/litellm.py +++ b/src/llmling/llm/providers/litellm.py @@ -74,6 +74,42 @@ def _get_provider_from_model(self) -> str: except Exception: # noqa: BLE001 return "unknown" + def _prepare_content(self, msg: Message) -> str | list[dict[str, Any]]: + """Prepare message content for LiteLLM. + + Handles both text and image content, converting to the format + expected by the API. + """ + if not msg.content_items: + return msg.content + + content: list[Any] = [] + for i in msg.content_items: + match i.type: + case "text": + content.append({"type": "text", "text": i.content}) + case "image_url": + content.append({"type": "image_url", "image_url": {"url": i.content}}) + case "image_base64": + content.append({ + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{i.content}"}, + }) + + # For better compatibility, if only text and single item, return just the text + if len(content) == 1 and content[0]["type"] == "text": + return content[0]["text"] + + return content + + def _check_vision_support(self, messages: list[Message]) -> None: + """Check if model supports vision when image content is present.""" + types = ("image_url", "image_base64") + has_images = any(i.type in types for msg in messages for i in msg.content_items) + if has_images and not self.model_info.supports_vision: + msg = f"Model {self.config.model} does not support vision inputs" + raise exceptions.LLMError(msg) + def _prepare_request_kwargs(self, **additional_kwargs: Any) -> dict[str, Any]: """Prepare request kwargs from config and additional kwargs.""" # Start with essential settings preserved from initialization @@ -113,29 +149,25 @@ async def complete( ) -> CompletionResult: """Implement completion using LiteLLM.""" try: + # Check vision support if needed + self._check_vision_support(messages) + # Convert messages to dict format - messages_dict = [ + messages_list: list[dict[str, Any]] = [ { "role": msg.role, - "content": msg.content, **({"name": msg.name} if msg.name else {}), + "content": self._prepare_content(msg), } for msg in messages ] - # Clean up kwargs - # Remove empty tools array and related settings - if "tools" in kwargs and not kwargs["tools"]: - kwargs.pop("tools") - kwargs.pop("tool_choice", None) - - # Prepare request kwargs request_kwargs = self._prepare_request_kwargs(**kwargs) # Execute completion response = await litellm.acompletion( model=self.config.model, - messages=messages_dict, + messages=messages_list, **request_kwargs, ) @@ -152,27 +184,18 @@ async def complete_stream( ) -> AsyncIterator[CompletionResult]: """Implement streaming completion using LiteLLM.""" try: + # Check vision support if needed + self._check_vision_support(messages) + # Convert messages to dict format - messages_dict = [ + messages_dict: list[dict[str, Any]] = [ { "role": msg.role, - "content": msg.content, **({"name": msg.name} if msg.name else {}), + "content": self._prepare_content(msg), } for msg in messages ] - - # Clean up kwargs - # Remove empty tools array and related settings - if "tools" in kwargs and not kwargs["tools"]: - kwargs.pop("tools") - kwargs.pop("tool_choice", None) - - # Remove tool-related kwargs if model doesn't support them - if not self.model_info.supports_function_calling: - kwargs.pop("tools", None) - kwargs.pop("tool_choice", None) - # Prepare kwargs with streaming enabled request_kwargs = self._prepare_request_kwargs(stream=True, **kwargs) @@ -201,9 +224,9 @@ async def complete_stream( def _process_response(self, response: Any) -> CompletionResult: """Process LiteLLM response into CompletionResult.""" + tool_calls = None try: # Handle tool calls if present - tool_calls = None if hasattr(response.choices[0].message, "tool_calls"): tc = response.choices[0].message.tool_calls logger.debug("Received tool calls from LLM: %s", tc) diff --git a/src/llmling/processors/base.py b/src/llmling/processors/base.py index 828cdca..9729b36 100644 --- a/src/llmling/processors/base.py +++ b/src/llmling/processors/base.py @@ -11,6 +11,9 @@ from llmling.core.log import get_logger +PROCESSOR_TYPES = Literal["function", "template", "image"] + + if TYPE_CHECKING: from collections.abc import AsyncIterator @@ -37,7 +40,7 @@ class TemplateProcessorConfig(TypedDict, total=False): class ProcessorConfig(BaseModel): """Configuration for text processors.""" - type: Literal["function", "template"] + type: PROCESSOR_TYPES name: str | None = None description: str | None = None diff --git a/src/llmling/resources/vision_test.yml b/src/llmling/resources/vision_test.yml new file mode 100644 index 0000000..c7b5511 --- /dev/null +++ b/src/llmling/resources/vision_test.yml @@ -0,0 +1,40 @@ +llm_providers: + gpt4-vision: + name: "GPT-4 Vision" + model: "openai/gpt-4-vision-preview" + max_tokens: 4096 + max_image_size: 2048 + +contexts: + test_image: + type: "image" + path: "resources/test_image.jpg" + alt_text: "A test image for vision capabilities" + + web_image: + type: "image" + path: "https://example.com/image.jpg" + alt_text: "A test image from the web" + + multi_image: + type: "text" + content: "Compare these images and describe their differences:" + processors: + - name: "append_images" + kwargs: + images: ["resources/image1.jpg", "resources/image2.jpg"] + +task_templates: + analyze_image: + provider: gpt4-vision + context: test_image + settings: + temperature: 0.7 + max_tokens: 1000 + + compare_images: + provider: gpt4-vision + context: multi_image + settings: + temperature: 0.7 + max_tokens: 2000 diff --git a/src/llmling/task/executor.py b/src/llmling/task/executor.py index fa3237c..9db0f7a 100644 --- a/src/llmling/task/executor.py +++ b/src/llmling/task/executor.py @@ -18,6 +18,7 @@ from llmling.config.manager import ConfigManager from llmling.context import ContextLoaderRegistry + from llmling.context.models import LoadedContext from llmling.llm.registry import ProviderRegistry from llmling.processors.registry import ProcessorRegistry @@ -139,8 +140,8 @@ async def execute( # Load and process context context_result = await self._load_context(task_context) - # Prepare messages - messages = self._prepare_messages(context_result.content, system_prompt) + # Prepare messages with new content structure support + messages = self._prepare_messages(context_result, system_prompt) # Configure and create provider llm_config = self._create_llm_config(task_provider) @@ -148,6 +149,7 @@ async def execute( task_provider.name, llm_config, ) + # Get completion with potential tool calls while True: completion = await provider.complete(messages, **kwargs) @@ -156,8 +158,11 @@ async def execute( if completion.tool_calls: tool_results = [] for tool_call in completion.tool_calls: - msg = "Executing tool call: %s with params: %s" - logger.debug(msg, tool_call.name, tool_call.parameters) + logger.debug( + "Executing tool call: %s with params: %s", + tool_call.name, + tool_call.parameters, + ) result = await self.tool_registry.execute( tool_call.name, **tool_call.parameters, @@ -212,8 +217,8 @@ async def execute_stream( # Load and process context context_result = await self._load_context(task_context) - # Prepare messages - messages = self._prepare_messages(context_result.content, system_prompt) + # Prepare messages with new content structure support + messages = self._prepare_messages(context_result, system_prompt) # Configure and create provider llm_config = self._create_llm_config(task_provider, streaming=True) @@ -231,7 +236,8 @@ async def execute_stream( ) except Exception as exc: - msg = "Task streaming failed" + logger.exception("Task streaming failed") + msg = f"Task streaming failed: {exc}" raise exceptions.TaskError(msg) from exc async def _load_context(self, task_context: TaskContext) -> Any: @@ -262,13 +268,13 @@ async def _load_context(self, task_context: TaskContext) -> Any: def _prepare_messages( self, - content: str, + loaded_context: LoadedContext | str, # for bw compat system_prompt: str | None, ) -> list[Message]: """Prepare messages for LLM completion. Args: - content: Context content + loaded_context: Loaded and processed context system_prompt: Optional system prompt Returns: @@ -277,9 +283,26 @@ def _prepare_messages( messages: list[Message] = [] if system_prompt: - messages.append(Message(role="system", content=system_prompt)) + messages.append( + Message( + role="system", + content=system_prompt, + ) + ) + + # If context has content_items, use them directly + if isinstance(loaded_context, str): + # Backward compatibility: use plain content + messages.append(Message(role="user", content=loaded_context)) + + elif loaded_context.content_items: + messages.append( + Message(role="user", content_items=loaded_context.content_items) + ) + else: + # Backward compatibility: use plain content + messages.append(Message(role="user", content=loaded_context.content)) - messages.append(Message(role="user", content=content)) return messages def _create_llm_config(