From 4e380f96db3a6c2f8e1ce016c7eb94c47ca1abf4 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Sun, 24 Nov 2024 08:39:10 -0800 Subject: [PATCH 1/4] feat: have global system prompt and decription --- docetl/operations/utils.py | 31 ++++-- docetl/runner.py | 2 +- website/src/app/api/utils.ts | 21 +++- .../src/app/api/writePipelineConfig/route.ts | 4 +- website/src/app/localStorageKeys.ts | 1 + website/src/components/AIChatPanel.tsx | 73 +++++++++++- website/src/components/LLMContextPopover.tsx | 48 ++++++-- website/src/components/PipelineGui.tsx | 104 ++++++++++++++---- website/src/components/operations/args.tsx | 100 +++++++++++------ .../src/components/operations/components.tsx | 11 +- website/src/contexts/PipelineContext.tsx | 16 +++ 11 files changed, 328 insertions(+), 83 deletions(-) diff --git a/docetl/operations/utils.py b/docetl/operations/utils.py index 127637a1..ee1327c5 100644 --- a/docetl/operations/utils.py +++ b/docetl/operations/utils.py @@ -185,6 +185,7 @@ def cache_key( messages: List[Dict[str, str]], output_schema: Dict[str, str], scratchpad: Optional[str] = None, + system_prompt: Optional[Dict[str, str]] = None, ) -> str: """ Generate a unique cache key based on function arguments. @@ -209,6 +210,7 @@ def cache_key( "messages": json.dumps(messages, sort_keys=True), "output_schema": json.dumps(output_schema, sort_keys=True), "scratchpad": scratchpad, + "system_prompt": json.dumps(system_prompt, sort_keys=True), } return hashlib.md5(json.dumps(key_dict, sort_keys=True).encode()).hexdigest() @@ -690,7 +692,7 @@ def call_llm( Raises: TimeoutError: If the call times out after retrying. """ - key = cache_key(model, op_type, messages, output_schema, scratchpad) + key = cache_key(model, op_type, messages, output_schema, scratchpad, self.runner.config.get("system_prompt", {})) max_retries = max_retries_per_timeout attempt = 0 @@ -809,21 +811,33 @@ def _call_llm_with_cache( tools = None tool_choice = None - system_prompt = f"You are a helpful assistant, intelligently processing data. This is a {op_type} operation. You will perform the specified task on the provided data. The result should be a structured output that you will send back to the user." + persona = self.runner.config.get("system_prompt", {}).get("persona", "a helpful assistant") + dataset_description = self.runner.config.get("system_prompt", {}).get("dataset_description", "a collection of unstructured documents") + parethetical_op_instructions = "many inputs:one output" if op_type == "reduce" else "one input:one output" + + system_prompt = f"You are a {persona}, intelligently transforming data. The dataset description is: {dataset_description}. You will be performing a {op_type} operation ({parethetical_op_instructions}). You will perform the specified task on the provided data, as accurately, precisely, and exhaustively as possible. The result should be a structured output that you will send back to the user." if scratchpad: system_prompt += f""" -You are incrementally processing data across multiple batches. Maintain intermediate state between batches to accomplish this task effectively. +You are incrementally processing data across multiple batches. You will see: +1. The current batch of data to process +2. The intermediate output so far (what you returned last time) +3. A scratchpad for tracking additional state: {scratchpad} + +The intermediate output contains the partial result that directly answers the user's task, just on a subset of the data. +The scratchpad contains supporting information needed to process future batches correctly, but isn't part of the answer itself. -Current scratchpad: {scratchpad} +Example for counting words that appear >2 times: +- Intermediate output: {{"frequent_words": ["the", "and"]}} # Words seen 3+ times +- Scratchpad: {{"pending": {{"cat": 2, "dog": 1}}}} # Track words seen 1-2 times As you process each batch: -1. Update the scratchpad with crucial information for subsequent batches. -2. This may include partial results, counters, or data that doesn't fit into {list(output_schema.keys())}. -3. Example: For counting elements that appear more than twice, track all occurrences in the scratchpad until an item exceeds the threshold. +1. Use both the intermediate output and scratchpad to inform your processing +2. Update the scratchpad with any new information needed for future batches +3. Return both your partial result (representing the answer on the current batch and the previous batches' intermediate output) and updated scratchpad Keep the scratchpad concise (~500 chars) and easily parsable. Use clear structures like: -- Bullet points +- Bullet points - Key-value pairs - JSON-like format @@ -831,6 +845,7 @@ def _call_llm_with_cache( Remember: The scratchpad should contain information necessary for processing future batches, not the final result.""" + # Truncate messages if they exceed the model's context length messages = truncate_messages(messages, model) diff --git a/docetl/runner.py b/docetl/runner.py index c36507ff..e02439d8 100644 --- a/docetl/runner.py +++ b/docetl/runner.py @@ -102,7 +102,7 @@ def __init__(self, config: Dict, max_threads: int = None, **kwargs): all_ops_until_and_including_current = [ op_map[prev_op] for prev_op in step["operations"][:idx] - ] + [op_map[op_name]] + ] + [op_map[op_name]] + [self.config.get("system_prompt", {})] # If there's no model in the op, add the default model for op in all_ops_until_and_including_current: if "model" not in op: diff --git a/website/src/app/api/utils.ts b/website/src/app/api/utils.ts index b3274cc0..30bca6dd 100644 --- a/website/src/app/api/utils.ts +++ b/website/src/app/api/utils.ts @@ -12,7 +12,11 @@ export function generatePipelineConfig( homeDir: string, sample_size: number | null, optimize: boolean = false, - clear_intermediate: boolean = false + clear_intermediate: boolean = false, + system_prompt: { + datasetDescription: string | null; + persona: string | null; + } | null = null ) { const datasets = { input: { @@ -156,7 +160,7 @@ export function generatePipelineConfig( { name: "data_processing", input: Object.keys(datasets)[0], // Assuming the first dataset is the input - operations: operationsToRun.map((op: any) => op.name), + operations: operationsToRun.map((op) => op.name), }, ], output: { @@ -177,8 +181,21 @@ export function generatePipelineConfig( ), }, }, + system_prompt: {}, }; + if (system_prompt) { + if (system_prompt.datasetDescription) { + // @ts-ignore + pipelineConfig.system_prompt!.dataset_description = + system_prompt.datasetDescription; + } + if (system_prompt.persona) { + // @ts-ignore + pipelineConfig.system_prompt!.persona = system_prompt.persona; + } + } + // Get the inputPath from the intermediate_dir let inputPath; let outputPath; diff --git a/website/src/app/api/writePipelineConfig/route.ts b/website/src/app/api/writePipelineConfig/route.ts index 21e57ddd..7f32d7c6 100644 --- a/website/src/app/api/writePipelineConfig/route.ts +++ b/website/src/app/api/writePipelineConfig/route.ts @@ -15,6 +15,7 @@ export async function POST(request: Request) { sample_size, optimize = false, clear_intermediate = false, + system_prompt, } = await request.json(); if (!name) { @@ -42,7 +43,8 @@ export async function POST(request: Request) { homeDir, sample_size, optimize, - clear_intermediate + clear_intermediate, + system_prompt ); // Save the YAML file in the user's home directory diff --git a/website/src/app/localStorageKeys.ts b/website/src/app/localStorageKeys.ts index fb6748cc..aa22c9ae 100644 --- a/website/src/app/localStorageKeys.ts +++ b/website/src/app/localStorageKeys.ts @@ -16,3 +16,4 @@ export const DEFAULT_MODEL_KEY = "docetl_defaultModel"; export const OPTIMIZER_MODEL_KEY = "docetl_optimizerModel"; export const AUTO_OPTIMIZE_CHECK_KEY = "docetl_autoOptimizeCheck"; export const HIGH_LEVEL_GOAL_KEY = "docetl_highLevelGoal"; +export const SYSTEM_PROMPT_KEY = "docetl_systemPrompt"; diff --git a/website/src/components/AIChatPanel.tsx b/website/src/components/AIChatPanel.tsx index a1c24adf..3108395d 100644 --- a/website/src/components/AIChatPanel.tsx +++ b/website/src/components/AIChatPanel.tsx @@ -1,8 +1,14 @@ "use client"; -import React, { useRef, useState, useEffect } from "react"; +import React, { + useRef, + useState, + useEffect, + useMemo, + useCallback, +} from "react"; import { ResizableBox } from "react-resizable"; -import { Eraser, RefreshCw, X, Copy } from "lucide-react"; +import { RefreshCw, X, Copy } from "lucide-react"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { ScrollArea } from "@/components/ui/scroll-area"; @@ -13,6 +19,13 @@ import "react-resizable/css/styles.css"; import { LLMContextPopover } from "@/components/LLMContextPopover"; import { usePipelineContext } from "@/contexts/PipelineContext"; import ReactMarkdown from "react-markdown"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; +import { Textarea } from "@/components/ui/textarea"; +import { debounce } from "lodash"; interface AIChatPanelProps { onClose: () => void; @@ -45,7 +58,9 @@ const AIChatPanel: React.FC = ({ onClose }) => { initialMessages: [], id: "persistent-chat", }); - const { serializeState } = usePipelineContext(); + const { serializeState, highLevelGoal, setHighLevelGoal } = + usePipelineContext(); + const [localGoal, setLocalGoal] = useState(highLevelGoal); const handleMouseDown = (e: React.MouseEvent) => { if ((e.target as HTMLElement).classList.contains("drag-handle")) { @@ -184,6 +199,25 @@ Remember, all the output fields have been converted to strings, even if they wer ); }; + const debouncedSetHighLevelGoal = useMemo( + () => debounce((value: string) => setHighLevelGoal(value), 1000), + [setHighLevelGoal] + ); + + useEffect(() => { + return () => { + debouncedSetHighLevelGoal.cancel(); + }; + }, [debouncedSetHighLevelGoal]); + + const handleGoalUpdate = useCallback( + (newGoal: string) => { + setLocalGoal(newGoal); + debouncedSetHighLevelGoal(newGoal); + }, + [debouncedSetHighLevelGoal] + ); + return (
+ + + + + + + +
+

Pipeline Goal

+