Skip to content

Commit

Permalink
fix: add strict jinja templating (#237)
Browse files Browse the repository at this point in the history
* fix: add strict jinja templating

* fix: tests had an error
  • Loading branch information
shreyashankar authored Dec 9, 2024
1 parent acb655c commit f701701
Show file tree
Hide file tree
Showing 16 changed files with 496 additions and 624 deletions.
7 changes: 3 additions & 4 deletions docetl/operations/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple
from .base import BaseOperation
from .utils import RichLoopBar
from .utils import RichLoopBar, strict_render
from .clustering_utils import get_embeddings_for_clustering



class ClusterOperation(BaseOperation):
def __init__(
self,
Expand Down Expand Up @@ -187,9 +188,7 @@ def annotate_clustering_tree(self, t):
total_cost += futures[i].result()
pbar.update(i)

prompt = self.prompt_template.render(
inputs=t["children"]
)
prompt = strict_render(self.prompt_template, {"inputs": t["children"]})

def validation_fn(response: Dict[str, Any]):
output = self.runner.api.parse_llm_response(
Expand Down
5 changes: 3 additions & 2 deletions docetl/operations/equijoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from multiprocessing import Pool, cpu_count
from typing import Any, Dict, List, Tuple, Optional

from docetl.operations.utils import strict_render
import numpy as np
from jinja2 import Template
from litellm import model_cost
Expand Down Expand Up @@ -94,8 +95,8 @@ def compare_pair(
Tuple[bool, float]: A tuple containing a boolean indicating whether the items match and the cost of the comparison.
"""

prompt_template = Template(comparison_prompt)
prompt = prompt_template.render(left=item1, right=item2)

prompt = strict_render(comparison_prompt, {"left": item1, "right": item2})
response = self.runner.api.call_llm(
model,
"compare",
Expand Down
11 changes: 6 additions & 5 deletions docetl/operations/link_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar, rich_as_completed
from docetl.utils import completion_cost, extract_jinja_variables
from docetl.operations.utils import strict_render
from .clustering_utils import get_embeddings_for_clustering
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
Expand Down Expand Up @@ -139,11 +140,11 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
return input_data, total_cost

def compare(self, link_idx, id_idx, link_value, id_value, item):
prompt = self.prompt_template.render(
link_value = link_value,
id_value = id_value,
item = item
)
prompt = strict_render(self.prompt_template, {
"link_value": link_value,
"id_value": id_value,
"item": item
})

schema = {"is_same": "bool"}

Expand Down
21 changes: 5 additions & 16 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple, Union

from docetl.operations.utils import strict_render
from jinja2 import Environment, Template
from tqdm import tqdm

Expand All @@ -16,17 +17,6 @@
from litellm.utils import ModelResponse


def render_jinja_template(template_string: str, data: Dict[str, Any]) -> str:
"""
Render a Jinja2 template with the given data, ensuring protection against template injection vulnerabilities.
If the data is empty, return an empty string.
"""
if not data:
return ""

env = Environment(autoescape=True)
template = env.from_string(template_string)
return template.render(input=data)


class MapOperation(BaseOperation):
Expand Down Expand Up @@ -175,8 +165,8 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
self.status.stop()

def _process_map_item(item: Dict, initial_result: Optional[Dict] = None) -> Tuple[Optional[Dict], float]:
prompt_template = Template(self.config["prompt"])
prompt = prompt_template.render(input=item)

prompt = strict_render(self.config["prompt"], {"input": item})

def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
output = self.runner.api.parse_llm_response(
Expand Down Expand Up @@ -243,8 +233,7 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
def _process_map_batch(items: List[Dict]) -> Tuple[List[Dict], float]:
total_cost = 0
if len(items) > 1 and self.config.get("batch_prompt", None):
batch_prompt_template = Template(self.config["batch_prompt"])
batch_prompt = batch_prompt_template.render(inputs=items)
batch_prompt = strict_render(self.config["batch_prompt"], {"inputs": items})

# Issue the batch call
llm_result = self.runner.api.call_llm_batch(
Expand Down Expand Up @@ -449,7 +438,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
self.status.stop()

def process_prompt(item, prompt_config):
prompt = render_jinja_template(prompt_config["prompt"], item)
prompt = strict_render(prompt_config["prompt"], {"input": item})
local_output_schema = {
key: output_schema[key] for key in prompt_config["output_keys"]
}
Expand Down
34 changes: 16 additions & 18 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from jinja2 import Template

from docetl.operations.base import BaseOperation
from docetl.operations.utils import strict_render
from docetl.operations.clustering_utils import (
cluster_documents,
get_embeddings_for_clustering,
Expand Down Expand Up @@ -509,10 +510,8 @@ def _semantic_similarity_sampling(
self, key: Tuple, group_list: List[Dict], value_sampling: Dict, sample_size: int
) -> Tuple[List[Dict], float]:
embedding_model = value_sampling["embedding_model"]
query_text_template = Template(value_sampling["query_text"])
query_text = query_text_template.render(
reduce_key=dict(zip(self.config["reduce_key"], key))
)
query_text = strict_render(value_sampling["query_text"], {"reduce_key": dict(zip(self.config["reduce_key"], key))})


embeddings, cost = get_embeddings_for_clustering(
group_list, value_sampling, self.runner.api
Expand Down Expand Up @@ -794,12 +793,11 @@ def _increment_fold(
return self._batch_reduce(key, batch, scratchpad)

start_time = time.time()
fold_prompt_template = Template(self.config["fold_prompt"])
fold_prompt = fold_prompt_template.render(
inputs=batch,
output=current_output,
reduce_key=dict(zip(self.config["reduce_key"], key)),
)
fold_prompt = strict_render(self.config["fold_prompt"], {
"inputs": batch,
"output": current_output,
"reduce_key": dict(zip(self.config["reduce_key"], key))
})

response = self.runner.api.call_llm(
self.config.get("model", self.default_model),
Expand Down Expand Up @@ -857,10 +855,10 @@ def _merge_results(
the prompt used, and the cost of the merge operation.
"""
start_time = time.time()
merge_prompt_template = Template(self.config["merge_prompt"])
merge_prompt = merge_prompt_template.render(
outputs=outputs, reduce_key=dict(zip(self.config["reduce_key"], key))
)
merge_prompt = strict_render(self.config["merge_prompt"], {
"outputs": outputs,
"reduce_key": dict(zip(self.config["reduce_key"], key))
})
response = self.runner.api.call_llm(
self.config.get("model", self.default_model),
"merge",
Expand Down Expand Up @@ -963,10 +961,10 @@ def _batch_reduce(
Tuple[Optional[Dict], str, float]: A tuple containing the reduced output (or None if processing failed),
the prompt used, and the cost of the reduce operation.
"""
prompt_template = Template(self.config["prompt"])
prompt = prompt_template.render(
reduce_key=dict(zip(self.config["reduce_key"], key)), inputs=group_list
)
prompt = strict_render(self.config["prompt"], {
"reduce_key": dict(zip(self.config["reduce_key"], key)),
"inputs": group_list
})
item_cost = 0

response = self.runner.api.call_llm(
Expand Down
14 changes: 10 additions & 4 deletions docetl/operations/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
from datetime import datetime

from docetl.operations.utils import strict_render
import jinja2
from jinja2 import Template
from rich.prompt import Confirm
Expand Down Expand Up @@ -80,8 +81,11 @@ def compare_pair(
):
return True, 0, ""

prompt_template = Template(comparison_prompt)
prompt = prompt_template.render(input1=item1, input2=item2)

prompt = strict_render(comparison_prompt, {
"input1": item1,
"input2": item2
})
response = self.runner.api.call_llm(
model,
"compare",
Expand Down Expand Up @@ -543,14 +547,16 @@ def auto_batch() -> int:
def process_cluster(cluster):
if len(cluster) > 1:
cluster_items = [input_data[i] for i in cluster]
reduction_template = Template(self.config["resolution_prompt"])
if input_schema:
cluster_items = [
{k: item[k] for k in input_schema.keys() if k in item}
for item in cluster_items
]

resolution_prompt = reduction_template.render(inputs=cluster_items)

resolution_prompt = strict_render(self.config["resolution_prompt"], {
"inputs": cluster_items
})
reduction_response = self.runner.api.call_llm(
self.config.get("resolution_model", self.default_model),
"reduce",
Expand Down
36 changes: 36 additions & 0 deletions docetl/operations/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from .api import APIWrapper
from .cache import (
cache,
cache_key,
clear_cache,
flush_cache,
freezeargs,
CACHE_DIR,
LLM_CACHE_DIR,
DOCETL_HOME_DIR,
)
from .llm import LLMResult, InvalidOutputError, truncate_messages
from .progress import RichLoopBar, rich_as_completed
from .validation import safe_eval, convert_val, convert_dict_schema_to_list_schema, get_user_input_for_schema, strict_render

__all__ = [
'APIWrapper',
'cache',
'cache_key',
'clear_cache',
'flush_cache',
'freezeargs',
'CACHE_DIR',
'LLM_CACHE_DIR',
'DOCETL_HOME_DIR',
'LLMResult',
'InvalidOutputError',
'RichLoopBar',
'rich_as_completed',
'safe_eval',
'convert_val',
'convert_dict_schema_to_list_schema',
'get_user_input_for_schema',
'truncate_messages',
"strict_render"
]
Loading

0 comments on commit f701701

Please sign in to comment.