From 479432ae7640e3ea9e89786f8b051362ce3f5b2c Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Sun, 29 Sep 2024 15:31:44 -0700 Subject: [PATCH] feat: let timeout be configurable --- docetl/api.py | 6 +++++ docetl/operations/equijoin.py | 13 +++++++++- docetl/operations/filter.py | 4 +++ docetl/operations/map.py | 10 ++++++++ docetl/operations/reduce.py | 8 ++++++ docetl/operations/resolve.py | 12 +++++++++ docetl/operations/utils.py | 24 +++++++++++++---- docs/operators/filter.md | 2 ++ docs/operators/map.md | 4 +++ docs/operators/parallel-map.md | 14 +++++----- docs/operators/reduce.md | 26 ++++++++++--------- docs/operators/resolve.md | 26 ++++++++++--------- tests/test_map.py | 47 ++++++++++++++++++++++++++++++++++ tests/test_ollama.py | 3 +++ 14 files changed, 163 insertions(+), 36 deletions(-) diff --git a/docetl/api.py b/docetl/api.py index 4647572f..ede1b032 100644 --- a/docetl/api.py +++ b/docetl/api.py @@ -80,6 +80,7 @@ class MapOp(BaseOp): num_retries_on_validate_failure: Optional[int] = None gleaning: Optional[Dict[str, Any]] = None drop_keys: Optional[List[str]] = None + timeout: Optional[int] = None @dataclass @@ -98,6 +99,7 @@ class ResolveOp(BaseOp): compare_batch_size: Optional[int] = None limit_comparisons: Optional[int] = None optimize: Optional[bool] = None + timeout: Optional[int] = None @dataclass @@ -115,6 +117,7 @@ class ReduceOp(BaseOp): fold_batch_size: Optional[int] = None value_sampling: Optional[Dict[str, Any]] = None verbose: Optional[bool] = None + timeout: Optional[int] = None @dataclass @@ -126,6 +129,7 @@ class ParallelMapOp(BaseOp): recursively_optimize: Optional[bool] = None sample_size: Optional[int] = None drop_keys: Optional[List[str]] = None + timeout: Optional[int] = None @dataclass @@ -138,6 +142,7 @@ class FilterOp(BaseOp): sample_size: Optional[int] = None validate: Optional[List[str]] = None num_retries_on_validate_failure: Optional[int] = None + timeout: Optional[int] = None @dataclass @@ -156,6 +161,7 @@ class EquijoinOp(BaseOp): compare_batch_size: Optional[int] = None limit_comparisons: Optional[int] = None blocking_keys: Optional[Dict[str, List[str]]] = None + timeout: Optional[int] = None @dataclass diff --git a/docetl/operations/equijoin.py b/docetl/operations/equijoin.py index e6e4e9aa..f61347f0 100644 --- a/docetl/operations/equijoin.py +++ b/docetl/operations/equijoin.py @@ -54,7 +54,12 @@ def process_left_item( def compare_pair( - comparison_prompt: str, model: str, item1: Dict, item2: Dict + comparison_prompt: str, + model: str, + item1: Dict, + item2: Dict, + timeout_seconds: int = 120, + max_retries_per_timeout: int = 2, ) -> Tuple[bool, float]: """ Compares two items using an LLM model to determine if they match. @@ -64,6 +69,8 @@ def compare_pair( model (str): The LLM model to use for comparison. item1 (Dict): The first item to compare. item2 (Dict): The second item to compare. + timeout_seconds (int): The timeout for the LLM call in seconds. + max_retries_per_timeout (int): The maximum number of retries per timeout. Returns: Tuple[bool, float]: A tuple containing a boolean indicating whether the items match and the cost of the comparison. @@ -76,6 +83,8 @@ def compare_pair( "compare", [{"role": "user", "content": prompt}], {"is_match": "bool"}, + timeout_seconds=timeout_seconds, + max_retries_per_timeout=max_retries_per_timeout, ) output = parse_llm_response(response)[0] return output["is_match"], completion_cost(response) @@ -384,6 +393,8 @@ def get_embeddings( self.config.get("comparison_model", self.default_model), left, right, + self.config.get("timeout", 120), + self.config.get("max_retries_per_timeout", 2), ): (left, right) for left, right in blocked_pairs } diff --git a/docetl/operations/filter.py b/docetl/operations/filter.py index 98fa52c8..aee1df27 100644 --- a/docetl/operations/filter.py +++ b/docetl/operations/filter.py @@ -135,6 +135,10 @@ def validation_fn(response: Dict[str, Any]): messages, self.config["output"]["schema"], console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get( + "max_retries_per_timeout", 2 + ), ), validation_fn=validation_fn, val_rule=self.config.get("validate", []), diff --git a/docetl/operations/map.py b/docetl/operations/map.py index 01f02a78..15953b20 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -154,6 +154,10 @@ def validation_fn(response: Dict[str, Any]): self.config["gleaning"]["validation_prompt"], self.config["gleaning"]["num_rounds"], self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get( + "max_retries_per_timeout", 2 + ), ), validation_fn=validation_fn, val_rule=self.config.get("validate", []), @@ -170,6 +174,10 @@ def validation_fn(response: Dict[str, Any]): self.config["output"]["schema"], tools=self.config.get("tools", None), console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get( + "max_retries_per_timeout", 2 + ), ), validation_fn=validation_fn, val_rule=self.config.get("validate", []), @@ -356,6 +364,8 @@ def process_prompt(item, prompt_config): local_output_schema, tools=prompt_config.get("tools", None), console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) output = parse_llm_response( response, tools=prompt_config.get("tools", None) diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index dda15dca..df9d8b16 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -691,6 +691,8 @@ def _increment_fold( self.config["output"]["schema"], scratchpad=scratchpad, console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) folded_output = parse_llm_response(response)[0] @@ -730,6 +732,8 @@ def _merge_results( [{"role": "user", "content": merge_prompt}], self.config["output"]["schema"], console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) merged_output = parse_llm_response(response)[0] merged_output.update(dict(zip(self.config["reduce_key"], key))) @@ -822,6 +826,8 @@ def _batch_reduce( self.config["gleaning"]["validation_prompt"], self.config["gleaning"]["num_rounds"], console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) item_cost += gleaning_cost else: @@ -832,6 +838,8 @@ def _batch_reduce( self.config["output"]["schema"], console=self.console, scratchpad=scratchpad, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) item_cost += completion_cost(response) diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index 27a71340..5d0666af 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -31,6 +31,8 @@ def compare_pair( item1: Dict, item2: Dict, blocking_keys: List[str] = [], + timeout_seconds: int = 120, + max_retries_per_timeout: int = 2, ) -> Tuple[bool, float]: """ Compares two items using an LLM model to determine if they match. @@ -58,6 +60,8 @@ def compare_pair( "compare", [{"role": "user", "content": prompt}], {"is_match": "bool"}, + timeout_seconds=timeout_seconds, + max_retries_per_timeout=max_retries_per_timeout, ) output = parse_llm_response(response)[0] return output["is_match"], completion_cost(response) @@ -362,6 +366,10 @@ def meets_blocking_conditions(pair): input_data[pair[0]], input_data[pair[1]], blocking_keys, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get( + "max_retries_per_timeout", 2 + ), ): pair for pair in batch } @@ -400,6 +408,10 @@ def process_cluster(cluster): [{"role": "user", "content": resolution_prompt}], self.config["output"]["schema"], console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get( + "max_retries_per_timeout", 2 + ), ) reduction_output = parse_llm_response(reduction_response)[0] reduction_cost = completion_cost(reduction_response) diff --git a/docetl/operations/utils.py b/docetl/operations/utils.py index 9831e55f..9035f561 100644 --- a/docetl/operations/utils.py +++ b/docetl/operations/utils.py @@ -363,6 +363,8 @@ def call_llm( tools: Optional[List[Dict[str, str]]] = None, scratchpad: Optional[str] = None, console: Console = Console(), + timeout_seconds: int = 120, + max_retries_per_timeout: int = 2, ) -> Any: """ Wrapper function that uses caching for LLM calls. @@ -377,6 +379,8 @@ def call_llm( output_schema (Dict[str, str]): The output schema dictionary. tools (Optional[List[Dict[str, str]]]): The tools to pass to the LLM. scratchpad (Optional[str]): The scratchpad to use for the operation. + timeout_seconds (int): The timeout for the LLM call. + max_retries_per_timeout (int): The maximum number of retries per timeout. Returns: str: The result from the cached LLM call. @@ -385,10 +389,10 @@ def call_llm( """ key = cache_key(model, op_type, messages, output_schema, scratchpad) - max_retries = 2 - for attempt in range(max_retries): + max_retries = max_retries_per_timeout + for attempt in range(max_retries + 1): try: - return timeout(120)(cached_call_llm)( + return timeout(timeout_seconds)(cached_call_llm)( key, model, op_type, @@ -607,6 +611,8 @@ def call_llm_with_gleaning( validator_prompt_template: str, num_gleaning_rounds: int, console: Console = Console(), + timeout_seconds: int = 120, + max_retries_per_timeout: int = 2, ) -> Tuple[str, float]: """ Call LLM with a gleaning process, including validation and improvement rounds. @@ -621,7 +627,7 @@ def call_llm_with_gleaning( output_schema (Dict[str, str]): The output schema dictionary. validator_prompt_template (str): Template for the validator prompt. num_gleaning_rounds (int): Number of gleaning rounds to perform. - + timeout_seconds (int): The timeout for the LLM call. Returns: Tuple[str, float]: A tuple containing the final LLM response and the total cost. """ @@ -632,7 +638,15 @@ def call_llm_with_gleaning( parameters["additionalProperties"] = False # Initial LLM call - response = call_llm(model, op_type, messages, output_schema, console=console) + response = call_llm( + model, + op_type, + messages, + output_schema, + console=console, + timeout_seconds=timeout_seconds, + max_retries_per_timeout=max_retries_per_timeout, + ) cost = 0.0 diff --git a/docs/operators/filter.md b/docs/operators/filter.md index aaa83c2a..c3d51c31 100644 --- a/docs/operators/filter.md +++ b/docs/operators/filter.md @@ -89,6 +89,8 @@ This example demonstrates how the Filter operation distinguishes between high-im | `sample_size` | Number of samples to use for the operation | Processes all data | | `validate` | List of Python expressions to validate the output | None | | `num_retries_on_validate_failure` | Number of retry attempts on validation failure | 0 | +| `timeout` | Timeout for each LLM call in seconds | 120 | +| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | !!! info "Validation" diff --git a/docs/operators/map.md b/docs/operators/map.md index 90afcc4f..8b536f56 100644 --- a/docs/operators/map.md +++ b/docs/operators/map.md @@ -142,9 +142,13 @@ This example demonstrates how the Map operation can transform long, unstructured | `num_retries_on_validate_failure` | Number of retry attempts on validation failure | 0 | | `gleaning` | Configuration for advanced validation and LLM-based refinement | None | | `drop_keys` | List of keys to drop from the input before processing | None | +| `timeout` | Timeout for each LLM call in seconds | 120 | +| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | Note: If `drop_keys` is specified, `prompt` and `output` become optional parameters. +| `timeout` | Timeout for each LLM call in seconds | 120 | + !!! info "Validation and Gleaning" For more details on validation techniques and implementation, see [operators](../concepts/operators.md#validation). diff --git a/docs/operators/parallel-map.md b/docs/operators/parallel-map.md index 20551b71..a4e0f6f1 100644 --- a/docs/operators/parallel-map.md +++ b/docs/operators/parallel-map.md @@ -29,12 +29,14 @@ Each prompt configuration in the `prompts` list should contain: ### Optional Parameters -| Parameter | Description | Default | -| ---------------------- | ------------------------------------------ | ----------------------------- | -| `model` | The default language model to use | Falls back to `default_model` | -| `optimize` | Flag to enable operation optimization | True | -| `recursively_optimize` | Flag to enable recursive optimization | false | -| `sample_size` | Number of samples to use for the operation | Processes all data | +| Parameter | Description | Default | +| ------------------------- | ------------------------------------------ | ----------------------------- | +| `model` | The default language model to use | Falls back to `default_model` | +| `optimize` | Flag to enable operation optimization | True | +| `recursively_optimize` | Flag to enable recursive optimization | false | +| `sample_size` | Number of samples to use for the operation | Processes all data | +| `timeout` | Timeout for each LLM call in seconds | 120 | +| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | ??? question "Why use Parallel Map instead of multiple Map operations?" diff --git a/docs/operators/reduce.md b/docs/operators/reduce.md index 0f782f43..018f1cde 100644 --- a/docs/operators/reduce.md +++ b/docs/operators/reduce.md @@ -49,18 +49,20 @@ This Reduce operation processes customer feedback grouped by department: ### Optional Parameters -| Parameter | Description | Default | -| -------------------- | ------------------------------------------------------------------------------- | --------------------------- | -| `synthesize_resolve` | If false, won't synthesize a resolve operation between map and reduce | true | -| `model` | The language model to use | Falls back to default_model | -| `input` | Specifies the schema or keys to subselect from each item | All keys from input items | -| `pass_through` | If true, non-input keys from the first item in the group will be passed through | false | -| `associative` | If true, the reduce operation is associative (i.e., order doesn't matter) | true | -| `fold_prompt` | A prompt template for incremental folding | None | -| `fold_batch_size` | Number of items to process in each fold operation | None | -| `value_sampling` | A dictionary specifying the sampling strategy for large groups | None | -| `verbose` | If true, enables detailed logging of the reduce operation | false | -| `persist_intermediates` | If true, persists the intermediate results for each group to the key `_{operation_name}_intermediates` | false | +| Parameter | Description | Default | +| ------------------------- | ------------------------------------------------------------------------------------------------------ | --------------------------- | +| `synthesize_resolve` | If false, won't synthesize a resolve operation between map and reduce | true | +| `model` | The language model to use | Falls back to default_model | +| `input` | Specifies the schema or keys to subselect from each item | All keys from input items | +| `pass_through` | If true, non-input keys from the first item in the group will be passed through | false | +| `associative` | If true, the reduce operation is associative (i.e., order doesn't matter) | true | +| `fold_prompt` | A prompt template for incremental folding | None | +| `fold_batch_size` | Number of items to process in each fold operation | None | +| `value_sampling` | A dictionary specifying the sampling strategy for large groups | None | +| `verbose` | If true, enables detailed logging of the reduce operation | false | +| `persist_intermediates` | If true, persists the intermediate results for each group to the key `_{operation_name}_intermediates` | false | +| `timeout` | Timeout for each LLM call in seconds | 120 | +| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | ## Advanced Features diff --git a/docs/operators/resolve.md b/docs/operators/resolve.md index c4447ff3..722e8e31 100644 --- a/docs/operators/resolve.md +++ b/docs/operators/resolve.md @@ -107,18 +107,20 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un ## Optional Parameters -| Parameter | Description | Default | -| ---------------------- | --------------------------------------------------------------------------------- | ----------------------------- | -| `embedding_model` | The model to use for creating embeddings | Falls back to `default_model` | -| `resolution_model` | The language model to use for reducing matched entries | Falls back to `default_model` | -| `comparison_model` | The language model to use for comparing potential matches | Falls back to `default_model` | -| `blocking_keys` | List of keys to use for initial blocking | All keys in the input data | -| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | None | -| `blocking_conditions` | List of conditions for initial blocking | [] | -| `input` | Specifies the schema or keys to subselect from each item to pass into the prompts | All keys from input items | -| `embedding_batch_size` | The number of entries to send to the embedding model at a time | 1000 | -| `compare_batch_size` | The number of entity pairs processed in each batch during the comparison phase | 100 | -| `limit_comparisons` | Maximum number of comparisons to perform | None | +| Parameter | Description | Default | +| ------------------------- | --------------------------------------------------------------------------------- | ----------------------------- | +| `embedding_model` | The model to use for creating embeddings | Falls back to `default_model` | +| `resolution_model` | The language model to use for reducing matched entries | Falls back to `default_model` | +| `comparison_model` | The language model to use for comparing potential matches | Falls back to `default_model` | +| `blocking_keys` | List of keys to use for initial blocking | All keys in the input data | +| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | None | +| `blocking_conditions` | List of conditions for initial blocking | [] | +| `input` | Specifies the schema or keys to subselect from each item to pass into the prompts | All keys from input items | +| `embedding_batch_size` | The number of entries to send to the embedding model at a time | 1000 | +| `compare_batch_size` | The number of entity pairs processed in each batch during the comparison phase | 100 | +| `limit_comparisons` | Maximum number of comparisons to perform | None | +| `timeout` | Timeout for each LLM call in seconds | 120 | +| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | ## Best Practices diff --git a/tests/test_map.py b/tests/test_map.py index a5346a04..ccf390a4 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -55,3 +55,50 @@ def test_map_operation_with_word_count_tool(map_config_with_tools, synthetic_dat assert all("word_count" in result for result in results) assert [result["word_count"] for result in results] == [5, 6, 5, 1] assert cost > 0 # Ensure there was some cost associated with the operation + + +@pytest.fixture +def simple_map_config(): + return { + "name": "simple_sentiment_analysis", + "type": "map", + "prompt": "Analyze the sentiment of the following text: '{{ input.text }}'. Classify it as either positive, negative, or neutral.", + "output": {"schema": {"sentiment": "string"}}, + "model": "gpt-4o-mini", + } + + +@pytest.fixture +def simple_sample_data(): + import random + import string + + def generate_random_text(length): + return "".join( + random.choice( + string.ascii_letters + string.digits + string.punctuation + " " + ) + for _ in range(length) + ) + + return [ + {"text": generate_random_text(random.randint(20, 100000))}, + {"text": generate_random_text(random.randint(20, 100000))}, + {"text": generate_random_text(random.randint(20, 100000))}, + ] + + +def test_map_operation_with_timeout(simple_map_config, simple_sample_data): + # Add timeout to the map configuration + map_config_with_timeout = { + **simple_map_config, + "timeout": 1, + "max_retries_per_timeout": 0, + } + + operation = MapOperation(map_config_with_timeout, "gpt-4o-mini", 4) + + # Execute the operation and expect empty results + results, cost = operation.execute(simple_sample_data) + for result in results: + assert "sentiment" not in result diff --git a/tests/test_ollama.py b/tests/test_ollama.py index 2f9d9ecb..67933670 100644 --- a/tests/test_ollama.py +++ b/tests/test_ollama.py @@ -15,6 +15,9 @@ load_dotenv() +# Set the OLLAMA_API_BASE environment variable +os.environ["OLLAMA_API_BASE"] = "http://localhost:11434/" + @pytest.fixture def temp_input_file():