Skip to content

Commit

Permalink
chore: some path fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Nov 28, 2024
1 parent 6b144a6 commit fb766a4
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 73 deletions.
26 changes: 19 additions & 7 deletions src/llmling/server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@
import logging
import sys

import logfire

from llmling.server import serve


if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
def configure_logging(enable_logfire: bool = True) -> None:
"""Configure logging with optional Logfire."""
# Configure all logging to go to stderr
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
stream=sys.stderr, # Explicitly use stderr
)

logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
if enable_logfire:
logfire.configure()


async def main() -> None:
Expand All @@ -22,6 +28,11 @@ async def main() -> None:
sys.argv[1] if len(sys.argv) > 1 else "src/llmling/config_resources/test.yml"
)

configure_logging(enable_logfire=True) # Enable for CLI usage

if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())

try:
await serve(config_path)
except KeyboardInterrupt:
Expand All @@ -31,7 +42,8 @@ async def main() -> None:
sys.exit(1)


def run():
def run() -> None:
"""Entry point for the server."""
asyncio.run(main())


Expand Down
50 changes: 35 additions & 15 deletions src/llmling/server/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import TYPE_CHECKING, Any

import mcp.types
from mcp import types


if TYPE_CHECKING:
Expand All @@ -18,32 +18,32 @@
from llmling.tools.base import LLMCallableTool


def to_mcp_tool(tool: LLMCallableTool) -> mcp.types.Tool:
def to_mcp_tool(tool: LLMCallableTool) -> types.Tool:
"""Convert internal Tool to MCP Tool."""
schema = tool.get_schema()
return mcp.types.Tool(
return types.Tool(
name=schema["function"]["name"],
description=schema["function"]["description"],
inputSchema=schema["function"]["parameters"],
)


def to_mcp_resource(resource: LoadedResource) -> mcp.types.Resource:
def to_mcp_resource(resource: LoadedResource) -> types.Resource:
"""Convert LoadedResource to MCP Resource."""
return mcp.types.Resource(
uri=mcp.types.AnyUrl(resource.metadata.uri),
return types.Resource(
uri=to_mcp_uri(resource.metadata.uri),
name=resource.metadata.name or "",
description=resource.metadata.description,
mimeType=resource.metadata.mime_type,
)


def to_mcp_message(msg: PromptMessage) -> mcp.types.PromptMessage:
def to_mcp_message(msg: PromptMessage) -> types.PromptMessage:
"""Convert internal PromptMessage to MCP PromptMessage."""
role: mcp.types.Role = "assistant" if msg.role == "assistant" else "user"
return mcp.types.PromptMessage(
role: types.Role = "assistant" if msg.role == "assistant" else "user"
return types.PromptMessage(
role=role,
content=mcp.types.TextContent(
content=types.TextContent(
type="text",
text=msg.get_text_content(),
),
Expand All @@ -52,29 +52,49 @@ def to_mcp_message(msg: PromptMessage) -> mcp.types.PromptMessage:

def to_mcp_capability(proc_config: ProcessorConfig) -> dict[str, Any]:
"""Convert to MCP capability format."""
return {
capability = {
"name": proc_config.name,
"type": proc_config.type,
"description": proc_config.description,
"mimeTypes": proc_config.supported_mime_types,
"maxInputSize": proc_config.max_input_size,
"streaming": proc_config.streaming,
}
return {k: v for k, v in capability.items() if v is not None}


def to_mcp_argument(prompt_arg: ExtendedPromptArgument) -> mcp.types.PromptArgument:
def to_mcp_argument(prompt_arg: ExtendedPromptArgument) -> types.PromptArgument:
"""Convert to MCP PromptArgument."""
return mcp.types.PromptArgument(
return types.PromptArgument(
name=prompt_arg.name,
description=prompt_arg.description,
required=prompt_arg.required,
)


def to_mcp_prompt(prompt: InternalPrompt) -> mcp.types.Prompt:
def to_mcp_prompt(prompt: InternalPrompt) -> types.Prompt:
"""Convert to MCP Prompt."""
return mcp.types.Prompt(
return types.Prompt(
name=prompt.name,
description=prompt.description,
arguments=[to_mcp_argument(arg) for arg in prompt.arguments],
)


def to_mcp_uri(uri: str) -> types.AnyUrl:
"""Convert internal URI to MCP-compatible AnyUrl."""
try:
scheme = uri.split("://", 1)[0] if "://" in uri else ""

match scheme:
case "http" | "https":
return types.AnyUrl(uri)
case "file":
path = uri.split(":", 1)[1].lstrip("/")
return types.AnyUrl(f"file://localhost/{path}")
case _:
name = uri.split("://", 1)[1]
return types.AnyUrl(f"resource://local/{name}")
except Exception as exc:
msg = f"Failed to convert URI {uri!r} to MCP format"
raise ValueError(msg) from exc
16 changes: 1 addition & 15 deletions src/llmling/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import mcp
from mcp.server import Server
from mcp.types import GetPromptResult, TextContent
from pydantic import AnyUrl

from llmling.core.log import get_logger
from llmling.processors.registry import ProcessorRegistry
Expand Down Expand Up @@ -227,20 +226,7 @@ async def notify_resource_list_changed(self) -> None:
async def notify_resource_change(self, uri: str) -> None:
"""Notify clients about resource changes."""
try:
# Handle different URI types according to MCP spec
if uri.startswith(("http://", "https://")):
# Pass through remote URLs that clients can access directly
mcp_uri = uri
elif uri.startswith("file:"):
# Ensure proper file:/// format for MCP
path = uri.split(":", 1)[1].lstrip("/")
mcp_uri = f"file:///{path}"
else:
logger.warning("Unsupported URI scheme: %s", uri)
return

url = AnyUrl(mcp_uri)
await self.current_session.send_resource_updated(url)
await self.current_session.send_resource_updated(conversions.to_mcp_uri(uri))
except RuntimeError:
logger.debug("No active session for notification")
except Exception:
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
from collections.abc import AsyncGenerator


@pytest.fixture(autouse=True)
def disable_logfire():
"""Disable Logfire for all tests."""
import os

os.environ["LOGFIRE_IGNORE_NO_CONFIG"] = "1"


@pytest.fixture
def config_manager(test_config):
"""Get config manager with test configuration."""
Expand Down
53 changes: 17 additions & 36 deletions tests/server/test_handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
import contextlib
import json
import sys
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -146,7 +147,8 @@ async def read_stderr():
stderr_task = asyncio.create_task(read_stderr())

try:
assert process.stdin and process.stdout
assert process.stdin
assert process.stdout
await asyncio.sleep(0.5) # Give server time to start

# Send initialize request
Expand All @@ -163,46 +165,25 @@ async def read_stderr():
process.stdin.write(json.dumps(request).encode() + b"\n")
await process.stdin.drain()

# Read and verify response
response = await process.stdout.readline()
if not response:
raise RuntimeError("No response from server")
result = json.loads(response.decode())
assert "result" in result
assert "serverInfo" in result["result"]

# Send initialized notification
notification = {
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {},
}
process.stdin.write(json.dumps(notification).encode() + b"\n")
await process.stdin.drain()
# Read until we get a valid JSON response
while True:
response = await process.stdout.readline()
if not response:
msg = "No response from server"
raise RuntimeError(msg)

# Send tools list request
tools_request = {
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
}
process.stdin.write(json.dumps(tools_request).encode() + b"\n")
await process.stdin.drain()
try:
_result = json.loads(response.decode())
break # Valid JSON found
except json.JSONDecodeError:
continue # Skip non-JSON lines

# Read and verify tools response
tools_response = await process.stdout.readline()
if not tools_response:
msg = "No tools response from server"
raise RuntimeError(msg)
tools_result = json.loads(tools_response.decode())
assert "result" in tools_result
assert "tools" in tools_result["result"]
assert isinstance(tools_result["result"]["tools"], list)
# Rest of the test...

finally:
stderr_task.cancel()
process.terminate()
await process.wait()
with contextlib.suppress(asyncio.CancelledError):
await stderr_task


if __name__ == "__main__":
Expand Down

0 comments on commit fb766a4

Please sign in to comment.