Skip to content

Commit

Permalink
feat: let timeout be configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Sep 29, 2024
1 parent 691a58e commit 479432a
Show file tree
Hide file tree
Showing 14 changed files with 163 additions and 36 deletions.
6 changes: 6 additions & 0 deletions docetl/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class MapOp(BaseOp):
num_retries_on_validate_failure: Optional[int] = None
gleaning: Optional[Dict[str, Any]] = None
drop_keys: Optional[List[str]] = None
timeout: Optional[int] = None


@dataclass
Expand All @@ -98,6 +99,7 @@ class ResolveOp(BaseOp):
compare_batch_size: Optional[int] = None
limit_comparisons: Optional[int] = None
optimize: Optional[bool] = None
timeout: Optional[int] = None


@dataclass
Expand All @@ -115,6 +117,7 @@ class ReduceOp(BaseOp):
fold_batch_size: Optional[int] = None
value_sampling: Optional[Dict[str, Any]] = None
verbose: Optional[bool] = None
timeout: Optional[int] = None


@dataclass
Expand All @@ -126,6 +129,7 @@ class ParallelMapOp(BaseOp):
recursively_optimize: Optional[bool] = None
sample_size: Optional[int] = None
drop_keys: Optional[List[str]] = None
timeout: Optional[int] = None


@dataclass
Expand All @@ -138,6 +142,7 @@ class FilterOp(BaseOp):
sample_size: Optional[int] = None
validate: Optional[List[str]] = None
num_retries_on_validate_failure: Optional[int] = None
timeout: Optional[int] = None


@dataclass
Expand All @@ -156,6 +161,7 @@ class EquijoinOp(BaseOp):
compare_batch_size: Optional[int] = None
limit_comparisons: Optional[int] = None
blocking_keys: Optional[Dict[str, List[str]]] = None
timeout: Optional[int] = None


@dataclass
Expand Down
13 changes: 12 additions & 1 deletion docetl/operations/equijoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def process_left_item(


def compare_pair(
comparison_prompt: str, model: str, item1: Dict, item2: Dict
comparison_prompt: str,
model: str,
item1: Dict,
item2: Dict,
timeout_seconds: int = 120,
max_retries_per_timeout: int = 2,
) -> Tuple[bool, float]:
"""
Compares two items using an LLM model to determine if they match.
Expand All @@ -64,6 +69,8 @@ def compare_pair(
model (str): The LLM model to use for comparison.
item1 (Dict): The first item to compare.
item2 (Dict): The second item to compare.
timeout_seconds (int): The timeout for the LLM call in seconds.
max_retries_per_timeout (int): The maximum number of retries per timeout.
Returns:
Tuple[bool, float]: A tuple containing a boolean indicating whether the items match and the cost of the comparison.
Expand All @@ -76,6 +83,8 @@ def compare_pair(
"compare",
[{"role": "user", "content": prompt}],
{"is_match": "bool"},
timeout_seconds=timeout_seconds,
max_retries_per_timeout=max_retries_per_timeout,
)
output = parse_llm_response(response)[0]
return output["is_match"], completion_cost(response)
Expand Down Expand Up @@ -384,6 +393,8 @@ def get_embeddings(
self.config.get("comparison_model", self.default_model),
left,
right,
self.config.get("timeout", 120),
self.config.get("max_retries_per_timeout", 2),
): (left, right)
for left, right in blocked_pairs
}
Expand Down
4 changes: 4 additions & 0 deletions docetl/operations/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def validation_fn(response: Dict[str, Any]):
messages,
self.config["output"]["schema"],
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
Expand Down
10 changes: 10 additions & 0 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ def validation_fn(response: Dict[str, Any]):
self.config["gleaning"]["validation_prompt"],
self.config["gleaning"]["num_rounds"],
self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
Expand All @@ -170,6 +174,10 @@ def validation_fn(response: Dict[str, Any]):
self.config["output"]["schema"],
tools=self.config.get("tools", None),
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
Expand Down Expand Up @@ -356,6 +364,8 @@ def process_prompt(item, prompt_config):
local_output_schema,
tools=prompt_config.get("tools", None),
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
)
output = parse_llm_response(
response, tools=prompt_config.get("tools", None)
Expand Down
8 changes: 8 additions & 0 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,8 @@ def _increment_fold(
self.config["output"]["schema"],
scratchpad=scratchpad,
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
)
folded_output = parse_llm_response(response)[0]

Expand Down Expand Up @@ -730,6 +732,8 @@ def _merge_results(
[{"role": "user", "content": merge_prompt}],
self.config["output"]["schema"],
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
)
merged_output = parse_llm_response(response)[0]
merged_output.update(dict(zip(self.config["reduce_key"], key)))
Expand Down Expand Up @@ -822,6 +826,8 @@ def _batch_reduce(
self.config["gleaning"]["validation_prompt"],
self.config["gleaning"]["num_rounds"],
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
)
item_cost += gleaning_cost
else:
Expand All @@ -832,6 +838,8 @@ def _batch_reduce(
self.config["output"]["schema"],
console=self.console,
scratchpad=scratchpad,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
)

item_cost += completion_cost(response)
Expand Down
12 changes: 12 additions & 0 deletions docetl/operations/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def compare_pair(
item1: Dict,
item2: Dict,
blocking_keys: List[str] = [],
timeout_seconds: int = 120,
max_retries_per_timeout: int = 2,
) -> Tuple[bool, float]:
"""
Compares two items using an LLM model to determine if they match.
Expand Down Expand Up @@ -58,6 +60,8 @@ def compare_pair(
"compare",
[{"role": "user", "content": prompt}],
{"is_match": "bool"},
timeout_seconds=timeout_seconds,
max_retries_per_timeout=max_retries_per_timeout,
)
output = parse_llm_response(response)[0]
return output["is_match"], completion_cost(response)
Expand Down Expand Up @@ -362,6 +366,10 @@ def meets_blocking_conditions(pair):
input_data[pair[0]],
input_data[pair[1]],
blocking_keys,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
): pair
for pair in batch
}
Expand Down Expand Up @@ -400,6 +408,10 @@ def process_cluster(cluster):
[{"role": "user", "content": resolution_prompt}],
self.config["output"]["schema"],
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
)
reduction_output = parse_llm_response(reduction_response)[0]
reduction_cost = completion_cost(reduction_response)
Expand Down
24 changes: 19 additions & 5 deletions docetl/operations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ def call_llm(
tools: Optional[List[Dict[str, str]]] = None,
scratchpad: Optional[str] = None,
console: Console = Console(),
timeout_seconds: int = 120,
max_retries_per_timeout: int = 2,
) -> Any:
"""
Wrapper function that uses caching for LLM calls.
Expand All @@ -377,6 +379,8 @@ def call_llm(
output_schema (Dict[str, str]): The output schema dictionary.
tools (Optional[List[Dict[str, str]]]): The tools to pass to the LLM.
scratchpad (Optional[str]): The scratchpad to use for the operation.
timeout_seconds (int): The timeout for the LLM call.
max_retries_per_timeout (int): The maximum number of retries per timeout.
Returns:
str: The result from the cached LLM call.
Expand All @@ -385,10 +389,10 @@ def call_llm(
"""
key = cache_key(model, op_type, messages, output_schema, scratchpad)

max_retries = 2
for attempt in range(max_retries):
max_retries = max_retries_per_timeout
for attempt in range(max_retries + 1):
try:
return timeout(120)(cached_call_llm)(
return timeout(timeout_seconds)(cached_call_llm)(
key,
model,
op_type,
Expand Down Expand Up @@ -607,6 +611,8 @@ def call_llm_with_gleaning(
validator_prompt_template: str,
num_gleaning_rounds: int,
console: Console = Console(),
timeout_seconds: int = 120,
max_retries_per_timeout: int = 2,
) -> Tuple[str, float]:
"""
Call LLM with a gleaning process, including validation and improvement rounds.
Expand All @@ -621,7 +627,7 @@ def call_llm_with_gleaning(
output_schema (Dict[str, str]): The output schema dictionary.
validator_prompt_template (str): Template for the validator prompt.
num_gleaning_rounds (int): Number of gleaning rounds to perform.
timeout_seconds (int): The timeout for the LLM call.
Returns:
Tuple[str, float]: A tuple containing the final LLM response and the total cost.
"""
Expand All @@ -632,7 +638,15 @@ def call_llm_with_gleaning(
parameters["additionalProperties"] = False

# Initial LLM call
response = call_llm(model, op_type, messages, output_schema, console=console)
response = call_llm(
model,
op_type,
messages,
output_schema,
console=console,
timeout_seconds=timeout_seconds,
max_retries_per_timeout=max_retries_per_timeout,
)

cost = 0.0

Expand Down
2 changes: 2 additions & 0 deletions docs/operators/filter.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ This example demonstrates how the Filter operation distinguishes between high-im
| `sample_size` | Number of samples to use for the operation | Processes all data |
| `validate` | List of Python expressions to validate the output | None |
| `num_retries_on_validate_failure` | Number of retry attempts on validation failure | 0 |
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |

!!! info "Validation"

Expand Down
4 changes: 4 additions & 0 deletions docs/operators/map.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,13 @@ This example demonstrates how the Map operation can transform long, unstructured
| `num_retries_on_validate_failure` | Number of retry attempts on validation failure | 0 |
| `gleaning` | Configuration for advanced validation and LLM-based refinement | None |
| `drop_keys` | List of keys to drop from the input before processing | None |
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |

Note: If `drop_keys` is specified, `prompt` and `output` become optional parameters.

| `timeout` | Timeout for each LLM call in seconds | 120 |

!!! info "Validation and Gleaning"

For more details on validation techniques and implementation, see [operators](../concepts/operators.md#validation).
Expand Down
14 changes: 8 additions & 6 deletions docs/operators/parallel-map.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ Each prompt configuration in the `prompts` list should contain:

### Optional Parameters

| Parameter | Description | Default |
| ---------------------- | ------------------------------------------ | ----------------------------- |
| `model` | The default language model to use | Falls back to `default_model` |
| `optimize` | Flag to enable operation optimization | True |
| `recursively_optimize` | Flag to enable recursive optimization | false |
| `sample_size` | Number of samples to use for the operation | Processes all data |
| Parameter | Description | Default |
| ------------------------- | ------------------------------------------ | ----------------------------- |
| `model` | The default language model to use | Falls back to `default_model` |
| `optimize` | Flag to enable operation optimization | True |
| `recursively_optimize` | Flag to enable recursive optimization | false |
| `sample_size` | Number of samples to use for the operation | Processes all data |
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |

??? question "Why use Parallel Map instead of multiple Map operations?"

Expand Down
26 changes: 14 additions & 12 deletions docs/operators/reduce.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,20 @@ This Reduce operation processes customer feedback grouped by department:

### Optional Parameters

| Parameter | Description | Default |
| -------------------- | ------------------------------------------------------------------------------- | --------------------------- |
| `synthesize_resolve` | If false, won't synthesize a resolve operation between map and reduce | true |
| `model` | The language model to use | Falls back to default_model |
| `input` | Specifies the schema or keys to subselect from each item | All keys from input items |
| `pass_through` | If true, non-input keys from the first item in the group will be passed through | false |
| `associative` | If true, the reduce operation is associative (i.e., order doesn't matter) | true |
| `fold_prompt` | A prompt template for incremental folding | None |
| `fold_batch_size` | Number of items to process in each fold operation | None |
| `value_sampling` | A dictionary specifying the sampling strategy for large groups | None |
| `verbose` | If true, enables detailed logging of the reduce operation | false |
| `persist_intermediates` | If true, persists the intermediate results for each group to the key `_{operation_name}_intermediates` | false |
| Parameter | Description | Default |
| ------------------------- | ------------------------------------------------------------------------------------------------------ | --------------------------- |
| `synthesize_resolve` | If false, won't synthesize a resolve operation between map and reduce | true |
| `model` | The language model to use | Falls back to default_model |
| `input` | Specifies the schema or keys to subselect from each item | All keys from input items |
| `pass_through` | If true, non-input keys from the first item in the group will be passed through | false |
| `associative` | If true, the reduce operation is associative (i.e., order doesn't matter) | true |
| `fold_prompt` | A prompt template for incremental folding | None |
| `fold_batch_size` | Number of items to process in each fold operation | None |
| `value_sampling` | A dictionary specifying the sampling strategy for large groups | None |
| `verbose` | If true, enables detailed logging of the reduce operation | false |
| `persist_intermediates` | If true, persists the intermediate results for each group to the key `_{operation_name}_intermediates` | false |
| `timeout` | Timeout for each LLM call in seconds | 120 |
| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 |

## Advanced Features

Expand Down
Loading

0 comments on commit 479432a

Please sign in to comment.