Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batching support to map operations with configurable parameters #16

Merged
merged 1 commit into from
Sep 30, 2024

Conversation

orban
Copy link
Contributor

@orban orban commented Sep 26, 2024

Summary

This pull request introduces batching support to map operations as described in issue #7, with the aim of significantly enhancing performance and reducing costs when processing small documents. Key updates include new batching parameters, implementation of batching logic, configuration enhancements, expanded testing, and documentation improvements. Additionally, Pydantic models have been introduced in schemas.py to simplify and streamline validation logic.


Main Changes

  1. Batching Support in Map Operations

    • New Parameters: Added batch_size and clustering_method to the map operation interface to enable batching functionality.
    • Batching Logic: Implemented logic to group documents based on the specified batch size and clustering method, optimizing the efficiency of LLM calls.
    • LLM Call Handling: Updated LLM calls to handle batched inputs and ensure accurate mapping of outputs back to individual documents.
    • Configuration Updates: Modified the YAML configuration format to support the new batch-related parameters, allowing users to easily configure batching behavior.
  2. Testing Enhancements

    • Comprehensive Unit Tests: Developed unit tests covering various batch sizes and clustering methods to ensure functionality and accuracy are maintained across different scenarios.
  3. Documentation Improvements

    • Updated Documentation: Expanded documentation with detailed explanations, practical examples, and best practices for utilizing batching in map operations.
  4. Validation Logic Simplification

    • Introduction of Pydantic Models: Incorporated Pydantic models into schemas.py to simplify validation logic and improve code maintainability.

New Pydantic Models

  • ToolFunction: Defines the structure of a tool function with fields for name, description, and parameters.
  • Tool: Represents a tool with fields for code and function.
  • OutputSchema: Specifies the output schema using a schema field.
  • MapOperationConfig: Configures map operations with optional fields such as drop_keys, prompt, output, model, and tools, including a validator for drop_keys.
  • ParallelMapOperationConfig: Configures parallel map operations with fields for prompts, model, and tools.
  • BatchConfig: Defines batch configurations with batch_size and an optional clustering_method.
  • OperationConfig: A generic operation configuration with fields for name, type, and a union of specific operation configurations (MapOperationConfig, ParallelMapOperationConfig, BatchConfig).

These enhancements collectively improve the efficiency and usability of map operations, making it easier to process small documents at scale while maintaining accuracy and performance.

@orban
Copy link
Contributor Author

orban commented Sep 26, 2024

Forgot to mention that testing_basic.py was split into testing_map.py and testing_map_parallel.py to group the tests together more logically. Common pytest fixtures have also been moved to conftest.py by pytest convention.

@shreyashankar
Copy link
Collaborator

Wow! This looks so thorough. I will review today 🙏🙌🏽

@shreyashankar
Copy link
Collaborator

Any reason for introducing the Flask dependency?

@orban
Copy link
Contributor Author

orban commented Sep 27, 2024

Great question! The Flask dependency was initially introduced as a safeguard against XSS and RCE vulnerabilities by ensuring proper escaping of Jinja2 templates. However, after revisiting the issue, I've found that we can achieve the same level of security without Flask by configuring Jinja2's Environment to enable autoescaping directly.

Based on the documentation here, we can remove the Flask dependency and simply configure Jinja2 with:

from jinja2 import Environment
env = Environment(autoescape=True)

This change would enforce the necessary escaping behavior during template rendering. I’m happy to update the PR to remove the Flask dependency and handle this via Jinja2. Let me know if that works for you!

@shreyashankar
Copy link
Collaborator

Awesome, I'll let you update it to replace Flask with Jinja. We're using Jinja elsewhere too, e.g., here, so it will be good to be consistent. Thank you 🙏🏽

@orban
Copy link
Contributor Author

orban commented Sep 28, 2024

Worked out the last few kinks -- make tests_basic is now passing!

I also swapped out the super-unsafe eval code 💀 in favor of ASTEVAL which runs the validation in a stripped down environment which limited functionality.

@shreyashankar
Copy link
Collaborator

Amazing, will check this out, play around with the new functionality, & merge it this weekend! Thank you for taking the time to do this 😄

@shreyashankar
Copy link
Collaborator

I went through the PR and most of the changes look good. I ended up removing the parallel map operation batching, since the code did not seem to be changing the functionality. I also removed the semantic similarity functionality here.

Overall, I think there is a misunderstanding between Map operation and Reduce operations (I should be more clear in the documents). Map operations are 1:1, where the prompt that the user writes only has access to one input. Reduce operations, on the other hand, are many:1. The example you had in your documentation looked like a reduce operation--and we support semantic similarity grouping for reduce operations actually :-)

I think the basic batching that we have now (thanks to you!) is good for limiting parallelism; if there are too many documents in the input, we should not try to process all of them at the same time, so batching is good. But in the future I wonder if it's possible to batch map operations in the same prompt, while ensuring the output still matches that same 1:1 expectation.

@shreyashankar shreyashankar changed the base branch from main to orban-map-batching September 30, 2024 23:08
@shreyashankar
Copy link
Collaborator

Merging into another branch so I can create a PR that runs test.

@shreyashankar shreyashankar merged commit e533ea2 into ucbepic:orban-map-batching Sep 30, 2024
0 of 3 checks passed
@orban
Copy link
Contributor Author

orban commented Oct 1, 2024

But in the future, I wonder if it's possible to batch map operations in the same prompt while ensuring the output still matches that same 1:1 expectation.

I’ve been thinking along the same lines. The updates to ParallelMapOperation haven't quite hit the mark yet since we’re still submitting all the futures at once. To handle this correctly, we’d need to batch multiple ParallelMapOperation prompts into a single LLM call.

I’ve started working on this already but paused due to the size of the current PR. If we’re aligned on batching, I’m happy to continue and get this integrated. Let me know if you want me to go ahead or focus elsewhere.

Appreciate the quick merge!

@orban orban deleted the map-batching branch October 1, 2024 01:01
@shreyashankar
Copy link
Collaborator

Sounds good to me! It would be great to also limit the number of concurrent LLM calls for the ParallelMapOperation. The PR should be a lot smaller :-) LMK if any issues come up! Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants