From 9e361751f854d04e9cfc260e82d6ed9b95ec0754 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Mon, 26 Aug 2024 15:29:43 +0800 Subject: [PATCH] Support delimiter based splitting --- README.md | 10 ++- motion/operations/split.py | 89 +++++++++++++------ .../map_optimizer/operation_creators.py | 3 +- tests/test_basic.py | 3 +- tests/test_split.py | 60 ++++++++++++- 5 files changed, 132 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 90a200d4..368d83d8 100644 --- a/README.md +++ b/README.md @@ -306,11 +306,15 @@ Required parameters: - type: Must be set to "split". - split_key: The key of the field containing the text to split. -- chunk_size: The maximum size of each chunk in tokens. +- method: The method to use for splitting. Options are "delimiter" and "token_count". +- method_kwargs: A dictionary of keyword arguments to pass to the splitting method. + - delimiter: The delimiter to use for splitting. Only used if method is "delimiter". + - token_count: The maximum number of tokens to include in each chunk. Only used if method is "token_count". Optional parameters: - model: The language model's tokenizer to use; falls back to default_model if not specified. Note that we don't actually run a language model here. +- num_splits_to_group: The number of splits to group together into one chunk. Only used if method is "delimiter". Example: @@ -318,7 +322,9 @@ Example: split_operation: type: split split_key: content - chunk_size: 150 + method: token_count + method_kwargs: + token_count: 150 model: gpt-4o-mini ``` diff --git a/motion/operations/split.py b/motion/operations/split.py index e48d54ea..c4264eec 100644 --- a/motion/operations/split.py +++ b/motion/operations/split.py @@ -11,7 +11,7 @@ class SplitOperation(BaseOperation): A class that implements a split operation on input data, dividing it into manageable chunks. This class extends BaseOperation to: - 1. Split input data into chunks of specified size based on the 'split_key' and 'chunk_size' configuration. + 1. Split input data into chunks of specified size based on the 'split_key' and 'token_count' configuration. 2. Assign unique identifiers to each original document and number chunks sequentially. 3. Return results containing: - {split_key}_chunk: The content of the split chunk. @@ -24,7 +24,7 @@ def __init__(self, *args, **kwargs): self.name = self.config["name"] def syntax_check(self) -> None: - required_keys = ["split_key", "chunk_size"] + required_keys = ["split_key", "method", "method_kwargs"] for key in required_keys: if key not in self.config: raise ValueError( @@ -34,48 +34,81 @@ def syntax_check(self) -> None: if not isinstance(self.config["split_key"], str): raise TypeError("'split_key' must be a string") - if ( - not isinstance(self.config["chunk_size"], int) - or self.config["chunk_size"] <= 0 - ): - raise ValueError("'chunk_size' must be a positive integer") + if self.config["method"] not in ["token_count", "delimiter"]: + raise ValueError(f"Invalid method '{self.config['method']}'") + + if self.config["method"] == "token_count": + if ( + not isinstance(self.config["method_kwargs"]["token_count"], int) + or self.config["method_kwargs"]["token_count"] <= 0 + ): + raise ValueError("'token_count' must be a positive integer") + elif self.config["method"] == "delimiter": + if not isinstance(self.config["method_kwargs"]["delimiter"], str): + raise ValueError("'delimiter' must be a string") if "model" in self.config and not isinstance(self.config["model"], str): raise TypeError("'model' in configuration must be a string") def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: split_key = self.config["split_key"] - chunk_size = self.config["chunk_size"] - results = [] - cost = 0.0 - + method = self.config["method"] + method_kwargs = self.config["method_kwargs"] encoder = tiktoken.encoding_for_model( self.config.get("model", self.default_model) ) + results = [] + cost = 0.0 for item in input_data: if split_key not in item: raise KeyError(f"Split key '{split_key}' not found in item") content = item[split_key] - tokens = encoder.encode(content) - - # Generate a unique document ID doc_id = str(uuid.uuid4()) - for chunk_num, i in enumerate(range(0, len(tokens), chunk_size), start=1): - chunk_tokens = tokens[i : i + chunk_size] - chunk = encoder.decode(chunk_tokens) - - result = item.copy() - result.update( - { - f"{split_key}_chunk": chunk, - f"{self.name}_id": doc_id, - f"{self.name}_chunk_num": chunk_num, - } - ) - - results.append(result) + if method == "token_count": + token_count = method_kwargs["token_count"] + tokens = encoder.encode(content) + + for chunk_num, i in enumerate( + range(0, len(tokens), token_count), start=1 + ): + chunk_tokens = tokens[i : i + token_count] + chunk = encoder.decode(chunk_tokens) + + result = item.copy() + result.update( + { + f"{split_key}_chunk": chunk, + f"{self.name}_id": doc_id, + f"{self.name}_chunk_num": chunk_num, + } + ) + results.append(result) + + elif method == "delimiter": + delimiter = method_kwargs["delimiter"] + num_splits_to_group = method_kwargs.get("num_splits_to_group", 1) + chunks = content.split(delimiter) + + # Get rid of empty chunks + chunks = [chunk for chunk in chunks if chunk.strip()] + + for chunk_num, i in enumerate( + range(0, len(chunks), num_splits_to_group), start=1 + ): + grouped_chunks = chunks[i : i + num_splits_to_group] + joined_chunk = delimiter.join(grouped_chunks).strip() + + result = item.copy() + result.update( + { + f"{split_key}_chunk": joined_chunk, + f"{self.name}_id": doc_id, + f"{self.name}_chunk_num": chunk_num, + } + ) + results.append(result) return results, cost diff --git a/motion/optimizers/map_optimizer/operation_creators.py b/motion/optimizers/map_optimizer/operation_creators.py index 8b272df2..fef30e2c 100644 --- a/motion/optimizers/map_optimizer/operation_creators.py +++ b/motion/optimizers/map_optimizer/operation_creators.py @@ -58,7 +58,8 @@ def create_split_map_gather_operations( "type": "split", "name": split_name, "split_key": split_key, - "chunk_size": chunk_size, + "method": "token_count", + "method_kwargs": {"token_count": chunk_size}, } pipeline.append(split_config) diff --git a/tests/test_basic.py b/tests/test_basic.py index 8b07f62d..2552f65b 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -297,7 +297,8 @@ def split_config(): return { "type": "split", "split_key": "content", - "chunk_size": 4, + "method": "token_count", + "method_kwargs": {"token_count": 4}, "name": "split_doc", } diff --git a/tests/test_split.py b/tests/test_split.py index 58390875..33ab0662 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -19,11 +19,23 @@ def split_config(): return { "type": "split", "split_key": "content", - "chunk_size": 10, + "method": "token_count", + "method_kwargs": {"token_count": 10}, "name": "split_doc", } +@pytest.fixture +def split_config_delimiter(): + return { + "type": "split", + "split_key": "content", + "method": "delimiter", + "method_kwargs": {"delimiter": "\n", "num_splits_to_group": 2}, + "name": "split_doc_delimiter", + } + + @pytest.fixture def map_config(): return { @@ -184,3 +196,49 @@ def test_split_map_gather_empty_input( gather_results, gather_cost = gather_op.execute(map_results) assert len(gather_results) == 0 assert gather_cost == 0 + + +def test_split_delimiter_no_summarization( + split_config_delimiter, default_model, max_threads +): + input_data = [ + {"id": "1", "content": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5\nLine 6"}, + {"id": "2", "content": "Paragraph 1\n\nParagraph 2\n\nParagraph 3"}, + ] + + split_op = SplitOperation(split_config_delimiter, default_model, max_threads) + results, cost = split_op.execute(input_data) + + assert len(results) == 5 # 3 chunks for first item, 2 for second + assert cost == 0 # No LLM calls, so cost should be 0 + + # Check first item's chunks + assert results[0]["content_chunk"] == "Line 1\nLine 2" + assert results[1]["content_chunk"] == "Line 3\nLine 4" + assert results[2]["content_chunk"] == "Line 5\nLine 6" + + # Check second item's chunks + assert results[3]["content_chunk"] == "Paragraph 1\nParagraph 2" + assert results[4]["content_chunk"] == "Paragraph 3" + + # Check that all results have the necessary fields + for result in results: + assert "split_doc_delimiter_id" in result + assert "split_doc_delimiter_chunk_num" in result + assert "id" in result # Original field should be preserved + + # Check that chunk numbers are correct + assert results[0]["split_doc_delimiter_chunk_num"] == 1 + assert results[1]["split_doc_delimiter_chunk_num"] == 2 + assert results[2]["split_doc_delimiter_chunk_num"] == 3 + assert results[3]["split_doc_delimiter_chunk_num"] == 1 + assert results[4]["split_doc_delimiter_chunk_num"] == 2 + + # Check that document IDs are consistent within each original item + assert ( + results[0]["split_doc_delimiter_id"] + == results[1]["split_doc_delimiter_id"] + == results[2]["split_doc_delimiter_id"] + ) + assert results[3]["split_doc_delimiter_id"] == results[4]["split_doc_delimiter_id"] + assert results[0]["split_doc_delimiter_id"] != results[3]["split_doc_delimiter_id"]