Skip to content

Commit

Permalink
chore: fix optimizer and API for user study (#196)
Browse files Browse the repository at this point in the history
* chore: fix optimizer and API for user study

* chore: fix optimizer and API for user study
  • Loading branch information
shreyashankar authored Nov 19, 2024
1 parent 4328a8e commit e75a9ed
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 99 deletions.
11 changes: 6 additions & 5 deletions docetl/optimizers/map_optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,12 @@ def optimize(
input_data, output_data, model_input_context_length, no_change_runtime, validator_prompt, assessment, data_exceeds_limit = self._should_optimize_helper(op_config, input_data)

# Check if improvement is needed based on the assessment
if not data_exceeds_limit and not assessment.get("needs_improvement", True):
self.console.log(
f"[green]No improvement needed for operation {op_config['name']}[/green]"
)
return [op_config], output_data, self.plan_generator.reduce_optimizer_cost
if not self.config.get("optimizer_config", {}).get("force_decompose", False):
if not data_exceeds_limit and not assessment.get("needs_improvement", True):
self.console.log(
f"[green]No improvement needed for operation {op_config['name']}[/green]"
)
return [op_config], output_data, self.plan_generator.reduce_optimizer_cost

candidate_plans = {}

Expand Down
54 changes: 33 additions & 21 deletions docetl/optimizers/reduce_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def optimize(

# Print the validation results
self.console.log("[bold]Validation Results on Initial Sample:[/bold]")
if validation_results["needs_improvement"]:
if validation_results["needs_improvement"] or self.config.get("optimizer_config", {}).get("force_decompose", False):
self.console.post_optimizer_rationale(
should_optimize=True,
rationale= "\n".join(
Expand Down Expand Up @@ -545,7 +545,7 @@ def _evaluate_decomposition(
user_agrees = Confirm.ask(
f"Do you agree with the decomposition assessment? "
f"[bold]{'Recommended' if should_decompose['should_decompose'] else 'Not recommended'}[/bold]",
self.console,
console=self.console,
)

# If user disagrees, invert the decomposition decision
Expand Down Expand Up @@ -640,14 +640,15 @@ def _should_decompose(
Sample values for other keys:
{json.dumps(sample_values, indent=2)}
Based on this information, determine if it would be beneficial to decompose this reduce operation into a sub-reduce operation followed by a final reduce operation. Consider the following:
Based on this information, determine if it would be beneficial to decompose this reduce operation into a sub-reduce operation followed by a final reduce operation. Consider ALL of the following:
1. Is there a natural hierarchy in the data (e.g., country -> state -> city) among the other available keys, with a key at a finer level of granularity than the current reduce key(s)?
2. Are the current reduce key(s) some form of ID, and are there many different types of inputs for that ID among the other available keys?
3. Does the prompt implicitly ask for sub-grouping based on the other available keys (e.g., "summarize policies by state, then by country")?
4. Would splitting the operation improve accuracy (i.e., make sure information isn't lost when reducing)?
5. Are all the keys of the potential hierarchy provided in the other available keys? If not, we should not decompose.
6. Importantly, do not suggest decomposition using any key that is already part of the current reduce key(s). We are looking for a new key from the other available keys to use for sub-grouping.
7. Do not suggest keys that don't contain meaningful information (e.g., id-related keys).
Provide your analysis in the following format:
"""
Expand Down Expand Up @@ -1315,6 +1316,7 @@ def _create_reduce_plans(
sample_output,
num_prompts=self.num_fold_prompts,
)
fold_prompts = list(set(fold_prompts))
if not fold_prompts:
raise ValueError("No fold prompts generated")
except Exception as e:
Expand Down Expand Up @@ -1485,10 +1487,16 @@ def _synthesize_fold_prompts(

input_schema = op_config.get("input", {}).get("schema", {})
output_schema = op_config["output"]["schema"]
reduce_key = op_config["reduce_key"]

def get_random_examples():
if isinstance(reduce_key, list):
reduce_key = op_config["reduce_key"]
reduce_key = list(reduce_key) if not isinstance(reduce_key, list) else reduce_key

if reduce_key == ["_all"]:
# For _all case, just pick random input and output examples
input_example = random.choice(sample_input)
output_example = random.choice(sample_output)
elif isinstance(reduce_key, list):
random_key = tuple(
random.choice(
[
Expand All @@ -1512,16 +1520,6 @@ def get_random_examples():
if all(item.get(k) == v for k, v in zip(reduce_key, random_key))
]
)
else:
random_key = random.choice(
[item[reduce_key] for item in sample_input if reduce_key in item]
)
input_example = random.choice(
[item for item in sample_input if item[reduce_key] == random_key]
)
output_example = random.choice(
[item for item in sample_output if item[reduce_key] == random_key]
)

if input_schema:
input_example = {
Expand Down Expand Up @@ -1560,9 +1558,9 @@ def generate_single_prompt():
3. Be designed to work iteratively, allowing for multiple fold operations. The first iteration will use the original prompt, and all successive iterations will use the fold prompt.
The fold prompt should be a Jinja2 template with the following variables available:
- {{ output }}: The current reduced value (a dictionary with the current output schema)
- {{ inputs }}: A list of new values to be folded in
- {{ reduce_key }}: The key used for grouping in the reduce operation
- {{{{ output }}}}: The current reduced value (a dictionary with the current output schema)
- {{{{ inputs }}}}: A list of new values to be folded in
- {{{{ reduce_key }}}}: The key used for grouping in the reduce operation
Provide the fold prompt as a string.
"""
Expand All @@ -1582,11 +1580,25 @@ def generate_single_prompt():
) # Use a small batch size for testing

# Run the operation with the fold prompt
self._run_operation(temp_plan, sample_input[: temp_plan["fold_batch_size"]])
try:
self._run_operation(temp_plan, sample_input[: temp_plan["fold_batch_size"]])

return fold_prompt
except Exception as e:
self.console.log(f"[red]Error in agent-generated fold prompt: {e}[/red]")

# Create a default fold prompt that instructs folding new data into existing output
fold_prompt = f"""Analyze this batch of data using the following instructions:
{original_prompt}
However, instead of starting fresh, fold your analysis into the existing output that has already been generated. The existing output is provided in the 'output' variable below:
# If the operation runs successfully, return the fold prompt
return fold_prompt
{{{{ output }}}}
Remember, you must fold the new data into the existing output, do not start fresh."""
return fold_prompt

with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
fold_prompts = list(
executor.map(lambda _: generate_single_prompt(), range(num_prompts))
Expand Down
9 changes: 2 additions & 7 deletions website/src/app/api/convertDocuments/route.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { NextRequest, NextResponse } from "next/server";

export const runtime = "nodejs";

export async function POST(request: NextRequest) {
try {
const formData = await request.formData();
Expand Down Expand Up @@ -42,10 +44,3 @@ export async function POST(request: NextRequest) {
);
}
}

// Increase the maximum request size limit for file uploads
export const config = {
api: {
bodyParser: false,
},
};
42 changes: 42 additions & 0 deletions website/src/app/api/serveDocument/[...path]/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// app/api/documents/[...path]/route.ts
import { NextRequest, NextResponse } from "next/server";
import { readFile } from "fs/promises";
import path from "path";
import { lookup } from "mime-types";

export const dynamic = "force-dynamic";

export async function GET(
request: NextRequest,
{ params }: { params: { path: string[] } }
) {
try {
// Join the path segments and decode any URL encoding
const filePath = decodeURIComponent(params.path.join("/"));

// Basic security check to prevent directory traversal
const normalizedPath = path.normalize(filePath);
if (normalizedPath.includes("..")) {
return NextResponse.json({ error: "Invalid file path" }, { status: 400 });
}

const fileBuffer = await readFile(normalizedPath);
const mimeType = lookup(normalizedPath) || "application/octet-stream";

return new NextResponse(fileBuffer, {
headers: {
"Content-Type": mimeType,
"Content-Disposition": `inline; filename="${path.basename(
normalizedPath
)}"`,
"Cache-Control": "public, max-age=3600",
},
});
} catch (error) {
console.error("Error serving file:", error);
return NextResponse.json(
{ error: "Failed to serve file" },
{ status: 500 }
);
}
}
33 changes: 0 additions & 33 deletions website/src/app/api/serveDocument/route.ts

This file was deleted.

3 changes: 3 additions & 0 deletions website/src/app/api/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ export function generatePipelineConfig(
const pipelineConfig = {
datasets,
default_model,
optimizer_config: {
force_decompose: true,
},
operations: updatedOperations,
pipeline: {
steps: [
Expand Down
1 change: 1 addition & 0 deletions website/src/components/DatasetView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ const DatasetView: React.FC<{ file: File | null }> = ({ file }) => {

setTimeout(() => {
try {
// @ts-ignore
const allContent = data.pages.map((page) => page.content).join("");
let documents: Record<string, unknown>[] = [];

Expand Down
65 changes: 39 additions & 26 deletions website/src/components/DocumentViewer.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import React from "react";
// DocumentViewer.tsx
import React, { useEffect, useState } from "react";
import DocViewer, { DocViewerRenderers } from "@cyntler/react-doc-viewer";
import {
Dialog,
Expand All @@ -21,8 +22,18 @@ export const DocumentViewer: React.FC<DocumentViewerProps> = ({
filePath,
fileName,
}) => {
const fileUrl = `/api/serveDocument?path=${encodeURIComponent(filePath)}`;
const docs = [{ uri: fileUrl, fileName: fileName }];
const [documentUrl, setDocumentUrl] = useState<string>("");

useEffect(() => {
if (isOpen && filePath) {
// Convert the full file path to a URL-safe format
const encodedPath = encodeURIComponent(filePath);
const url = `/api/serveDocument/${encodedPath}`;
setDocumentUrl(url);
}
}, [isOpen, filePath]);

const docs = documentUrl ? [{ uri: documentUrl, fileName: fileName }] : [];

return (
<Dialog open={isOpen} onOpenChange={(open) => !open && onClose()}>
Expand All @@ -32,29 +43,31 @@ export const DocumentViewer: React.FC<DocumentViewerProps> = ({
</DialogHeader>
<div className="flex-1 p-4 pt-0 overflow-hidden">
<div className="h-full w-full overflow-hidden">
<DocViewer
pluginRenderers={DocViewerRenderers}
documents={docs}
initialActiveDocument={docs[0]}
style={{
height: "100%",
width: "100%",
maxHeight: "100%",
overflow: "auto",
backgroundColor: "white",
}}
config={{
header: {
disableHeader: true,
disableFileName: true,
},
pdfZoom: {
defaultZoom: 1,
zoomJump: 0.2,
},
pdfVerticalScrollByDefault: true,
}}
/>
{documentUrl && (
<DocViewer
pluginRenderers={DocViewerRenderers}
documents={docs}
initialActiveDocument={docs[0]}
style={{
height: "100%",
width: "100%",
maxHeight: "100%",
overflow: "auto",
backgroundColor: "white",
}}
config={{
header: {
disableHeader: true,
disableFileName: true,
},
pdfZoom: {
defaultZoom: 1,
zoomJump: 0.2,
},
pdfVerticalScrollByDefault: true,
}}
/>
)}
</div>
</div>
</DialogContent>
Expand Down
7 changes: 7 additions & 0 deletions website/src/components/FileExplorer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async function getAllFiles(entry: FileSystemEntry): Promise<File[]> {
if (entry.isFile) {
const fileEntry = entry as FileSystemFileEntry;
const file = await new Promise<File>((resolve, reject) => {
// @ts-ignore
fileEntry.file(resolve, reject);
});

Expand All @@ -100,8 +101,10 @@ async function getAllFiles(entry: FileSystemEntry): Promise<File[]> {
) {
// Create a new file with the full path
const fullPath = path ? `${path}/${file.name}` : file.name;
// @ts-ignore
const newFile = new File([file], fullPath, { type: file.type });
Object.defineProperty(newFile, "relativePath", { value: fullPath });
// @ts-ignore
files.push(newFile);
}
} else if (entry.isDirectory) {
Expand Down Expand Up @@ -235,18 +238,21 @@ export const FileExplorer: React.FC<FileExplorerProps> = ({
const files: File[] = [];

const processItems = async () => {
// @ts-ignore
const items = Array.from(fileList);

for (const item of items) {
if ("webkitGetAsEntry" in item) {
// Handle drag and drop
// @ts-ignore
const entry = (item as DataTransferItem).webkitGetAsEntry();
if (entry) {
const entryFiles = await getAllFiles(entry);
files.push(...entryFiles);
}
} else {
// Handle regular file input
// @ts-ignore
const file = item as File;
const supportedExtensions = [
".pdf",
Expand All @@ -271,6 +277,7 @@ export const FileExplorer: React.FC<FileExplorerProps> = ({

// Create a new FileList-like object with the collected files
const dt = new DataTransfer();
// @ts-ignore
files.forEach((file) => dt.items.add(file));
setSelectedFiles((prevFiles) => mergeFileList(prevFiles, dt.files));
};
Expand Down
Loading

0 comments on commit e75a9ed

Please sign in to comment.