Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add strict jinja templating #237

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions docetl/operations/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple
from .base import BaseOperation
from .utils import RichLoopBar
from .utils import RichLoopBar, strict_render
from .clustering_utils import get_embeddings_for_clustering



class ClusterOperation(BaseOperation):
def __init__(
self,
Expand Down Expand Up @@ -187,9 +188,7 @@ def annotate_clustering_tree(self, t):
total_cost += futures[i].result()
pbar.update(i)

prompt = self.prompt_template.render(
inputs=t["children"]
)
prompt = strict_render(self.prompt_template, {"inputs": t["children"]})

def validation_fn(response: Dict[str, Any]):
output = self.runner.api.parse_llm_response(
Expand Down
5 changes: 3 additions & 2 deletions docetl/operations/equijoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from multiprocessing import Pool, cpu_count
from typing import Any, Dict, List, Tuple, Optional

from docetl.operations.utils import strict_render
import numpy as np
from jinja2 import Template
from litellm import model_cost
Expand Down Expand Up @@ -94,8 +95,8 @@ def compare_pair(
Tuple[bool, float]: A tuple containing a boolean indicating whether the items match and the cost of the comparison.
"""

prompt_template = Template(comparison_prompt)
prompt = prompt_template.render(left=item1, right=item2)

prompt = strict_render(comparison_prompt, {"left": item1, "right": item2})
response = self.runner.api.call_llm(
model,
"compare",
Expand Down
11 changes: 6 additions & 5 deletions docetl/operations/link_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar, rich_as_completed
from docetl.utils import completion_cost, extract_jinja_variables
from docetl.operations.utils import strict_render
from .clustering_utils import get_embeddings_for_clustering
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
Expand Down Expand Up @@ -139,11 +140,11 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
return input_data, total_cost

def compare(self, link_idx, id_idx, link_value, id_value, item):
prompt = self.prompt_template.render(
link_value = link_value,
id_value = id_value,
item = item
)
prompt = strict_render(self.prompt_template, {
"link_value": link_value,
"id_value": id_value,
"item": item
})

schema = {"is_same": "bool"}

Expand Down
21 changes: 5 additions & 16 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple, Union

from docetl.operations.utils import strict_render
from jinja2 import Environment, Template
from tqdm import tqdm

Expand All @@ -16,17 +17,6 @@
from litellm.utils import ModelResponse


def render_jinja_template(template_string: str, data: Dict[str, Any]) -> str:
"""
Render a Jinja2 template with the given data, ensuring protection against template injection vulnerabilities.
If the data is empty, return an empty string.
"""
if not data:
return ""

env = Environment(autoescape=True)
template = env.from_string(template_string)
return template.render(input=data)


class MapOperation(BaseOperation):
Expand Down Expand Up @@ -175,8 +165,8 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
self.status.stop()

def _process_map_item(item: Dict, initial_result: Optional[Dict] = None) -> Tuple[Optional[Dict], float]:
prompt_template = Template(self.config["prompt"])
prompt = prompt_template.render(input=item)

prompt = strict_render(self.config["prompt"], {"input": item})

def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
output = self.runner.api.parse_llm_response(
Expand Down Expand Up @@ -243,8 +233,7 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
def _process_map_batch(items: List[Dict]) -> Tuple[List[Dict], float]:
total_cost = 0
if len(items) > 1 and self.config.get("batch_prompt", None):
batch_prompt_template = Template(self.config["batch_prompt"])
batch_prompt = batch_prompt_template.render(inputs=items)
batch_prompt = strict_render(self.config["batch_prompt"], {"inputs": items})

# Issue the batch call
llm_result = self.runner.api.call_llm_batch(
Expand Down Expand Up @@ -449,7 +438,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
self.status.stop()

def process_prompt(item, prompt_config):
prompt = render_jinja_template(prompt_config["prompt"], item)
prompt = strict_render(prompt_config["prompt"], {"input": item})
local_output_schema = {
key: output_schema[key] for key in prompt_config["output_keys"]
}
Expand Down
34 changes: 16 additions & 18 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from jinja2 import Template

from docetl.operations.base import BaseOperation
from docetl.operations.utils import strict_render
from docetl.operations.clustering_utils import (
cluster_documents,
get_embeddings_for_clustering,
Expand Down Expand Up @@ -509,10 +510,8 @@ def _semantic_similarity_sampling(
self, key: Tuple, group_list: List[Dict], value_sampling: Dict, sample_size: int
) -> Tuple[List[Dict], float]:
embedding_model = value_sampling["embedding_model"]
query_text_template = Template(value_sampling["query_text"])
query_text = query_text_template.render(
reduce_key=dict(zip(self.config["reduce_key"], key))
)
query_text = strict_render(value_sampling["query_text"], {"reduce_key": dict(zip(self.config["reduce_key"], key))})


embeddings, cost = get_embeddings_for_clustering(
group_list, value_sampling, self.runner.api
Expand Down Expand Up @@ -794,12 +793,11 @@ def _increment_fold(
return self._batch_reduce(key, batch, scratchpad)

start_time = time.time()
fold_prompt_template = Template(self.config["fold_prompt"])
fold_prompt = fold_prompt_template.render(
inputs=batch,
output=current_output,
reduce_key=dict(zip(self.config["reduce_key"], key)),
)
fold_prompt = strict_render(self.config["fold_prompt"], {
"inputs": batch,
"output": current_output,
"reduce_key": dict(zip(self.config["reduce_key"], key))
})

response = self.runner.api.call_llm(
self.config.get("model", self.default_model),
Expand Down Expand Up @@ -857,10 +855,10 @@ def _merge_results(
the prompt used, and the cost of the merge operation.
"""
start_time = time.time()
merge_prompt_template = Template(self.config["merge_prompt"])
merge_prompt = merge_prompt_template.render(
outputs=outputs, reduce_key=dict(zip(self.config["reduce_key"], key))
)
merge_prompt = strict_render(self.config["merge_prompt"], {
"outputs": outputs,
"reduce_key": dict(zip(self.config["reduce_key"], key))
})
response = self.runner.api.call_llm(
self.config.get("model", self.default_model),
"merge",
Expand Down Expand Up @@ -963,10 +961,10 @@ def _batch_reduce(
Tuple[Optional[Dict], str, float]: A tuple containing the reduced output (or None if processing failed),
the prompt used, and the cost of the reduce operation.
"""
prompt_template = Template(self.config["prompt"])
prompt = prompt_template.render(
reduce_key=dict(zip(self.config["reduce_key"], key)), inputs=group_list
)
prompt = strict_render(self.config["prompt"], {
"reduce_key": dict(zip(self.config["reduce_key"], key)),
"inputs": group_list
})
item_cost = 0

response = self.runner.api.call_llm(
Expand Down
14 changes: 10 additions & 4 deletions docetl/operations/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
from datetime import datetime

from docetl.operations.utils import strict_render
import jinja2
from jinja2 import Template
from rich.prompt import Confirm
Expand Down Expand Up @@ -80,8 +81,11 @@ def compare_pair(
):
return True, 0, ""

prompt_template = Template(comparison_prompt)
prompt = prompt_template.render(input1=item1, input2=item2)

prompt = strict_render(comparison_prompt, {
"input1": item1,
"input2": item2
})
response = self.runner.api.call_llm(
model,
"compare",
Expand Down Expand Up @@ -543,14 +547,16 @@ def auto_batch() -> int:
def process_cluster(cluster):
if len(cluster) > 1:
cluster_items = [input_data[i] for i in cluster]
reduction_template = Template(self.config["resolution_prompt"])
if input_schema:
cluster_items = [
{k: item[k] for k in input_schema.keys() if k in item}
for item in cluster_items
]

resolution_prompt = reduction_template.render(inputs=cluster_items)

resolution_prompt = strict_render(self.config["resolution_prompt"], {
"inputs": cluster_items
})
reduction_response = self.runner.api.call_llm(
self.config.get("resolution_model", self.default_model),
"reduce",
Expand Down
36 changes: 36 additions & 0 deletions docetl/operations/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from .api import APIWrapper
from .cache import (
cache,
cache_key,
clear_cache,
flush_cache,
freezeargs,
CACHE_DIR,
LLM_CACHE_DIR,
DOCETL_HOME_DIR,
)
from .llm import LLMResult, InvalidOutputError, truncate_messages
from .progress import RichLoopBar, rich_as_completed
from .validation import safe_eval, convert_val, convert_dict_schema_to_list_schema, get_user_input_for_schema, strict_render

__all__ = [
'APIWrapper',
'cache',
'cache_key',
'clear_cache',
'flush_cache',
'freezeargs',
'CACHE_DIR',
'LLM_CACHE_DIR',
'DOCETL_HOME_DIR',
'LLMResult',
'InvalidOutputError',
'RichLoopBar',
'rich_as_completed',
'safe_eval',
'convert_val',
'convert_dict_schema_to_list_schema',
'get_user_input_for_schema',
'truncate_messages',
"strict_render"
]
Loading
Loading