Skip to content

Commit

Permalink
chore: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Dec 14, 2024
1 parent 64f4dac commit 733ba14
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 98 deletions.
6 changes: 1 addition & 5 deletions src/llmling/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,11 +890,7 @@ def register_static_prompt(
first_line[:100] + "..." if len(first_line) > 100 else first_line # noqa: PLR2004
)

prompt = StaticPrompt(
name=name,
description=description,
messages=messages,
)
prompt = StaticPrompt(name=name, description=description, messages=messages)
self._prompt_registry.register(name, prompt, replace=replace)

except Exception as exc:
Expand Down
8 changes: 4 additions & 4 deletions src/llmling/core/baseregistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ def _validate_item(self, item: Any) -> TItem:

async def _initialize_item(self, item: TItem) -> None:
"""Initialize an item during startup."""
if hasattr(item, "startup") and callable(item.startup):
await item.startup()
if hasattr(item, "startup") and callable(item.startup): # pyright: ignore
await item.startup() # pyright: ignore

async def _cleanup_item(self, item: TItem) -> None:
"""Clean up an item during shutdown."""
if hasattr(item, "shutdown") and callable(item.shutdown):
await item.shutdown()
if hasattr(item, "shutdown") and callable(item.shutdown): # pyright: ignore
await item.shutdown() # pyright: ignore

# Implementing MutableMapping methods
def __getitem__(self, key: TKey) -> TItem:
Expand Down
74 changes: 32 additions & 42 deletions src/llmling/core/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,32 @@ class StepCondition(BaseModel):
operator: Literal["eq", "gt", "lt", "contains", "exists"]
value: Any = None

def evaluate_with_value(self, value: Any) -> bool:
"""Evaluate this condition against a value.
Args:
value: The value to evaluate against the condition.
Returns:
bool: True if the condition is met, False otherwise.
"""
field_value = value.get(self.field) if isinstance(value, dict) else value

match self.operator:
case "eq":
return field_value == self.value
case "gt":
return field_value > self.value
case "lt":
return field_value < self.value
case "contains":
try:
return self.value in field_value # type: ignore
except TypeError:
return False
case "exists":
return field_value is not None


@dataclass
class StepResult:
Expand Down Expand Up @@ -246,15 +272,8 @@ async def _execute_step(
while True:
try:
# Check condition if any
if step.condition and not self._evaluate_condition(
step.condition, input_value
):
return StepResult(
success=True,
result=input_value, # Pass through unchanged
duration=0,
)

if step.condition and not step.condition.evaluate_with_value(input_value):
return StepResult(success=True, result=input_value, duration=0)
# Prepare kwargs
if isinstance(input_value, dict):
kwargs = {**input_value, **step.keyword_args}
Expand All @@ -263,10 +282,8 @@ async def _execute_step(

# Execute with timeout if specified
if step.timeout:
result = await asyncio.wait_for(
self.runtime.execute_tool(step.tool, **kwargs),
timeout=step.timeout,
)
fut = self.runtime.execute_tool(step.tool, **kwargs)
result = await asyncio.wait_for(fut, timeout=step.timeout)
else:
result = await self.runtime.execute_tool(step.tool, **kwargs)

Expand Down Expand Up @@ -303,30 +320,7 @@ async def _execute_step(
continue
raise # Max retries exceeded

def _evaluate_condition(self, condition: StepCondition, value: Any) -> bool:
"""Evaluate a step condition."""
field_value = value.get(condition.field) if isinstance(value, dict) else value

match condition.operator:
case "eq":
return field_value == condition.value
case "gt":
return field_value > condition.value
case "lt":
return field_value < condition.value
case "contains":
try:
return condition.value in field_value # type: ignore
except TypeError:
return False
case "exists":
return field_value is not None

async def _execute_sequential(
self,
pipeline: Pipeline,
results: StepResults,
) -> Any:
async def _execute_sequential(self, pipeline: Pipeline, results: StepResults) -> Any:
"""Execute steps sequentially."""
current = pipeline.input

Expand All @@ -338,11 +332,7 @@ async def _execute_sequential(

return current

async def _execute_parallel(
self,
pipeline: Pipeline,
results: StepResults,
) -> Any:
async def _execute_parallel(self, pipeline: Pipeline, results: StepResults) -> Any:
"""Execute independent steps in parallel."""
semaphore = asyncio.Semaphore(pipeline.max_parallel)

Expand Down
2 changes: 1 addition & 1 deletion src/llmling/resources/loaders/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def _load_impl(
loaded.content = result.content
yield loaded

except git.exc.GitCommandError as exc:
except git.exc.GitCommandError as exc: # type: ignore
msg = f"Git operation failed: {exc}"
raise exceptions.LoaderError(msg) from exc
except Exception as exc:
Expand Down
84 changes: 43 additions & 41 deletions src/llmling/tools/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,47 @@
}


def parse_operations(paths: dict) -> dict[str, dict[str, Any]]:
operations = {}
for path, path_item in paths.items():
for method, operation in path_item.items():
if method not in {"get", "post", "put", "delete", "patch"}:
continue

# Generate operation ID if not provided
op_id = operation.get("operationId")
if not op_id:
op_id = f"{method}_{path.replace('/', '_').strip('_')}"

# Collect all parameters (path, query, header)
params = operation.get("parameters", [])
if (
(request_body := operation.get("requestBody"))
and (content := request_body.get("content", {}))
and (json_schema := content.get("application/json", {}).get("schema"))
and (properties := json_schema.get("properties", {}))
):
# Convert request body to parameters
for name, schema in properties.items():
params.append({
"name": name,
"in": "body",
"required": name in json_schema.get("required", []),
"schema": schema,
"description": schema.get("description", ""),
})

operations[op_id] = {
"method": method,
"path": path,
"description": operation.get("description", ""),
"parameters": params,
"responses": operation.get("responses", {}),
}

return operations


class OpenAPITools(ToolSet):
"""Tool collection for OpenAPI endpoints."""

Expand Down Expand Up @@ -91,50 +132,11 @@ def _load_spec(self) -> Schema:

def _parse_operations(self) -> dict[str, dict[str, Any]]:
"""Parse OpenAPI spec into operation configurations."""
operations = {}

# Get server URL if not overridden
if not self.base_url and "servers" in self._spec:
self.base_url = self._spec["servers"][0]["url"]

# Parse paths and operations
for path, path_item in self._spec.get("paths", {}).items():
for method, operation in path_item.items():
if method not in {"get", "post", "put", "delete", "patch"}:
continue

# Generate operation ID if not provided
op_id = operation.get("operationId")
if not op_id:
op_id = f"{method}_{path.replace('/', '_').strip('_')}"

# Collect all parameters (path, query, header)
parameters = operation.get("parameters", [])
if (
(request_body := operation.get("requestBody"))
and (content := request_body.get("content", {}))
and (json_schema := content.get("application/json", {}).get("schema"))
and (properties := json_schema.get("properties", {}))
):
# Convert request body to parameters
for name, schema in properties.items():
parameters.append({
"name": name,
"in": "body",
"required": name in json_schema.get("required", []),
"schema": schema,
"description": schema.get("description", ""),
})

operations[op_id] = {
"method": method,
"path": path,
"description": operation.get("description", ""),
"parameters": parameters,
"responses": operation.get("responses", {}),
}

return operations
paths = self._spec.get("paths", {})
return parse_operations(paths)

def _resolve_schema_ref(self, schema: dict[str, Any]) -> dict[str, Any]:
"""Resolve schema reference."""
Expand Down
6 changes: 1 addition & 5 deletions src/llmling/utils/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ def get_module_source(
try:
module = importlib.import_module(import_path)
sources = list(
_get_sources(
module,
recursive=recursive,
include_tests=include_tests,
)
_get_sources(module, recursive=recursive, include_tests=include_tests)
)
return "\n\n# " + "-" * 40 + "\n\n".join(sources)

Expand Down

0 comments on commit 733ba14

Please sign in to comment.