Skip to content

Commit

Permalink
Support delimiter based splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Aug 26, 2024
1 parent 6ccde64 commit 9e36175
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 33 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,19 +306,25 @@ Required parameters:

- type: Must be set to "split".
- split_key: The key of the field containing the text to split.
- chunk_size: The maximum size of each chunk in tokens.
- method: The method to use for splitting. Options are "delimiter" and "token_count".
- method_kwargs: A dictionary of keyword arguments to pass to the splitting method.
- delimiter: The delimiter to use for splitting. Only used if method is "delimiter".
- token_count: The maximum number of tokens to include in each chunk. Only used if method is "token_count".

Optional parameters:

- model: The language model's tokenizer to use; falls back to default_model if not specified. Note that we don't actually run a language model here.
- num_splits_to_group: The number of splits to group together into one chunk. Only used if method is "delimiter".

Example:

```yaml
split_operation:
type: split
split_key: content
chunk_size: 150
method: token_count
method_kwargs:
token_count: 150
model: gpt-4o-mini
```

Expand Down
89 changes: 61 additions & 28 deletions motion/operations/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class SplitOperation(BaseOperation):
A class that implements a split operation on input data, dividing it into manageable chunks.
This class extends BaseOperation to:
1. Split input data into chunks of specified size based on the 'split_key' and 'chunk_size' configuration.
1. Split input data into chunks of specified size based on the 'split_key' and 'token_count' configuration.
2. Assign unique identifiers to each original document and number chunks sequentially.
3. Return results containing:
- {split_key}_chunk: The content of the split chunk.
Expand All @@ -24,7 +24,7 @@ def __init__(self, *args, **kwargs):
self.name = self.config["name"]

def syntax_check(self) -> None:
required_keys = ["split_key", "chunk_size"]
required_keys = ["split_key", "method", "method_kwargs"]
for key in required_keys:
if key not in self.config:
raise ValueError(
Expand All @@ -34,48 +34,81 @@ def syntax_check(self) -> None:
if not isinstance(self.config["split_key"], str):
raise TypeError("'split_key' must be a string")

if (
not isinstance(self.config["chunk_size"], int)
or self.config["chunk_size"] <= 0
):
raise ValueError("'chunk_size' must be a positive integer")
if self.config["method"] not in ["token_count", "delimiter"]:
raise ValueError(f"Invalid method '{self.config['method']}'")

if self.config["method"] == "token_count":
if (
not isinstance(self.config["method_kwargs"]["token_count"], int)
or self.config["method_kwargs"]["token_count"] <= 0
):
raise ValueError("'token_count' must be a positive integer")
elif self.config["method"] == "delimiter":
if not isinstance(self.config["method_kwargs"]["delimiter"], str):
raise ValueError("'delimiter' must be a string")

if "model" in self.config and not isinstance(self.config["model"], str):
raise TypeError("'model' in configuration must be a string")

def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
split_key = self.config["split_key"]
chunk_size = self.config["chunk_size"]
results = []
cost = 0.0

method = self.config["method"]
method_kwargs = self.config["method_kwargs"]
encoder = tiktoken.encoding_for_model(
self.config.get("model", self.default_model)
)
results = []
cost = 0.0

for item in input_data:
if split_key not in item:
raise KeyError(f"Split key '{split_key}' not found in item")

content = item[split_key]
tokens = encoder.encode(content)

# Generate a unique document ID
doc_id = str(uuid.uuid4())

for chunk_num, i in enumerate(range(0, len(tokens), chunk_size), start=1):
chunk_tokens = tokens[i : i + chunk_size]
chunk = encoder.decode(chunk_tokens)

result = item.copy()
result.update(
{
f"{split_key}_chunk": chunk,
f"{self.name}_id": doc_id,
f"{self.name}_chunk_num": chunk_num,
}
)

results.append(result)
if method == "token_count":
token_count = method_kwargs["token_count"]
tokens = encoder.encode(content)

for chunk_num, i in enumerate(
range(0, len(tokens), token_count), start=1
):
chunk_tokens = tokens[i : i + token_count]
chunk = encoder.decode(chunk_tokens)

result = item.copy()
result.update(
{
f"{split_key}_chunk": chunk,
f"{self.name}_id": doc_id,
f"{self.name}_chunk_num": chunk_num,
}
)
results.append(result)

elif method == "delimiter":
delimiter = method_kwargs["delimiter"]
num_splits_to_group = method_kwargs.get("num_splits_to_group", 1)
chunks = content.split(delimiter)

# Get rid of empty chunks
chunks = [chunk for chunk in chunks if chunk.strip()]

for chunk_num, i in enumerate(
range(0, len(chunks), num_splits_to_group), start=1
):
grouped_chunks = chunks[i : i + num_splits_to_group]
joined_chunk = delimiter.join(grouped_chunks).strip()

result = item.copy()
result.update(
{
f"{split_key}_chunk": joined_chunk,
f"{self.name}_id": doc_id,
f"{self.name}_chunk_num": chunk_num,
}
)
results.append(result)

return results, cost
3 changes: 2 additions & 1 deletion motion/optimizers/map_optimizer/operation_creators.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def create_split_map_gather_operations(
"type": "split",
"name": split_name,
"split_key": split_key,
"chunk_size": chunk_size,
"method": "token_count",
"method_kwargs": {"token_count": chunk_size},
}
pipeline.append(split_config)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def split_config():
return {
"type": "split",
"split_key": "content",
"chunk_size": 4,
"method": "token_count",
"method_kwargs": {"token_count": 4},
"name": "split_doc",
}

Expand Down
60 changes: 59 additions & 1 deletion tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,23 @@ def split_config():
return {
"type": "split",
"split_key": "content",
"chunk_size": 10,
"method": "token_count",
"method_kwargs": {"token_count": 10},
"name": "split_doc",
}


@pytest.fixture
def split_config_delimiter():
return {
"type": "split",
"split_key": "content",
"method": "delimiter",
"method_kwargs": {"delimiter": "\n", "num_splits_to_group": 2},
"name": "split_doc_delimiter",
}


@pytest.fixture
def map_config():
return {
Expand Down Expand Up @@ -184,3 +196,49 @@ def test_split_map_gather_empty_input(
gather_results, gather_cost = gather_op.execute(map_results)
assert len(gather_results) == 0
assert gather_cost == 0


def test_split_delimiter_no_summarization(
split_config_delimiter, default_model, max_threads
):
input_data = [
{"id": "1", "content": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5\nLine 6"},
{"id": "2", "content": "Paragraph 1\n\nParagraph 2\n\nParagraph 3"},
]

split_op = SplitOperation(split_config_delimiter, default_model, max_threads)
results, cost = split_op.execute(input_data)

assert len(results) == 5 # 3 chunks for first item, 2 for second
assert cost == 0 # No LLM calls, so cost should be 0

# Check first item's chunks
assert results[0]["content_chunk"] == "Line 1\nLine 2"
assert results[1]["content_chunk"] == "Line 3\nLine 4"
assert results[2]["content_chunk"] == "Line 5\nLine 6"

# Check second item's chunks
assert results[3]["content_chunk"] == "Paragraph 1\nParagraph 2"
assert results[4]["content_chunk"] == "Paragraph 3"

# Check that all results have the necessary fields
for result in results:
assert "split_doc_delimiter_id" in result
assert "split_doc_delimiter_chunk_num" in result
assert "id" in result # Original field should be preserved

# Check that chunk numbers are correct
assert results[0]["split_doc_delimiter_chunk_num"] == 1
assert results[1]["split_doc_delimiter_chunk_num"] == 2
assert results[2]["split_doc_delimiter_chunk_num"] == 3
assert results[3]["split_doc_delimiter_chunk_num"] == 1
assert results[4]["split_doc_delimiter_chunk_num"] == 2

# Check that document IDs are consistent within each original item
assert (
results[0]["split_doc_delimiter_id"]
== results[1]["split_doc_delimiter_id"]
== results[2]["split_doc_delimiter_id"]
)
assert results[3]["split_doc_delimiter_id"] == results[4]["split_doc_delimiter_id"]
assert results[0]["split_doc_delimiter_id"] != results[3]["split_doc_delimiter_id"]

0 comments on commit 9e36175

Please sign in to comment.