From 733ba147fc3189675f3b3cd492f768d253aa2b8d Mon Sep 17 00:00:00 2001 From: Philipp Temminghoff Date: Sun, 15 Dec 2024 00:13:51 +0100 Subject: [PATCH] chore: cleanup --- src/llmling/config/runtime.py | 6 +- src/llmling/core/baseregistry.py | 8 +- src/llmling/core/chain.py | 74 ++++++++---------- src/llmling/resources/loaders/repository.py | 2 +- src/llmling/tools/openapi.py | 84 +++++++++++---------- src/llmling/utils/importing.py | 6 +- 6 files changed, 82 insertions(+), 98 deletions(-) diff --git a/src/llmling/config/runtime.py b/src/llmling/config/runtime.py index d56c791..857d046 100644 --- a/src/llmling/config/runtime.py +++ b/src/llmling/config/runtime.py @@ -890,11 +890,7 @@ def register_static_prompt( first_line[:100] + "..." if len(first_line) > 100 else first_line # noqa: PLR2004 ) - prompt = StaticPrompt( - name=name, - description=description, - messages=messages, - ) + prompt = StaticPrompt(name=name, description=description, messages=messages) self._prompt_registry.register(name, prompt, replace=replace) except Exception as exc: diff --git a/src/llmling/core/baseregistry.py b/src/llmling/core/baseregistry.py index bb6716a..ee84169 100644 --- a/src/llmling/core/baseregistry.py +++ b/src/llmling/core/baseregistry.py @@ -112,13 +112,13 @@ def _validate_item(self, item: Any) -> TItem: async def _initialize_item(self, item: TItem) -> None: """Initialize an item during startup.""" - if hasattr(item, "startup") and callable(item.startup): - await item.startup() + if hasattr(item, "startup") and callable(item.startup): # pyright: ignore + await item.startup() # pyright: ignore async def _cleanup_item(self, item: TItem) -> None: """Clean up an item during shutdown.""" - if hasattr(item, "shutdown") and callable(item.shutdown): - await item.shutdown() + if hasattr(item, "shutdown") and callable(item.shutdown): # pyright: ignore + await item.shutdown() # pyright: ignore # Implementing MutableMapping methods def __getitem__(self, key: TKey) -> TItem: diff --git a/src/llmling/core/chain.py b/src/llmling/core/chain.py index 64fd8f9..1563770 100644 --- a/src/llmling/core/chain.py +++ b/src/llmling/core/chain.py @@ -34,6 +34,32 @@ class StepCondition(BaseModel): operator: Literal["eq", "gt", "lt", "contains", "exists"] value: Any = None + def evaluate_with_value(self, value: Any) -> bool: + """Evaluate this condition against a value. + + Args: + value: The value to evaluate against the condition. + + Returns: + bool: True if the condition is met, False otherwise. + """ + field_value = value.get(self.field) if isinstance(value, dict) else value + + match self.operator: + case "eq": + return field_value == self.value + case "gt": + return field_value > self.value + case "lt": + return field_value < self.value + case "contains": + try: + return self.value in field_value # type: ignore + except TypeError: + return False + case "exists": + return field_value is not None + @dataclass class StepResult: @@ -246,15 +272,8 @@ async def _execute_step( while True: try: # Check condition if any - if step.condition and not self._evaluate_condition( - step.condition, input_value - ): - return StepResult( - success=True, - result=input_value, # Pass through unchanged - duration=0, - ) - + if step.condition and not step.condition.evaluate_with_value(input_value): + return StepResult(success=True, result=input_value, duration=0) # Prepare kwargs if isinstance(input_value, dict): kwargs = {**input_value, **step.keyword_args} @@ -263,10 +282,8 @@ async def _execute_step( # Execute with timeout if specified if step.timeout: - result = await asyncio.wait_for( - self.runtime.execute_tool(step.tool, **kwargs), - timeout=step.timeout, - ) + fut = self.runtime.execute_tool(step.tool, **kwargs) + result = await asyncio.wait_for(fut, timeout=step.timeout) else: result = await self.runtime.execute_tool(step.tool, **kwargs) @@ -303,30 +320,7 @@ async def _execute_step( continue raise # Max retries exceeded - def _evaluate_condition(self, condition: StepCondition, value: Any) -> bool: - """Evaluate a step condition.""" - field_value = value.get(condition.field) if isinstance(value, dict) else value - - match condition.operator: - case "eq": - return field_value == condition.value - case "gt": - return field_value > condition.value - case "lt": - return field_value < condition.value - case "contains": - try: - return condition.value in field_value # type: ignore - except TypeError: - return False - case "exists": - return field_value is not None - - async def _execute_sequential( - self, - pipeline: Pipeline, - results: StepResults, - ) -> Any: + async def _execute_sequential(self, pipeline: Pipeline, results: StepResults) -> Any: """Execute steps sequentially.""" current = pipeline.input @@ -338,11 +332,7 @@ async def _execute_sequential( return current - async def _execute_parallel( - self, - pipeline: Pipeline, - results: StepResults, - ) -> Any: + async def _execute_parallel(self, pipeline: Pipeline, results: StepResults) -> Any: """Execute independent steps in parallel.""" semaphore = asyncio.Semaphore(pipeline.max_parallel) diff --git a/src/llmling/resources/loaders/repository.py b/src/llmling/resources/loaders/repository.py index be894c6..bcac852 100644 --- a/src/llmling/resources/loaders/repository.py +++ b/src/llmling/resources/loaders/repository.py @@ -129,7 +129,7 @@ async def _load_impl( loaded.content = result.content yield loaded - except git.exc.GitCommandError as exc: + except git.exc.GitCommandError as exc: # type: ignore msg = f"Git operation failed: {exc}" raise exceptions.LoaderError(msg) from exc except Exception as exc: diff --git a/src/llmling/tools/openapi.py b/src/llmling/tools/openapi.py index 97078a9..b006799 100644 --- a/src/llmling/tools/openapi.py +++ b/src/llmling/tools/openapi.py @@ -39,6 +39,47 @@ } +def parse_operations(paths: dict) -> dict[str, dict[str, Any]]: + operations = {} + for path, path_item in paths.items(): + for method, operation in path_item.items(): + if method not in {"get", "post", "put", "delete", "patch"}: + continue + + # Generate operation ID if not provided + op_id = operation.get("operationId") + if not op_id: + op_id = f"{method}_{path.replace('/', '_').strip('_')}" + + # Collect all parameters (path, query, header) + params = operation.get("parameters", []) + if ( + (request_body := operation.get("requestBody")) + and (content := request_body.get("content", {})) + and (json_schema := content.get("application/json", {}).get("schema")) + and (properties := json_schema.get("properties", {})) + ): + # Convert request body to parameters + for name, schema in properties.items(): + params.append({ + "name": name, + "in": "body", + "required": name in json_schema.get("required", []), + "schema": schema, + "description": schema.get("description", ""), + }) + + operations[op_id] = { + "method": method, + "path": path, + "description": operation.get("description", ""), + "parameters": params, + "responses": operation.get("responses", {}), + } + + return operations + + class OpenAPITools(ToolSet): """Tool collection for OpenAPI endpoints.""" @@ -91,50 +132,11 @@ def _load_spec(self) -> Schema: def _parse_operations(self) -> dict[str, dict[str, Any]]: """Parse OpenAPI spec into operation configurations.""" - operations = {} - # Get server URL if not overridden if not self.base_url and "servers" in self._spec: self.base_url = self._spec["servers"][0]["url"] - - # Parse paths and operations - for path, path_item in self._spec.get("paths", {}).items(): - for method, operation in path_item.items(): - if method not in {"get", "post", "put", "delete", "patch"}: - continue - - # Generate operation ID if not provided - op_id = operation.get("operationId") - if not op_id: - op_id = f"{method}_{path.replace('/', '_').strip('_')}" - - # Collect all parameters (path, query, header) - parameters = operation.get("parameters", []) - if ( - (request_body := operation.get("requestBody")) - and (content := request_body.get("content", {})) - and (json_schema := content.get("application/json", {}).get("schema")) - and (properties := json_schema.get("properties", {})) - ): - # Convert request body to parameters - for name, schema in properties.items(): - parameters.append({ - "name": name, - "in": "body", - "required": name in json_schema.get("required", []), - "schema": schema, - "description": schema.get("description", ""), - }) - - operations[op_id] = { - "method": method, - "path": path, - "description": operation.get("description", ""), - "parameters": parameters, - "responses": operation.get("responses", {}), - } - - return operations + paths = self._spec.get("paths", {}) + return parse_operations(paths) def _resolve_schema_ref(self, schema: dict[str, Any]) -> dict[str, Any]: """Resolve schema reference.""" diff --git a/src/llmling/utils/importing.py b/src/llmling/utils/importing.py index 31a2171..b6e4ecc 100644 --- a/src/llmling/utils/importing.py +++ b/src/llmling/utils/importing.py @@ -23,11 +23,7 @@ def get_module_source( try: module = importlib.import_module(import_path) sources = list( - _get_sources( - module, - recursive=recursive, - include_tests=include_tests, - ) + _get_sources(module, recursive=recursive, include_tests=include_tests) ) return "\n\n# " + "-" * 40 + "\n\n".join(sources)