diff --git a/docetl/operations/cluster.py b/docetl/operations/cluster.py index 19413a9a..76f57cc4 100644 --- a/docetl/operations/cluster.py +++ b/docetl/operations/cluster.py @@ -3,10 +3,11 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional, Tuple from .base import BaseOperation -from .utils import RichLoopBar +from .utils import RichLoopBar, strict_render from .clustering_utils import get_embeddings_for_clustering + class ClusterOperation(BaseOperation): def __init__( self, @@ -187,9 +188,7 @@ def annotate_clustering_tree(self, t): total_cost += futures[i].result() pbar.update(i) - prompt = self.prompt_template.render( - inputs=t["children"] - ) + prompt = strict_render(self.prompt_template, {"inputs": t["children"]}) def validation_fn(response: Dict[str, Any]): output = self.runner.api.parse_llm_response( diff --git a/docetl/operations/equijoin.py b/docetl/operations/equijoin.py index 7ab99c0f..23967fab 100644 --- a/docetl/operations/equijoin.py +++ b/docetl/operations/equijoin.py @@ -9,6 +9,7 @@ from multiprocessing import Pool, cpu_count from typing import Any, Dict, List, Tuple, Optional +from docetl.operations.utils import strict_render import numpy as np from jinja2 import Template from litellm import model_cost @@ -94,8 +95,8 @@ def compare_pair( Tuple[bool, float]: A tuple containing a boolean indicating whether the items match and the cost of the comparison. """ - prompt_template = Template(comparison_prompt) - prompt = prompt_template.render(left=item1, right=item2) + + prompt = strict_render(comparison_prompt, {"left": item1, "right": item2}) response = self.runner.api.call_llm( model, "compare", diff --git a/docetl/operations/link_resolve.py b/docetl/operations/link_resolve.py index e6a95c84..22ceded0 100644 --- a/docetl/operations/link_resolve.py +++ b/docetl/operations/link_resolve.py @@ -9,6 +9,7 @@ from docetl.operations.base import BaseOperation from docetl.operations.utils import RichLoopBar, rich_as_completed from docetl.utils import completion_cost, extract_jinja_variables +from docetl.operations.utils import strict_render from .clustering_utils import get_embeddings_for_clustering from sklearn.metrics.pairwise import cosine_similarity import numpy as np @@ -139,11 +140,11 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: return input_data, total_cost def compare(self, link_idx, id_idx, link_value, id_value, item): - prompt = self.prompt_template.render( - link_value = link_value, - id_value = id_value, - item = item - ) + prompt = strict_render(self.prompt_template, { + "link_value": link_value, + "id_value": id_value, + "item": item + }) schema = {"is_same": "bool"} diff --git a/docetl/operations/map.py b/docetl/operations/map.py index e4995fec..6b2dde2e 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -5,6 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional, Tuple, Union +from docetl.operations.utils import strict_render from jinja2 import Environment, Template from tqdm import tqdm @@ -16,17 +17,6 @@ from litellm.utils import ModelResponse -def render_jinja_template(template_string: str, data: Dict[str, Any]) -> str: - """ - Render a Jinja2 template with the given data, ensuring protection against template injection vulnerabilities. - If the data is empty, return an empty string. - """ - if not data: - return "" - - env = Environment(autoescape=True) - template = env.from_string(template_string) - return template.render(input=data) class MapOperation(BaseOperation): @@ -175,8 +165,8 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: self.status.stop() def _process_map_item(item: Dict, initial_result: Optional[Dict] = None) -> Tuple[Optional[Dict], float]: - prompt_template = Template(self.config["prompt"]) - prompt = prompt_template.render(input=item) + + prompt = strict_render(self.config["prompt"], {"input": item}) def validation_fn(response: Union[Dict[str, Any], ModelResponse]): output = self.runner.api.parse_llm_response( @@ -243,8 +233,7 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]): def _process_map_batch(items: List[Dict]) -> Tuple[List[Dict], float]: total_cost = 0 if len(items) > 1 and self.config.get("batch_prompt", None): - batch_prompt_template = Template(self.config["batch_prompt"]) - batch_prompt = batch_prompt_template.render(inputs=items) + batch_prompt = strict_render(self.config["batch_prompt"], {"inputs": items}) # Issue the batch call llm_result = self.runner.api.call_llm_batch( @@ -449,7 +438,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: self.status.stop() def process_prompt(item, prompt_config): - prompt = render_jinja_template(prompt_config["prompt"], item) + prompt = strict_render(prompt_config["prompt"], {"input": item}) local_output_schema = { key: output_schema[key] for key in prompt_config["output_keys"] } diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index 254323a6..08178900 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -19,6 +19,7 @@ from jinja2 import Template from docetl.operations.base import BaseOperation +from docetl.operations.utils import strict_render from docetl.operations.clustering_utils import ( cluster_documents, get_embeddings_for_clustering, @@ -509,10 +510,8 @@ def _semantic_similarity_sampling( self, key: Tuple, group_list: List[Dict], value_sampling: Dict, sample_size: int ) -> Tuple[List[Dict], float]: embedding_model = value_sampling["embedding_model"] - query_text_template = Template(value_sampling["query_text"]) - query_text = query_text_template.render( - reduce_key=dict(zip(self.config["reduce_key"], key)) - ) + query_text = strict_render(value_sampling["query_text"], {"reduce_key": dict(zip(self.config["reduce_key"], key))}) + embeddings, cost = get_embeddings_for_clustering( group_list, value_sampling, self.runner.api @@ -794,12 +793,11 @@ def _increment_fold( return self._batch_reduce(key, batch, scratchpad) start_time = time.time() - fold_prompt_template = Template(self.config["fold_prompt"]) - fold_prompt = fold_prompt_template.render( - inputs=batch, - output=current_output, - reduce_key=dict(zip(self.config["reduce_key"], key)), - ) + fold_prompt = strict_render(self.config["fold_prompt"], { + "inputs": batch, + "output": current_output, + "reduce_key": dict(zip(self.config["reduce_key"], key)) + }) response = self.runner.api.call_llm( self.config.get("model", self.default_model), @@ -857,10 +855,10 @@ def _merge_results( the prompt used, and the cost of the merge operation. """ start_time = time.time() - merge_prompt_template = Template(self.config["merge_prompt"]) - merge_prompt = merge_prompt_template.render( - outputs=outputs, reduce_key=dict(zip(self.config["reduce_key"], key)) - ) + merge_prompt = strict_render(self.config["merge_prompt"], { + "outputs": outputs, + "reduce_key": dict(zip(self.config["reduce_key"], key)) + }) response = self.runner.api.call_llm( self.config.get("model", self.default_model), "merge", @@ -963,10 +961,10 @@ def _batch_reduce( Tuple[Optional[Dict], str, float]: A tuple containing the reduced output (or None if processing failed), the prompt used, and the cost of the reduce operation. """ - prompt_template = Template(self.config["prompt"]) - prompt = prompt_template.render( - reduce_key=dict(zip(self.config["reduce_key"], key)), inputs=group_list - ) + prompt = strict_render(self.config["prompt"], { + "reduce_key": dict(zip(self.config["reduce_key"], key)), + "inputs": group_list + }) item_cost = 0 response = self.runner.api.call_llm( diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index bcfec266..f146bc03 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -9,6 +9,7 @@ import json from datetime import datetime +from docetl.operations.utils import strict_render import jinja2 from jinja2 import Template from rich.prompt import Confirm @@ -80,8 +81,11 @@ def compare_pair( ): return True, 0, "" - prompt_template = Template(comparison_prompt) - prompt = prompt_template.render(input1=item1, input2=item2) + + prompt = strict_render(comparison_prompt, { + "input1": item1, + "input2": item2 + }) response = self.runner.api.call_llm( model, "compare", @@ -543,14 +547,16 @@ def auto_batch() -> int: def process_cluster(cluster): if len(cluster) > 1: cluster_items = [input_data[i] for i in cluster] - reduction_template = Template(self.config["resolution_prompt"]) if input_schema: cluster_items = [ {k: item[k] for k in input_schema.keys() if k in item} for item in cluster_items ] - resolution_prompt = reduction_template.render(inputs=cluster_items) + + resolution_prompt = strict_render(self.config["resolution_prompt"], { + "inputs": cluster_items + }) reduction_response = self.runner.api.call_llm( self.config.get("resolution_model", self.default_model), "reduce", diff --git a/docetl/operations/utils/__init__.py b/docetl/operations/utils/__init__.py new file mode 100644 index 00000000..e787f842 --- /dev/null +++ b/docetl/operations/utils/__init__.py @@ -0,0 +1,36 @@ +from .api import APIWrapper +from .cache import ( + cache, + cache_key, + clear_cache, + flush_cache, + freezeargs, + CACHE_DIR, + LLM_CACHE_DIR, + DOCETL_HOME_DIR, +) +from .llm import LLMResult, InvalidOutputError, truncate_messages +from .progress import RichLoopBar, rich_as_completed +from .validation import safe_eval, convert_val, convert_dict_schema_to_list_schema, get_user_input_for_schema, strict_render + +__all__ = [ + 'APIWrapper', + 'cache', + 'cache_key', + 'clear_cache', + 'flush_cache', + 'freezeargs', + 'CACHE_DIR', + 'LLM_CACHE_DIR', + 'DOCETL_HOME_DIR', + 'LLMResult', + 'InvalidOutputError', + 'RichLoopBar', + 'rich_as_completed', + 'safe_eval', + 'convert_val', + 'convert_dict_schema_to_list_schema', + 'get_user_input_for_schema', + 'truncate_messages', + "strict_render" +] \ No newline at end of file diff --git a/docetl/operations/utils.py b/docetl/operations/utils/api.py similarity index 63% rename from docetl/operations/utils.py rename to docetl/operations/utils/api.py index d5362bb5..af64ab3d 100644 --- a/docetl/operations/utils.py +++ b/docetl/operations/utils/api.py @@ -1,384 +1,18 @@ import ast -import functools import hashlib import json -import os -import shutil -import threading -from concurrent.futures import as_completed -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union - -import litellm -import tiktoken -from asteval import Interpreter -from diskcache import Cache -from dotenv import load_dotenv -from frozendict import frozendict -from jinja2 import Template -from litellm import completion, embedding, model_cost, RateLimitError -from rich import print as rprint -from rich.console import Console -from rich.prompt import Prompt -from tqdm import tqdm -from pydantic import BaseModel - -from docetl.console import DOCETL_CONSOLE -from docetl.utils import completion_cost, count_tokens import time -from litellm.utils import ModelResponse - -aeval = Interpreter() - -load_dotenv() -# litellm.set_verbose = True -DOCETL_HOME_DIR = os.environ.get("DOCETL_HOME_DIR", os.path.expanduser("~"))+"/.cache/docetl" - -CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "general") -LLM_CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "llm") -cache = Cache(LLM_CACHE_DIR) -cache.close() - - -class LLMResult(BaseModel): - response: Any - total_cost: float - validated: bool - - -def freezeargs(func): - """ - Decorator to convert mutable dictionary arguments into immutable. - - This decorator is useful for making functions compatible with caching mechanisms - that require immutable arguments. - - Args: - func (callable): The function to be wrapped. - - Returns: - callable: The wrapped function with immutable dictionary arguments. - """ - - @functools.wraps(func) - def wrapped(*args, **kwargs): - args = tuple( - ( - frozendict(arg) - if isinstance(arg, dict) - else json.dumps(arg) if isinstance(arg, list) else arg - ) - for arg in args - ) - kwargs = { - k: ( - frozendict(v) - if isinstance(v, dict) - else json.dumps(v) if isinstance(v, list) else v - ) - for k, v in kwargs.items() - } - return func(*args, **kwargs) - - return wrapped - - -def flush_cache(console: Console = DOCETL_CONSOLE): - """ - Flush the cache to disk. - """ - console.log("[bold green]Flushing cache to disk...[/bold green]") - cache.close() - console.log("[bold green]Cache flushed to disk.[/bold green]") - - -def clear_cache(console: Console = DOCETL_CONSOLE): - """ - Clear the LLM cache stored on disk. - - This function removes all cached items from the disk-based cache, - effectively clearing the LLM's response history. - - Args: - console (Console, optional): A Rich console object for logging. - Defaults to a new Console instance. - """ - console.log("[bold yellow]Clearing LLM cache...[/bold yellow]") - try: - with cache as c: - c.clear() - # Remove all files in the cache directory - cache_dir = CACHE_DIR - if not os.path.exists(cache_dir): - os.makedirs(cache_dir) - for filename in os.listdir(cache_dir): - file_path = os.path.join(cache_dir, filename) - try: - if os.path.isfile(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - console.log( - f"[bold red]Error deleting {file_path}: {str(e)}[/bold red]" - ) - console.log("[bold green]Cache cleared successfully.[/bold green]") - except Exception as e: - console.log(f"[bold red]Error clearing cache: {str(e)}[/bold red]") - -def convert_dict_schema_to_list_schema(schema: Dict[str, Any]) -> Dict[str, Any]: - schema_str = "{" + ", ".join([f"{k}: {v}" for k, v in schema.items()]) + "}" - return {"results": f"list[{schema_str}]"} - -def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]: - """ - Convert a string representation of a type to a dictionary representation. - - This function takes a string value representing a data type and converts it - into a dictionary format suitable for JSON schema. - - Args: - value (Any): A string representing a data type. - model (str): The model being used. Defaults to "gpt-4o-mini". - - Returns: - Dict[str, Any]: A dictionary representing the type in JSON schema format. - - Raises: - ValueError: If the input value is not a supported type or is improperly formatted. - """ - value = value.strip().lower() - if value in ["str", "text", "string", "varchar"]: - return {"type": "string"} - elif value in ["int", "integer"]: - return {"type": "integer"} - elif value in ["float", "decimal", "number"]: - return {"type": "number"} - elif value in ["bool", "boolean"]: - return {"type": "boolean"} - elif value.startswith("list["): - inner_type = value[5:-1].strip() - return {"type": "array", "items": convert_val(inner_type, model)} - elif value == "list": - raise ValueError("List type must specify its elements, e.g., 'list[str]'") - elif value.startswith("{") and value.endswith("}"): - # Handle dictionary type - properties = {} - for item in value[1:-1].split(","): - key, val = item.strip().split(":") - properties[key.strip()] = convert_val(val.strip(), model) - result = { - "type": "object", - "properties": properties, - "required": list(properties.keys()), - } - # TODO: this is a hack to get around the fact that gemini doesn't support additionalProperties - if "gemini" not in model: - result["additionalProperties"] = False - return result - else: - raise ValueError(f"Unsupported value type: {value}") - - -def cache_key( - model: str, - op_type: str, - messages: List[Dict[str, str]], - output_schema: Dict[str, str], - scratchpad: Optional[str] = None, - system_prompt: Optional[Dict[str, str]] = None, -) -> str: - """ - Generate a unique cache key based on function arguments. - - This function creates a hash-based key using the input parameters, which can - be used for caching purposes. - - Args: - model (str): The model name. - op_type (str): The operation type. - messages (List[Dict[str, str]]): The messages to send to the LLM. - output_schema (Dict[str, str]): The output schema dictionary. - scratchpad (Optional[str]): The scratchpad to use for the operation. - - Returns: - str: A unique hash string representing the cache key. - """ - # Ensure no non-serializable objects are included - key_dict = { - "model": model, - "op_type": op_type, - "messages": json.dumps(messages, sort_keys=True), - "output_schema": json.dumps(output_schema, sort_keys=True), - "scratchpad": scratchpad, - "system_prompt": json.dumps(system_prompt, sort_keys=True), - } - return hashlib.md5(json.dumps(key_dict, sort_keys=True).encode()).hexdigest() - - -def get_user_input_for_schema(schema: Dict[str, Any]) -> Dict[str, Any]: - """ - Prompt the user for input for each key in the schema using Rich, - then parse the input values with json.loads(). - - Args: - schema (Dict[str, Any]): The schema dictionary. - - Returns: - Dict[str, Any]: A dictionary with user inputs parsed according to the schema. - """ - user_input = {} - - for key, value_type in schema.items(): - prompt_text = f"Enter value for '{key}' ({value_type}): " - user_value = Prompt.ask(prompt_text) - - try: - # Parse the input value using json.loads() - parsed_value = json.loads(user_value) - - # Check if the parsed value matches the expected type - if isinstance(parsed_value, eval(value_type)): - user_input[key] = parsed_value - else: - rprint( - f"[bold red]Error:[/bold red] Input for '{key}' does not match the expected type {value_type}." - ) - return get_user_input_for_schema(schema) # Recursive call to retry - - except json.JSONDecodeError: - rprint( - f"[bold red]Error:[/bold red] Invalid JSON input for '{key}'. Please try again." - ) - return get_user_input_for_schema(schema) # Recursive call to retry - - return user_input - - -class InvalidOutputError(Exception): - """ - Custom exception raised when the LLM output is invalid or cannot be parsed. - - Attributes: - message (str): Explanation of the error. - output (str): The invalid output that caused the exception. - expected_schema (Dict[str, Any]): The expected schema for the output. - messages (List[Dict[str, str]]): The messages sent to the LLM. - tools (Optional[List[Dict[str, str]]]): The tool calls generated by the LLM. - """ - - def __init__( - self, - message: str, - output: str, - expected_schema: Dict[str, Any], - messages: List[Dict[str, str]], - tools: Optional[List[Dict[str, str]]] = None, - ): - self.message = message - self.output = output - self.expected_schema = expected_schema - self.messages = messages - self.tools = tools - super().__init__(self.message) - - def __str__(self): - return ( - f"{self.message}\n" - f"Invalid output: {self.output}\n" - f"Expected schema: {self.expected_schema}\n" - f"Messages sent to LLM: {self.messages}\n" - f"Tool calls generated by LLM: {self.tools}" - ) - +from typing import Any, Dict, List, Optional -def timeout(seconds): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - result = [TimeoutError("Function call timed out")] +from litellm import completion, embedding, RateLimitError, ModelResponse +from rich.console import Console - def target(): - try: - result[0] = func(*args, **kwargs) - except Exception as e: - result[0] = e - - thread = threading.Thread(target=target) - thread.start() - thread.join(seconds) - if isinstance(result[0], Exception): - raise result[0] - return result[0] - - return wrapper - - return decorator - - -def truncate_messages( - messages: List[Dict[str, str]], model: str, from_agent: bool = False -) -> List[Dict[str, str]]: - """ - Truncate the messages to fit the model's context length. - """ - model_input_context_length = model_cost.get(model.split("/")[-1], {}).get( - "max_input_tokens", 8192 - ) - total_tokens = sum(count_tokens(json.dumps(msg), model) for msg in messages) - - if total_tokens <= model_input_context_length - 100: - return messages - - truncated_messages = messages.copy() - longest_message = max(truncated_messages, key=lambda x: len(x["content"])) - content = longest_message["content"] - excess_tokens = total_tokens - model_input_context_length + 200 # 200 token buffer - - try: - encoder = tiktoken.encoding_for_model(model.split("/")[-1]) - except Exception: - encoder = tiktoken.encoding_for_model("gpt-4o") - encoded_content = encoder.encode(content) - tokens_to_remove = min(len(encoded_content), excess_tokens) - mid_point = len(encoded_content) // 2 - truncated_encoded = ( - encoded_content[: mid_point - tokens_to_remove // 2] - + encoder.encode(f" ... [{tokens_to_remove} tokens truncated] ... ") - + encoded_content[mid_point + tokens_to_remove // 2 :] - ) - truncated_content = encoder.decode(truncated_encoded) - # Calculate the total number of tokens in the original content - total_tokens = len(encoded_content) - - # Print the warning message using rprint - warning_type = "User" if not from_agent else "Agent" - rprint( - f"[yellow]{warning_type} Warning:[/yellow] Cutting {tokens_to_remove} tokens from a prompt with {total_tokens} tokens..." - ) - - longest_message["content"] = truncated_content - - return truncated_messages - - -def safe_eval(expression: str, output: Dict) -> bool: - """ - Safely evaluate an expression with a given output dictionary. - Uses asteval to evaluate the expression. - https://lmfit.github.io/asteval/index.html - """ - try: - # Add the output dictionary to the symbol table - aeval.symtable["output"] = output - # Safely evaluate the expression - return bool(aeval(expression)) - except Exception: - # try to evaluate with python eval - try: - return bool(eval(expression, locals={"output": output})) - except Exception: - return False +from .cache import cache, cache_key, freezeargs +from .llm import LLMResult, InvalidOutputError, timeout, truncate_messages +from .validation import safe_eval, convert_dict_schema_to_list_schema, get_user_input_for_schema, convert_val, strict_render +from docetl.utils import completion_cost +from rich import print as rprint class APIWrapper(object): def __init__(self, runner): @@ -502,9 +136,6 @@ def _cached_call_llm( if gleaning_config: # Retry gleaning prompt + regular LLM num_gleaning_rounds = gleaning_config.get("num_rounds", 2) - validator_prompt_template = Template( - gleaning_config["validation_prompt"] - ) parsed_output = self.parse_llm_response( response, output_schema, tools @@ -523,9 +154,7 @@ def _cached_call_llm( for rnd in range(num_gleaning_rounds): # Prepare validator prompt - validator_prompt = validator_prompt_template.render( - output=parsed_output - ) + validator_prompt = strict_render(gleaning_config["validation_prompt"], {"output": parsed_output}) self.runner.rate_limiter.try_acquire("llm_call", weight=1) # Get params for should refine @@ -1053,137 +682,4 @@ def validate_output(self, operation: Dict, output: Dict, console: Console) -> bo console.log(f"[bold red]Validation error:[/bold red] {str(e)}") console.log(f"[yellow]Output:[/yellow] {output}") return False - return True - - -class RichLoopBar: - """ - A progress bar class that integrates with Rich console. - - This class provides a wrapper around tqdm to create progress bars that work - with Rich console output. - - Args: - iterable (Optional[Union[Iterable, range]]): An iterable to track progress. - total (Optional[int]): The total number of iterations. - desc (Optional[str]): Description to be displayed alongside the progress bar. - leave (bool): Whether to leave the progress bar on screen after completion. - console: The Rich console object to use for output. - """ - - def __init__( - self, - iterable: Optional[Union[Iterable, range]] = None, - total: Optional[int] = None, - desc: Optional[str] = None, - leave: bool = True, - console=None, - ): - if console is None: - raise ValueError("Console must be provided") - self.console = console - self.iterable = iterable - self.total = self._get_total(iterable, total) - self.description = desc - self.leave = leave - self.tqdm = None - - def _get_total(self, iterable, total): - """ - Determine the total number of iterations for the progress bar. - - Args: - iterable: The iterable to be processed. - total: The explicitly specified total, if any. - - Returns: - int or None: The total number of iterations, or None if it can't be determined. - """ - if total is not None: - return total - if isinstance(iterable, range): - return len(iterable) - try: - return len(iterable) - except TypeError: - return None - - def __iter__(self): - """ - Create and return an iterator with a progress bar. - - Returns: - Iterator: An iterator that yields items from the wrapped iterable. - """ - self.tqdm = tqdm( - self.iterable, - total=self.total, - desc=self.description, - file=self.console.file, - ) - for item in self.tqdm: - yield item - - def __enter__(self): - """ - Enter the context manager, initializing the progress bar. - - Returns: - RichLoopBar: The RichLoopBar instance. - """ - self.tqdm = tqdm( - total=self.total, - desc=self.description, - leave=self.leave, - file=self.console.file, - ) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Exit the context manager, closing the progress bar. - - Args: - exc_type: The type of the exception that caused the context to be exited. - exc_val: The instance of the exception that caused the context to be exited. - exc_tb: A traceback object encoding the stack trace. - """ - self.tqdm.close() - - def update(self, n=1): - """ - Update the progress bar. - - Args: - n (int): The number of iterations to increment the progress bar by. - """ - if self.tqdm: - self.tqdm.update(n) - - -def rich_as_completed(futures, total=None, desc=None, leave=True, console=None): - """ - Yield completed futures with a Rich progress bar. - - This function wraps concurrent.futures.as_completed with a Rich progress bar. - - Args: - futures: An iterable of Future objects to monitor. - total (Optional[int]): The total number of futures. - desc (Optional[str]): Description for the progress bar. - leave (bool): Whether to leave the progress bar on screen after completion. - console: The Rich console object to use for output. - - Yields: - Future: Completed future objects. - - Raises: - ValueError: If no console object is provided. - """ - if console is None: - raise ValueError("Console must be provided") - - with RichLoopBar(total=total, desc=desc, leave=leave, console=console) as pbar: - for future in as_completed(futures): - yield future - pbar.update() + return True \ No newline at end of file diff --git a/docetl/operations/utils/cache.py b/docetl/operations/utils/cache.py new file mode 100644 index 00000000..51625e5a --- /dev/null +++ b/docetl/operations/utils/cache.py @@ -0,0 +1,92 @@ +import functools +import hashlib +import json +import os +import shutil +from typing import Any, Dict, List +from frozendict import frozendict +from diskcache import Cache +from rich.console import Console +from dotenv import load_dotenv + +from docetl.console import DOCETL_CONSOLE + +load_dotenv() + +DOCETL_HOME_DIR = os.environ.get("DOCETL_HOME_DIR", os.path.expanduser("~"))+"/.cache/docetl" +CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "general") +LLM_CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "llm") +cache = Cache(LLM_CACHE_DIR) +cache.close() + +def freezeargs(func): + """ + Decorator to convert mutable dictionary arguments into immutable. + """ + @functools.wraps(func) + def wrapped(*args, **kwargs): + args = tuple( + ( + frozendict(arg) + if isinstance(arg, dict) + else json.dumps(arg) if isinstance(arg, list) else arg + ) + for arg in args + ) + kwargs = { + k: ( + frozendict(v) + if isinstance(v, dict) + else json.dumps(v) if isinstance(v, list) else v + ) + for k, v in kwargs.items() + } + return func(*args, **kwargs) + return wrapped + +def flush_cache(console: Console = DOCETL_CONSOLE): + """Flush the cache to disk.""" + console.log("[bold green]Flushing cache to disk...[/bold green]") + cache.close() + console.log("[bold green]Cache flushed to disk.[/bold green]") + +def clear_cache(console: Console = DOCETL_CONSOLE): + """Clear the LLM cache stored on disk.""" + console.log("[bold yellow]Clearing LLM cache...[/bold yellow]") + try: + with cache as c: + c.clear() + # Remove all files in the cache directory + if not os.path.exists(CACHE_DIR): + os.makedirs(CACHE_DIR) + for filename in os.listdir(CACHE_DIR): + file_path = os.path.join(CACHE_DIR, filename) + try: + if os.path.isfile(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + console.log(f"[bold red]Error deleting {file_path}: {str(e)}[/bold red]") + console.log("[bold green]Cache cleared successfully.[/bold green]") + except Exception as e: + console.log(f"[bold red]Error clearing cache: {str(e)}[/bold red]") + +def cache_key( + model: str, + op_type: str, + messages: List[Dict[str, str]], + output_schema: Dict[str, str], + scratchpad: str = None, + system_prompt: Dict[str, str] = None, +) -> str: + """Generate a unique cache key based on function arguments.""" + key_dict = { + "model": model, + "op_type": op_type, + "messages": json.dumps(messages, sort_keys=True), + "output_schema": json.dumps(output_schema, sort_keys=True), + "scratchpad": scratchpad, + "system_prompt": json.dumps(system_prompt, sort_keys=True), + } + return hashlib.md5(json.dumps(key_dict, sort_keys=True).encode()).hexdigest() \ No newline at end of file diff --git a/docetl/operations/utils/llm.py b/docetl/operations/utils/llm.py new file mode 100644 index 00000000..f98413fa --- /dev/null +++ b/docetl/operations/utils/llm.py @@ -0,0 +1,105 @@ +import ast +import json +import threading +import time +from typing import Any, Dict, List, Optional +import tiktoken +from jinja2 import Template +from litellm import completion, RateLimitError +from pydantic import BaseModel +from rich import print as rprint + +from docetl.utils import completion_cost, count_tokens + +class LLMResult(BaseModel): + response: Any + total_cost: float + validated: bool + +class InvalidOutputError(Exception): + """Custom exception raised when the LLM output is invalid or cannot be parsed.""" + def __init__( + self, + message: str, + output: str, + expected_schema: Dict[str, Any], + messages: List[Dict[str, str]], + tools: Optional[List[Dict[str, str]]] = None, + ): + self.message = message + self.output = output + self.expected_schema = expected_schema + self.messages = messages + self.tools = tools + super().__init__(self.message) + + def __str__(self): + return ( + f"{self.message}\n" + f"Invalid output: {self.output}\n" + f"Expected schema: {self.expected_schema}\n" + f"Messages sent to LLM: {self.messages}\n" + f"Tool calls generated by LLM: {self.tools}" + ) + +def timeout(seconds): + def decorator(func): + def wrapper(*args, **kwargs): + result = [TimeoutError("Function call timed out")] + + def target(): + try: + result[0] = func(*args, **kwargs) + except Exception as e: + result[0] = e + + thread = threading.Thread(target=target) + thread.start() + thread.join(seconds) + if isinstance(result[0], Exception): + raise result[0] + return result[0] + + return wrapper + return decorator + +def truncate_messages( + messages: List[Dict[str, str]], + model: str, + from_agent: bool = False +) -> List[Dict[str, str]]: + """Truncate messages to fit within model's context length.""" + model_input_context_length = 8192 # Default + total_tokens = sum(count_tokens(json.dumps(msg), model) for msg in messages) + + if total_tokens <= model_input_context_length - 100: + return messages + + truncated_messages = messages.copy() + longest_message = max(truncated_messages, key=lambda x: len(x["content"])) + content = longest_message["content"] + excess_tokens = total_tokens - model_input_context_length + 200 + + try: + encoder = tiktoken.encoding_for_model(model.split("/")[-1]) + except Exception: + encoder = tiktoken.encoding_for_model("gpt-4o") + + encoded_content = encoder.encode(content) + tokens_to_remove = min(len(encoded_content), excess_tokens) + mid_point = len(encoded_content) // 2 + truncated_encoded = ( + encoded_content[: mid_point - tokens_to_remove // 2] + + encoder.encode(f" ... [{tokens_to_remove} tokens truncated] ... ") + + encoded_content[mid_point + tokens_to_remove // 2 :] + ) + truncated_content = encoder.decode(truncated_encoded) + total_tokens = len(encoded_content) + + warning_type = "User" if not from_agent else "Agent" + rprint( + f"[yellow]{warning_type} Warning:[/yellow] Cutting {tokens_to_remove} tokens from a prompt with {total_tokens} tokens..." + ) + + longest_message["content"] = truncated_content + return truncated_messages \ No newline at end of file diff --git a/docetl/operations/utils/progress.py b/docetl/operations/utils/progress.py new file mode 100644 index 00000000..b78f62fc --- /dev/null +++ b/docetl/operations/utils/progress.py @@ -0,0 +1,69 @@ +from typing import Iterable, Optional, Union +from concurrent.futures import as_completed +from tqdm import tqdm + +class RichLoopBar: + """A progress bar class that integrates with Rich console.""" + + def __init__( + self, + iterable: Optional[Union[Iterable, range]] = None, + total: Optional[int] = None, + desc: Optional[str] = None, + leave: bool = True, + console=None, + ): + if console is None: + raise ValueError("Console must be provided") + self.console = console + self.iterable = iterable + self.total = self._get_total(iterable, total) + self.description = desc + self.leave = leave + self.tqdm = None + + def _get_total(self, iterable, total): + if total is not None: + return total + if isinstance(iterable, range): + return len(iterable) + try: + return len(iterable) + except TypeError: + return None + + def __iter__(self): + self.tqdm = tqdm( + self.iterable, + total=self.total, + desc=self.description, + file=self.console.file, + ) + for item in self.tqdm: + yield item + + def __enter__(self): + self.tqdm = tqdm( + total=self.total, + desc=self.description, + leave=self.leave, + file=self.console.file, + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.tqdm.close() + + def update(self, n=1): + if self.tqdm: + self.tqdm.update(n) + +def rich_as_completed(futures, total=None, desc=None, leave=True, console=None): + """Yield completed futures with a Rich progress bar.""" + if console is None: + raise ValueError("Console must be provided") + + with RichLoopBar(total=total, desc=desc, leave=leave, console=console) as pbar: + for future in as_completed(futures): + yield future + pbar.update() \ No newline at end of file diff --git a/docetl/operations/utils/validation.py b/docetl/operations/utils/validation.py new file mode 100644 index 00000000..f193d93d --- /dev/null +++ b/docetl/operations/utils/validation.py @@ -0,0 +1,134 @@ +import ast +import json +from typing import Union, Dict, Any +from asteval import Interpreter +from rich import print as rprint +from rich.prompt import Prompt + +from jinja2 import Environment, StrictUndefined, Template +from jinja2.exceptions import UndefinedError + + +aeval = Interpreter() + +def strict_render(template: Union[Template, str], context: Dict[str, Any]) -> str: + """ + Renders a Jinja template with strict undefined checking. + + Args: + template: Either a Jinja2 Template object or a template string + context: Dictionary containing the template variables + + Returns: + The rendered template string + + Raises: + UndefinedError: When any undefined variable, attribute or index is accessed + ValueError: When template is invalid + """ + # Create strict environment + env = Environment(undefined=StrictUndefined) + + # Convert string to Template if needed + if isinstance(template, str): + + # # If "inputs" in the context, make sure they are not accessing some attribute of inputs + # if "inputs" in context and "{{ inputs." in template: + # raise UndefinedError("The inputs variable is a list, so you cannot access attributes of inputs. Use inputs[index].key instead.") + + try: + template = env.from_string(template) + except Exception as e: + raise ValueError(f"Invalid template: {str(e)}") + + try: + return template.render(context) + except UndefinedError as e: + # Get the available context keys for better error reporting + available_vars = list(context.keys()) + + # For each var in context, if it's a dict, get the keys + var_attributes = {} + for var in available_vars: + if isinstance(context[var], dict): + var_attributes[var] = list(context[var].keys()) + elif isinstance(context[var], list) and len(context[var]) > 0: + var_attributes[var] = [f"inputs[i].{k}" for k in context[var][0].keys() if "_observability" not in k] + + raise UndefinedError( + f"{str(e)}\n" + f"Your prompt can include the following variables: {available_vars}\n" + f"For dictionary variables, you can access keys using dot notation (e.g. input.key).\n" + f"Available keys for each document: {var_attributes}\n" + ) + + +def safe_eval(expression: str, output: Dict) -> bool: + """Safely evaluate an expression with a given output dictionary.""" + try: + aeval.symtable["output"] = output + return bool(aeval(expression)) + except Exception: + try: + return bool(eval(expression, locals={"output": output})) + except Exception: + return False + +def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]: + """Convert a string representation of a type to a dictionary representation.""" + value = value.strip().lower() + if value in ["str", "text", "string", "varchar"]: + return {"type": "string"} + elif value in ["int", "integer"]: + return {"type": "integer"} + elif value in ["float", "decimal", "number"]: + return {"type": "number"} + elif value in ["bool", "boolean"]: + return {"type": "boolean"} + elif value.startswith("list["): + inner_type = value[5:-1].strip() + return {"type": "array", "items": convert_val(inner_type, model)} + elif value == "list": + raise ValueError("List type must specify its elements, e.g., 'list[str]'") + elif value.startswith("{") and value.endswith("}"): + properties = {} + for item in value[1:-1].split(","): + key, val = item.strip().split(":") + properties[key.strip()] = convert_val(val.strip(), model) + result = { + "type": "object", + "properties": properties, + "required": list(properties.keys()), + } + if "gemini" not in model: + result["additionalProperties"] = False + return result + else: + raise ValueError(f"Unsupported value type: {value}") + +def convert_dict_schema_to_list_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """Convert a dictionary schema to a list schema.""" + schema_str = "{" + ", ".join([f"{k}: {v}" for k, v in schema.items()]) + "}" + return {"results": f"list[{schema_str}]"} + +def get_user_input_for_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """Prompt the user for input for each key in the schema.""" + user_input = {} + + for key, value_type in schema.items(): + prompt_text = f"Enter value for '{key}' ({value_type}): " + user_value = Prompt.ask(prompt_text) + + try: + parsed_value = json.loads(user_value) + if isinstance(parsed_value, eval(value_type)): + user_input[key] = parsed_value + else: + rprint(f"[bold red]Error:[/bold red] Input for '{key}' does not match the expected type {value_type}.") + return get_user_input_for_schema(schema) + + except json.JSONDecodeError: + rprint(f"[bold red]Error:[/bold red] Invalid JSON input for '{key}'. Please try again.") + return get_user_input_for_schema(schema) + + return user_input \ No newline at end of file diff --git a/server/app/routes/pipeline.py b/server/app/routes/pipeline.py index 1db96c00..92af480c 100644 --- a/server/app/routes/pipeline.py +++ b/server/app/routes/pipeline.py @@ -300,6 +300,6 @@ async def run_pipeline(): error_traceback = traceback.format_exc() print(f"Error occurred:\n{error_traceback}") - await websocket.send_json({"type": "error", "data": str(e) + "\n" + error_traceback}) + await websocket.send_json({"type": "error", "data": str(e), "traceback": error_traceback}) finally: await websocket.close() diff --git a/website/src/components/OperationCard.tsx b/website/src/components/OperationCard.tsx index ab3eb15e..9fe75b3d 100644 --- a/website/src/components/OperationCard.tsx +++ b/website/src/components/OperationCard.tsx @@ -257,7 +257,7 @@ const OperationHeader: React.FC = React.memo( disabled={disabled} > - Decompose Operation + Optimize Operation )} diff --git a/website/src/components/PipelineGui.tsx b/website/src/components/PipelineGui.tsx index db7ad39c..8333203f 100644 --- a/website/src/components/PipelineGui.tsx +++ b/website/src/components/PipelineGui.tsx @@ -330,7 +330,7 @@ const PipelineGUI: React.FC = () => { setIsLoadingOutputs(false); } else if (lastMessage.type === "error") { toast({ - title: "Error", + title: "Execution Error", description: lastMessage.data, variant: "destructive", duration: Infinity, @@ -976,60 +976,6 @@ const PipelineGUI: React.FC = () => { - -
-
- - - - - - -

Load from YAML

-
-
-
- - - - - - - -

Save to YAML

-
-
-
- -
-