From 9df64d19defbd6fc93d2911921afe4f047e0d530 Mon Sep 17 00:00:00 2001 From: Philipp Temminghoff Date: Mon, 2 Dec 2024 22:34:42 +0100 Subject: [PATCH] feat: allow loading multiple resources from loader --- src/llmling/config/runtime.py | 5 +- src/llmling/resources/base.py | 33 ++++---- src/llmling/resources/loaders/callable.py | 6 +- src/llmling/resources/loaders/cli.py | 6 +- src/llmling/resources/loaders/image.py | 5 +- src/llmling/resources/loaders/path.py | 67 ++++++++++++---- src/llmling/resources/loaders/source.py | 6 +- src/llmling/resources/loaders/text.py | 11 +-- src/llmling/resources/registry.py | 48 +++++++----- tests/resources/test_loaders.py | 96 +++++++++++++++++++++-- tests/test_context.py | 25 +++--- 11 files changed, 223 insertions(+), 85 deletions(-) diff --git a/src/llmling/config/runtime.py b/src/llmling/config/runtime.py index 8731ebf..809d0db 100644 --- a/src/llmling/config/runtime.py +++ b/src/llmling/config/runtime.py @@ -317,7 +317,10 @@ async def load_resource_by_uri(self, uri: str) -> LoadedResource: resolved_uri, resource = await self.resolve_resource_uri(uri) loader = self._loader_registry.get_loader(resource) loader = loader.create(resource, loader.get_name_from_uri(resolved_uri)) - return await loader.load(processor_registry=self._processor_registry) + async for res in loader.load(processor_registry=self._processor_registry): + return res # Return first resource + msg = "No resources loaded" + raise exceptions.ResourceError(msg) # noqa: TRY301 except Exception as exc: msg = f"Failed to load resource from URI {uri}" raise exceptions.ResourceError(msg) from exc diff --git a/src/llmling/resources/base.py b/src/llmling/resources/base.py index 5f40642..407e8f1 100644 --- a/src/llmling/resources/base.py +++ b/src/llmling/resources/base.py @@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast, overload import urllib.parse -import logfire import upath from llmling.completions.protocols import CompletionProvider @@ -24,6 +23,8 @@ if TYPE_CHECKING: + from collections.abc import AsyncIterator + from llmling.processors.registry import ProcessorRegistry @@ -208,31 +209,31 @@ def resource_type(self) -> str: return fields["resource_type"].default # type: ignore @overload - async def load( + def load( self, context: LoaderContext[TResource], processor_registry: ProcessorRegistry | None = None, - ) -> LoadedResource: ... + ) -> AsyncIterator[LoadedResource]: ... @overload - async def load( + def load( self, context: TResource, processor_registry: ProcessorRegistry | None = None, - ) -> LoadedResource: ... + ) -> AsyncIterator[LoadedResource]: ... @overload - async def load( + def load( self, context: None = None, processor_registry: ProcessorRegistry | None = None, - ) -> LoadedResource: ... + ) -> AsyncIterator[LoadedResource]: ... async def load( self, - context: LoaderContext[TResource] | TResource | None = None, + context: LoaderContext[TResource] | None = None, processor_registry: ProcessorRegistry | None = None, - ) -> LoadedResource: + ) -> AsyncIterator[LoadedResource]: """Load and process content. Args: @@ -262,12 +263,11 @@ async def load( case _: msg = f"Invalid context type: {type(context)}" raise exceptions.LoaderError(msg) - with logfire.span( - "Loading resource", - resource_type=self.resource_type, - name=name, - ): - return await self._load_impl(resource, name, processor_registry) + + generator = self._load_impl(resource, name, processor_registry) + # Then yield from the generator + async for result in generator: + yield result @abstractmethod async def _load_impl( @@ -275,5 +275,6 @@ async def _load_impl( resource: TResource, name: str, processor_registry: ProcessorRegistry | None, - ) -> LoadedResource: + ) -> AsyncIterator[LoadedResource]: """Implementation of actual loading logic.""" + yield NotImplemented # type: ignore diff --git a/src/llmling/resources/loaders/callable.py b/src/llmling/resources/loaders/callable.py index c516411..f2473a9 100644 --- a/src/llmling/resources/loaders/callable.py +++ b/src/llmling/resources/loaders/callable.py @@ -10,6 +10,8 @@ if TYPE_CHECKING: + from collections.abc import AsyncIterator + from llmling.processors.registry import ProcessorRegistry from llmling.resources.models import LoadedResource @@ -29,7 +31,7 @@ async def _load_impl( resource: CallableResource, name: str, processor_registry: ProcessorRegistry | None, - ) -> LoadedResource: + ) -> AsyncIterator[LoadedResource]: """Execute callable and load result.""" try: kwargs = resource.keyword_args @@ -39,7 +41,7 @@ async def _load_impl( processed = await processor_registry.process(content, procs) content = processed.content meta = {"import_path": resource.import_path, "args": resource.keyword_args} - return create_loaded_resource( + yield create_loaded_resource( content=content, source_type="callable", uri=self.create_uri(name=name), diff --git a/src/llmling/resources/loaders/cli.py b/src/llmling/resources/loaders/cli.py index 9109400..9ade4c0 100644 --- a/src/llmling/resources/loaders/cli.py +++ b/src/llmling/resources/loaders/cli.py @@ -12,6 +12,8 @@ if TYPE_CHECKING: + from collections.abc import AsyncIterator + from llmling.processors.registry import ProcessorRegistry from llmling.resources.models import LoadedResource @@ -31,7 +33,7 @@ async def _load_impl( resource: CLIResource, name: str, processor_registry: ProcessorRegistry | None, - ) -> LoadedResource: + ) -> AsyncIterator[LoadedResource]: """Execute command and load output.""" command = cmd if isinstance((cmd := resource.command), str) else " ".join(cmd) try: @@ -63,7 +65,7 @@ async def _load_impl( processed = await processor_registry.process(content, procs) content = processed.content meta = {"command": command, "exit_code": proc.returncode} - return create_loaded_resource( + yield create_loaded_resource( content=content, source_type="cli", uri=self.create_uri(name=name), diff --git a/src/llmling/resources/loaders/image.py b/src/llmling/resources/loaders/image.py index 5db5541..f537939 100644 --- a/src/llmling/resources/loaders/image.py +++ b/src/llmling/resources/loaders/image.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: + from collections.abc import AsyncIterator import os from llmling.processors.registry import ProcessorRegistry @@ -63,7 +64,7 @@ async def _load_impl( resource: ImageResource, name: str, processor_registry: ProcessorRegistry | None, - ) -> LoadedResource: + ) -> AsyncIterator[LoadedResource]: """Load and process image content.""" try: path_obj = upath.UPath(resource.path) @@ -87,7 +88,7 @@ async def _load_impl( if resource.alt_text: placeholder_text = f"{placeholder_text} - {resource.alt_text}" - return create_loaded_resource( + yield create_loaded_resource( content=placeholder_text, source_type="image", uri=self.create_uri(name=name), diff --git a/src/llmling/resources/loaders/path.py b/src/llmling/resources/loaders/path.py index ca42df8..11aca94 100644 --- a/src/llmling/resources/loaders/path.py +++ b/src/llmling/resources/loaders/path.py @@ -15,6 +15,8 @@ logger = get_logger(__name__) if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from llmling.processors.registry import ProcessorRegistry from llmling.resources.models import LoadedResource @@ -82,25 +84,56 @@ async def _load_impl( resource: PathResource, name: str, processor_registry: ProcessorRegistry | None, - ) -> LoadedResource: - """Load content from a file or URL.""" + ) -> AsyncGenerator[LoadedResource, None]: + """Load content from file(s).""" try: path = UPath(resource.path) - content = path.read_text("utf-8") - - if processor_registry and (procs := resource.processors): - processed = await processor_registry.process(content, procs) - content = processed.content - meta = {"type": "path", "path": str(path), "scheme": path.protocol} - return create_loaded_resource( - content=content, - source_type="path", - uri=self.create_uri(name=name), - mime_type=self.supported_mime_types[0], - name=resource.description or path.name, - description=resource.description, - additional_metadata=meta, - ) + + if path.is_dir(): + # Handle directory recursively + for file_path in path.rglob("*"): + if file_path.is_file(): + content = file_path.read_text("utf-8") + if processor_registry and (procs := resource.processors): + processed = await processor_registry.process(content, procs) + content = processed.content + + yield create_loaded_resource( + content=content, + source_type="path", + uri=self.create_uri( + name=file_path.name + ), # Use filename for URI + mime_type=self.supported_mime_types[0], + name=resource.description or file_path.name, + description=resource.description, + additional_metadata={ + "type": "path", + "path": str(file_path), + "scheme": file_path.protocol, + "relative_to": str(path), # Add original directory + }, + ) + else: + # Handle single file + content = path.read_text("utf-8") + if processor_registry and (procs := resource.processors): + processed = await processor_registry.process(content, procs) + content = processed.content + + yield create_loaded_resource( + content=content, + source_type="path", + uri=self.create_uri(name=name), + mime_type=self.supported_mime_types[0], + name=resource.description or path.name, + description=resource.description, + additional_metadata={ + "type": "path", + "path": str(path), + "scheme": path.protocol, + }, + ) except Exception as exc: msg = f"Failed to load content from {resource.path}" raise exceptions.LoaderError(msg) from exc diff --git a/src/llmling/resources/loaders/source.py b/src/llmling/resources/loaders/source.py index d8af481..21ca31e 100644 --- a/src/llmling/resources/loaders/source.py +++ b/src/llmling/resources/loaders/source.py @@ -11,6 +11,8 @@ if TYPE_CHECKING: + from collections.abc import AsyncIterator + from llmling.processors.registry import ProcessorRegistry from llmling.resources.models import LoadedResource @@ -30,7 +32,7 @@ async def _load_impl( resource: SourceResource, name: str, processor_registry: ProcessorRegistry | None, - ) -> LoadedResource: + ) -> AsyncIterator[LoadedResource]: """Load Python source content.""" try: content = importing.get_module_source( @@ -43,7 +45,7 @@ async def _load_impl( processed = await processor_registry.process(content, procs) content = processed.content - return create_loaded_resource( + yield create_loaded_resource( content=content, source_type="source", uri=self.create_uri(name=name), diff --git a/src/llmling/resources/loaders/text.py b/src/llmling/resources/loaders/text.py index c0f21f9..86ab8fa 100644 --- a/src/llmling/resources/loaders/text.py +++ b/src/llmling/resources/loaders/text.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: + from collections.abc import AsyncIterator + from llmling.processors.registry import ProcessorRegistry from llmling.resources.models import LoadedResource @@ -28,15 +30,15 @@ async def _load_impl( resource: TextResource, name: str, processor_registry: ProcessorRegistry | None, - ) -> LoadedResource: - """Implement actual loading logic.""" + ) -> AsyncIterator[LoadedResource]: + """Load text content.""" try: content = resource.content if processor_registry and (procs := resource.processors): processed = await processor_registry.process(content, procs) content = processed.content - return create_loaded_resource( + yield create_loaded_resource( content=content, source_type="text", uri=self.create_uri(name=name), @@ -46,6 +48,5 @@ async def _load_impl( additional_metadata={"type": "text"}, ) except Exception as exc: - logger.exception("Failed to load text content") - msg = "Failed to load text content" + msg = f"Failed to load text content: {exc}" raise exceptions.LoaderError(msg) from exc diff --git a/src/llmling/resources/registry.py b/src/llmling/resources/registry.py index 7e7ed7e..056c138 100644 --- a/src/llmling/resources/registry.py +++ b/src/llmling/resources/registry.py @@ -23,6 +23,8 @@ if TYPE_CHECKING: + from collections.abc import AsyncIterator + from llmling.processors.registry import ProcessorRegistry from llmling.resources.loaders.registry import ResourceLoaderRegistry from llmling.resources.models import LoadedResource @@ -147,43 +149,53 @@ def get_uri(self, name: str) -> str: loader = loader.create(resource, name) # Create instance return loader.create_uri(name=name) - @logfire.instrument("Loading resource {name}") - async def load(self, name: str, *, force_reload: bool = False) -> LoadedResource: - """Load a resource by name.""" + async def load_all( + self, name: str, *, force_reload: bool = False + ) -> AsyncIterator[LoadedResource]: + """Load all resources for a given name.""" try: resource = self[name] uri = self.get_uri(name) # Check cache unless force reload if not force_reload and uri in self._cache: - return self._cache[uri] + yield self._cache[uri] + return # Get loader and initialize with context loader = self.loader_registry.get_loader(resource) - loader = loader.create(resource, name) # Create with named context + loader = loader.create(resource, name) - loaded = await loader.load( - context=loader.context, # Pass the context we created + async for loaded in loader.load( + context=loader.context, processor_registry=self.processor_registry, - ) + ): + # Ensure the URI is set correctly + if loaded.metadata.uri != uri: + msg = "Loader returned different URI than expected: %s != %s" + logger.warning(msg, loaded.metadata.uri, uri) + loaded.metadata.uri = uri + + # Update cache using URI + self._cache[uri] = loaded + self._last_loaded[uri] = datetime.now() - # Ensure the URI is set correctly - if loaded.metadata.uri != uri: - msg = "Loader returned different URI than expected: %s != %s" - logger.warning(msg, loaded.metadata.uri, uri) - loaded.metadata.uri = uri + yield loaded - # Update cache using URI - self._cache[uri] = loaded - self._last_loaded[uri] = datetime.now() except KeyError as exc: msg = f"Resource not found: {name}" raise exceptions.ResourceError(msg) from exc except Exception as exc: msg = f"Failed to load resource {name}: {exc}" raise exceptions.ResourceError(msg) from exc - else: - return loaded + + @logfire.instrument("Loading resource {name}") + async def load(self, name: str, *, force_reload: bool = False) -> LoadedResource: + """Load first/single resource (backward compatibility).""" + async for resource in self.load_all(name, force_reload=force_reload): + return resource + msg = f"No resources loaded for {name}" + raise exceptions.ResourceError(msg) async def load_by_uri(self, uri: str) -> LoadedResource: """Load a resource by URI.""" diff --git a/tests/resources/test_loaders.py b/tests/resources/test_loaders.py index 84f6627..15186b6 100644 --- a/tests/resources/test_loaders.py +++ b/tests/resources/test_loaders.py @@ -13,6 +13,9 @@ TextResource, ) from llmling.core import exceptions +from llmling.core.typedefs import ProcessingStep +from llmling.processors.base import ProcessorConfig +from llmling.processors.registry import ProcessorRegistry from llmling.resources import ( CallableResourceLoader, CLIResourceLoader, @@ -28,8 +31,6 @@ if TYPE_CHECKING: from pathlib import Path - from llmling.processors.registry import ProcessorRegistry - @pytest.fixture def loader_registry() -> ResourceLoaderRegistry: @@ -46,8 +47,8 @@ def loader_registry() -> ResourceLoaderRegistry: @pytest.fixture def processor_registry() -> ProcessorRegistry: - """Mock processor registry.""" - return None # type: ignore + """Create a processor registry for testing.""" + return ProcessorRegistry() @pytest.mark.parametrize( @@ -155,7 +156,7 @@ async def test_text_loader(processor_registry: ProcessorRegistry) -> None: resource = TextResource(content=content) loader = TextResourceLoader(LoaderContext(resource=resource, name="test")) - result = await loader.load(processor_registry=processor_registry) + result = await anext(loader.load(processor_registry=processor_registry)) assert result.content == content assert result.metadata.mime_type == "text/plain" assert result.source_type == "text" @@ -175,7 +176,7 @@ async def test_path_loader( resource = PathResource(path=str(test_file)) loader = PathResourceLoader(LoaderContext(resource=resource, name="test")) - result = await loader.load(processor_registry=processor_registry) + result = await anext(loader.load(processor_registry=processor_registry)) assert result.content == content assert result.source_type == "path" @@ -187,7 +188,7 @@ async def test_cli_loader(processor_registry: ProcessorRegistry) -> None: resource = CLIResource(command="echo test", shell=True) loader = CLIResourceLoader(LoaderContext(resource=resource, name="test")) - result = await loader.load(processor_registry=processor_registry) + result = await anext(loader.load(processor_registry=processor_registry)) assert result.content.strip() == "test" assert result.source_type == "cli" @@ -198,7 +199,7 @@ async def test_source_loader(processor_registry: ProcessorRegistry) -> None: resource = SourceResource(import_path="llmling.core.log") loader = SourceResourceLoader(LoaderContext(resource=resource, name="test")) - result = await loader.load(processor_registry=processor_registry) + result = await anext(loader.load(processor_registry=processor_registry)) assert "get_logger" in result.content assert result.source_type == "source" assert result.metadata.mime_type == "text/x-python" @@ -284,5 +285,84 @@ def test_registry_uri_templates(loader_registry: ResourceLoaderRegistry) -> None assert all("scheme" in t and "template" in t and "mimeTypes" in t for t in templates) +@pytest.mark.asyncio +async def test_path_loader_directory( + tmp_path: Path, + processor_registry: ProcessorRegistry, +) -> None: + """Test PathResourceLoader with directory.""" + # Create test directory structure + (tmp_path / "subdir").mkdir() + (tmp_path / "file1.txt").write_text("content 1") + (tmp_path / "file2.md").write_text("content 2") + (tmp_path / "subdir" / "file3.txt").write_text("content 3") + + resource = PathResource(path=str(tmp_path)) + loader = PathResourceLoader(LoaderContext(resource=resource, name="test")) + + # Collect all loaded resources using async list comp + files = [ + result async for result in loader.load(processor_registry=processor_registry) + ] + + # Test results + assert len(files) == 3 # noqa: PLR2004 + assert {f.content for f in files} == {"content 1", "content 2", "content 3"} + # Test URIs use basenames + assert all(f.metadata.uri.startswith("file:///") for f in files) + assert {f.metadata.name for f in files} == {"file1.txt", "file2.md", "file3.txt"} + # Test relative path metadata + assert all("relative_to" in f.metadata.extra for f in files) + assert str(tmp_path) == files[0].metadata.extra["relative_to"] + + +@pytest.mark.asyncio +async def test_path_loader_empty_directory( + tmp_path: Path, + processor_registry: ProcessorRegistry, +) -> None: + """Test loading from an empty directory.""" + resource = PathResource(path=str(tmp_path)) + loader = PathResourceLoader(LoaderContext(resource=resource, name="test")) + + files = [ + result async for result in loader.load(processor_registry=processor_registry) + ] + + assert len(files) == 0 + + +@pytest.mark.asyncio +async def test_path_loader_directory_with_processors( + tmp_path: Path, + processor_registry: ProcessorRegistry, +) -> None: + """Test directory loading with processors applied to each file.""" + # Create test files + (tmp_path / "file1.txt").write_text("test1") + (tmp_path / "file2.txt").write_text("test2") + + # Set up processor + processor_registry.register( + "reverse", + ProcessorConfig( + type="function", + import_path="llmling.testing.processors.reverse_text", + ), + ) + + resource = PathResource( + path=str(tmp_path), + processors=[ProcessingStep(name="reverse")], + ) + loader = PathResourceLoader(LoaderContext(resource=resource, name="test")) + files = [ + result async for result in loader.load(processor_registry=processor_registry) + ] + + assert len(files) == 2 # noqa: PLR2004 + assert {f.content for f in files} == {"1tset", "2tset"} # Reversed content + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_context.py b/tests/test_context.py index 16dd7a9..faf1d75 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -90,7 +90,7 @@ async def test_text_loader_basic() -> None: """Test basic text loading functionality.""" context = TextResource(content=SAMPLE_TEXT, description="Test text") loader = TextResourceLoader() - result = await loader.load(context, ProcessorRegistry()) + result = await anext(loader.load(context, ProcessorRegistry())) assert isinstance(result, LoadedResource) assert result.content == SAMPLE_TEXT @@ -108,7 +108,7 @@ async def test_text_loader_with_processors(processor_registry: ProcessorRegistry steps = [ProcessingStep(name="reverse")] context = TextResource(content=SAMPLE_TEXT, description="test", processors=steps) loader = TextResourceLoader() - result = await loader.load(context, processor_registry) + result = await anext(loader.load(context, processor_registry)) assert result.content == REVERSED_TEXT finally: await processor_registry.shutdown() @@ -120,7 +120,8 @@ async def test_path_loader_file(tmp_file: Path) -> None: """Test loading from a file.""" context = PathResource(path=str(tmp_file), description="Test file") loader = PathResourceLoader() - result = await loader.load(context, ProcessorRegistry()) + coro = loader.load(context, ProcessorRegistry()) + result = await anext(coro) assert result.content == TEST_FILE_CONTENT assert result.metadata.extra["type"] == "path" @@ -141,7 +142,7 @@ async def test_path_loader_with_file_protocol(tmp_path: Path) -> None: context = PathResource(path=file_url, description="Test file URL") loader = PathResourceLoader() - result = await loader.load(context, ProcessorRegistry()) + result = await anext(loader.load(context, ProcessorRegistry())) assert result.content == TEST_FILE_CONTENT assert result.metadata.extra["path"] == file_url @@ -156,7 +157,7 @@ async def test_path_loader_error() -> None: loader = PathResourceLoader() with pytest.raises(exceptions.LoaderError): - await loader.load(context, ProcessorRegistry()) + await anext(loader.load(context, ProcessorRegistry())) # CLI Loader Tests @@ -168,7 +169,7 @@ async def test_cli_loader_basic() -> None: command=ECHO_COMMAND, description="Test command", shell=is_shell ) loader = CLIResourceLoader() - result = await loader.load(context, ProcessorRegistry()) + result = await anext(loader.load(context, ProcessorRegistry())) assert "test" in result.content.strip() assert result.metadata.extra["exit_code"] == 0 @@ -181,7 +182,7 @@ async def test_cli_loader_timeout() -> None: loader = CLIResourceLoader() with pytest.raises(exceptions.LoaderError): - await loader.load(context, ProcessorRegistry()) + await anext(loader.load(context, ProcessorRegistry())) # Source Loader Tests @@ -191,7 +192,7 @@ async def test_source_loader_basic() -> None: path = "llmling.resources.loaders.text" context = SourceResource(import_path=path, description="Test source") loader = SourceResourceLoader() - result = await loader.load(context, ProcessorRegistry()) + result = await anext(loader.load(context, ProcessorRegistry())) assert "class TextResourceLoader" in result.content assert result.metadata.extra["import_path"] == context.import_path @@ -206,7 +207,7 @@ async def test_source_loader_invalid_module() -> None: loader = SourceResourceLoader() with pytest.raises(exceptions.LoaderError): - await loader.load(context, ProcessorRegistry()) + await anext(loader.load(context, ProcessorRegistry())) # Callable Loader Tests @@ -219,7 +220,7 @@ async def test_callable_loader_sync() -> None: keyword_args={"test": "value"}, ) loader = CallableResourceLoader() - result = await loader.load(context, ProcessorRegistry()) + result = await anext(loader.load(context, processor_registry=ProcessorRegistry())) assert "Sync result with" in result.content assert result.metadata.extra["import_path"] == context.import_path @@ -234,7 +235,7 @@ async def test_callable_loader_async() -> None: keyword_args={"test": "value"}, ) loader = CallableResourceLoader() - result = await loader.load(context, ProcessorRegistry()) + result = await anext(loader.load(context, ProcessorRegistry())) assert "Async result with" in result.content assert result.metadata.extra["import_path"] == context.import_path @@ -272,7 +273,7 @@ async def test_all_loaders_with_processors( for context in resources: loader = loaders[context.resource_type] - result = await loader.load(context, processor_registry) + result = await anext(loader.load(context, processor_registry)) assert isinstance(result, LoadedResource) assert result.content assert result.content.startswith("'")