Skip to content

Commit

Permalink
chore: context rework
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Nov 21, 2024
1 parent 830f6cd commit 24b9558
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 60 deletions.
17 changes: 10 additions & 7 deletions src/llmling/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from llmling.processors.base import ProcessorConfig # noqa: TCH001


ContextType = Literal["path", "text", "cli", "source", "callable", "image"]


class GlobalSettings(BaseModel):
"""Global settings that apply to all components."""

Expand Down Expand Up @@ -67,7 +70,7 @@ class TaskSettings(BaseModel):
class BaseContext(BaseModel):
"""Base class for all context types."""

type: str
context_type: ContextType = Field(...)
description: str = "" # Made optional with empty default
processors: list[ProcessingStep] = Field(
default_factory=list
Expand All @@ -78,7 +81,7 @@ class BaseContext(BaseModel):
class PathContext(BaseContext):
"""Context loaded from a file or URL."""

type: Literal["path"]
context_type: Literal["path"] = "path"
path: str

@model_validator(mode="after")
Expand All @@ -93,7 +96,7 @@ def validate_path(self) -> PathContext:
class TextContext(BaseContext):
"""Raw text context."""

type: Literal["text"]
context_type: Literal["text"] = "text"
content: str

@model_validator(mode="after")
Expand All @@ -108,7 +111,7 @@ def validate_content(self) -> TextContext:
class CLIContext(BaseContext):
"""Context from CLI command execution."""

type: Literal["cli"]
context_type: Literal["cli"] = "cli"
command: str | TypingSequence[str]
shell: bool = False
cwd: str | None = None
Expand All @@ -133,7 +136,7 @@ def validate_command(self) -> CLIContext:
class SourceContext(BaseContext):
"""Context from Python source code."""

type: Literal["source"]
context_type: Literal["source"] = "source"
import_path: str
recursive: bool = False
include_tests: bool = False
Expand All @@ -150,7 +153,7 @@ def validate_import_path(self) -> SourceContext:
class CallableContext(BaseContext):
"""Context from executing a Python callable."""

type: Literal["callable"]
context_type: Literal["callable"] = "callable"
import_path: str
keyword_args: dict[str, Any] = Field(default_factory=dict)

Expand All @@ -166,7 +169,7 @@ def validate_import_path(self) -> CallableContext:
class ImageContext(BaseContext):
"""Context for image input."""

type: Literal["image"]
context_type: Literal["image"] = "image"
path: str # Local path or URL
alt_text: str | None = None

Expand Down
4 changes: 2 additions & 2 deletions src/llmling/context/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def get_loader(self, context: Context) -> ContextLoader:
LoaderError: If no loader is registered for the context type
"""
try:
loader_cls = self._loaders[context.type]
loader_cls = self._loaders[context.context_type]
return loader_cls()
except KeyError as exc:
msg = f"No loader registered for context type: {context.type}"
msg = f"No loader registered for context type: {context.context_type}"
raise exceptions.LoaderError(msg) from exc
70 changes: 19 additions & 51 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from llmling.config.models import (
CallableContext,
CLIContext,
Context,
PathContext,
SourceContext,
TextContext,
Expand Down Expand Up @@ -89,7 +90,7 @@ def tmp_file(tmp_path: Path) -> Path:
@pytest.mark.asyncio
async def test_text_loader_basic() -> None:
"""Test basic text loading functionality."""
context = TextContext(type="text", content=SAMPLE_TEXT, description="Test text")
context = TextContext(content=SAMPLE_TEXT, description="Test text")
loader = TextContextLoader()
result = await loader.load(context, ProcessorRegistry())

Expand All @@ -103,18 +104,11 @@ async def test_text_loader_with_processors(processor_registry: ProcessorRegistry
"""Test text loading with processors."""
await processor_registry.startup()
try:
cfg = ProcessorConfig(
type="function", import_path="llmling.testing.processors.reverse_text"
)
path = "llmling.testing.processors.reverse_text"
cfg = ProcessorConfig(type="function", import_path=path)
processor_registry.register("reverse", cfg)

context = TextContext(
type="text",
content=SAMPLE_TEXT,
description="test",
processors=[ProcessingStep(name="reverse")],
)

steps = [ProcessingStep(name="reverse")]
context = TextContext(content=SAMPLE_TEXT, description="test", processors=steps)
loader = TextContextLoader()
result = await loader.load(context, processor_registry)
assert result.content == REVERSED_TEXT
Expand All @@ -126,7 +120,7 @@ async def test_text_loader_with_processors(processor_registry: ProcessorRegistry
@pytest.mark.asyncio
async def test_path_loader_file(tmp_file: Path) -> None:
"""Test loading from a file."""
context = PathContext(type="path", path=str(tmp_file), description="Test file")
context = PathContext(path=str(tmp_file), description="Test file")
loader = PathContextLoader()
result = await loader.load(context, ProcessorRegistry())

Expand All @@ -146,7 +140,7 @@ async def test_path_loader_with_file_protocol(tmp_path: Path) -> None:
path = upath.UPath(test_file)
file_url = str(path.as_uri()) # This will create the correct file:// URL format

context = PathContext(type="path", path=file_url, description="Test file URL")
context = PathContext(path=file_url, description="Test file URL")

loader = PathContextLoader()
result = await loader.load(context, ProcessorRegistry())
Expand All @@ -160,9 +154,7 @@ async def test_path_loader_with_file_protocol(tmp_path: Path) -> None:
@pytest.mark.asyncio
async def test_path_loader_error() -> None:
"""Test loading from a non-existent path."""
context = PathContext(
type="path", path="/nonexistent/file.txt", description="Test missing file"
)
context = PathContext(path="/nonexistent/file.txt", description="Test missing file")
loader = PathContextLoader()

with pytest.raises(exceptions.LoaderError):
Expand All @@ -173,12 +165,8 @@ async def test_path_loader_error() -> None:
@pytest.mark.asyncio
async def test_cli_loader_basic() -> None:
"""Test basic CLI command execution."""
context = CLIContext(
type="cli",
command=ECHO_COMMAND,
description="Test command",
shell=sys.platform == "win32",
)
is_shell = sys.platform == "win32"
context = CLIContext(command=ECHO_COMMAND, description="Test command", shell=is_shell)
loader = CLIContextLoader()
result = await loader.load(context, ProcessorRegistry())

Expand All @@ -189,9 +177,7 @@ async def test_cli_loader_basic() -> None:
@pytest.mark.asyncio
async def test_cli_loader_timeout() -> None:
"""Test CLI command timeout."""
context = CLIContext(
type="cli", command=SLEEP_COMMAND, timeout=0.1, description="test"
)
context = CLIContext(command=SLEEP_COMMAND, timeout=0.1, description="test")
loader = CLIContextLoader()

with pytest.raises(exceptions.LoaderError):
Expand All @@ -202,11 +188,8 @@ async def test_cli_loader_timeout() -> None:
@pytest.mark.asyncio
async def test_source_loader_basic() -> None:
"""Test basic source code loading."""
context = SourceContext(
type="source",
import_path="llmling.context.loaders.text",
description="Test source",
)
path = "llmling.context.loaders.text"
context = SourceContext(import_path=path, description="Test source")
loader = SourceContextLoader()
result = await loader.load(context, ProcessorRegistry())

Expand All @@ -217,9 +200,7 @@ async def test_source_loader_basic() -> None:
@pytest.mark.asyncio
async def test_source_loader_invalid_module() -> None:
"""Test loading from non-existent module."""
context = SourceContext(
type="source", import_path=INVALID_MODULE, description="Test invalid module"
)
context = SourceContext(import_path=INVALID_MODULE, description="Test invalid module")
loader = SourceContextLoader()

with pytest.raises(exceptions.LoaderError):
Expand All @@ -231,7 +212,6 @@ async def test_source_loader_invalid_module() -> None:
async def test_callable_loader_sync() -> None:
"""Test loading from synchronous callable."""
context = CallableContext(
type="callable",
import_path=f"{__name__}.sync_function",
description="Test sync callable",
keyword_args={"test": "value"},
Expand All @@ -247,7 +227,6 @@ async def test_callable_loader_sync() -> None:
async def test_callable_loader_async() -> None:
"""Test loading from asynchronous callable."""
context = CallableContext(
type="callable",
import_path=f"{__name__}.async_function",
description="Test async callable",
keyword_args={"test": "value"},
Expand All @@ -272,21 +251,10 @@ async def test_all_loaders_with_processors(
processor_registry.register("reverse", cfg)
processors = [ProcessingStep(name="upper"), ProcessingStep(name="reverse")]

contexts = [
TextContext(
type="text",
content=SAMPLE_TEXT,
description="Test text",
processors=processors,
),
PathContext(
type="path",
path=str(tmp_file),
description="Test file",
processors=processors,
),
contexts: list[Context] = [
TextContext(content=SAMPLE_TEXT, description="Test text", processors=processors),
PathContext(path=str(tmp_file), description="Test file", processors=processors),
CLIContext(
type="cli",
command=ECHO_COMMAND,
description="Test command",
shell=sys.platform == "win32",
Expand All @@ -301,7 +269,7 @@ async def test_all_loaders_with_processors(
}

for context in contexts:
loader = loaders[context.type]
loader = loaders[context.context_type]
result = await loader.load(context, processor_registry)
assert isinstance(result, LoadedContext)
assert result.content
Expand Down

0 comments on commit 24b9558

Please sign in to comment.