Skip to content

Commit

Permalink
ai21[minor]: AI21 Labs Semantic Text Splitter support (#19510)
Browse files Browse the repository at this point in the history
Description: Added support for AI21 Labs model - Segmentation, as a Text
Splitter
Dependencies: ai21, langchain-text-splitter
Twitter handle: https://github.com/AI21Labs

---------

Co-authored-by: Bagatur <[email protected]>
Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
3 people authored and hinthornw committed Apr 26, 2024
1 parent 26bf663 commit 533a6a1
Show file tree
Hide file tree
Showing 11 changed files with 976 additions and 15 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ LangChain offers many different types of text splitters. These all live in the `
| Token | Tokens | | Splits text on tokens. There exist a few different ways to measure tokens. |
| Character | A user defined character | | Splits text based on a user defined character. One of the simpler methods. |
| [Experimental] Semantic Chunker | Sentences | | First splits on sentences. Then combines ones next to each other if they are semantically similar enough. Taken from [Greg Kamradt](https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/5_Levels_Of_Text_Splitting.ipynb) |
| [AI21 Semantic Text Splitter](/docs/integrations/document_transformers/ai21_semantic_text_splitter) | Semantics || Identifies distinct topics that form coherent pieces of text and splits along those. |


## Evaluate text splitters
Expand Down
16 changes: 16 additions & 0 deletions libs/partners/ai21/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,20 @@ chain = tsm | StrOutputParser()
response = chain.invoke(
{"context": "Your context", "question": "Your question"},
)
```

## Text Splitters

### Semantic Text Splitter

You can use AI21's semantic text splitter to split a text into segments.
Instead of merely using punctuation and newlines to divide the text, it identifies distinct topics that will work well together and will form a coherent piece of text.

For a list for examples, see [this page](https://github.com/langchain-ai/langchain/blob/master/docs/docs/modules/data_connection/document_transformers/semantic_text_splitter.ipynb).

```python
from langchain_ai21 import AI21SemanticTextSplitter

splitter = AI21SemanticTextSplitter()
response = splitter.split_text("Your text")
```
2 changes: 2 additions & 0 deletions libs/partners/ai21/langchain_ai21/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from langchain_ai21.contextual_answers import AI21ContextualAnswers
from langchain_ai21.embeddings import AI21Embeddings
from langchain_ai21.llms import AI21LLM
from langchain_ai21.semantic_text_splitter import AI21SemanticTextSplitter

__all__ = [
"AI21LLM",
"ChatAI21",
"AI21Embeddings",
"AI21ContextualAnswers",
"AI21SemanticTextSplitter",
]
158 changes: 158 additions & 0 deletions libs/partners/ai21/langchain_ai21/semantic_text_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import copy
import logging
import re
from typing import (
Any,
Iterable,
List,
Optional,
)

from ai21.models import DocumentType
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import SecretStr
from langchain_text_splitters import TextSplitter

from langchain_ai21.ai21_base import AI21Base

logger = logging.getLogger(__name__)


class AI21SemanticTextSplitter(TextSplitter):
"""Splitting text into coherent and readable units,
based on distinct topics and lines
"""

def __init__(
self,
chunk_size: int = 0,
chunk_overlap: int = 0,
client: Optional[Any] = None,
api_key: Optional[SecretStr] = None,
api_host: Optional[str] = None,
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
**kwargs,
)

self._segmentation = AI21Base(
client=client,
api_key=api_key,
api_host=api_host,
timeout_sec=timeout_sec,
num_retries=num_retries,
).client.segmentation

def split_text(self, source: str) -> List[str]:
"""Split text into multiple components.
Args:
source: Specifies the text input for text segmentation
"""
response = self._segmentation.create(
source=source, source_type=DocumentType.TEXT
)

segments = [segment.segment_text for segment in response.segments]

if self._chunk_size > 0:
return self._merge_splits_no_seperator(segments)

return segments

def split_text_to_documents(self, source: str) -> List[Document]:
"""Split text into multiple documents.
Args:
source: Specifies the text input for text segmentation
"""
response = self._segmentation.create(
source=source, source_type=DocumentType.TEXT
)

return [
Document(
page_content=segment.segment_text,
metadata={"source_type": segment.segment_type},
)
for segment in response.segments
]

def create_documents(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []

for i, text in enumerate(texts):
normalized_text = self._normalized_text(text)
index = 0
previous_chunk_len = 0

for chunk in self.split_text_to_documents(text):
# merge metadata from user (if exists) and from segmentation api
metadata = copy.deepcopy(_metadatas[i])
metadata.update(chunk.metadata)

if self._add_start_index:
# find the start index of the chunk
offset = index + previous_chunk_len - self._chunk_overlap
normalized_chunk = self._normalized_text(chunk.page_content)
index = normalized_text.find(normalized_chunk, max(0, offset))
metadata["start_index"] = index
previous_chunk_len = len(normalized_chunk)

documents.append(
Document(
page_content=chunk.page_content,
metadata=metadata,
)
)

return documents

def _normalized_text(self, string: str) -> str:
"""Use regular expression to replace sequences of '\n'"""
return re.sub(r"\s+", "", string)

def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
"""This method overrides the default implementation of TextSplitter"""
return self._merge_splits_no_seperator(splits)

def _merge_splits_no_seperator(self, splits: Iterable[str]) -> List[str]:
"""Merge splits into chunks.
If the segment size is bigger than chunk_size,
it will be left as is (won't be cut to match to chunk_size).
If the segment size is smaller than chunk_size,
it will be merged with the next segment until the chunk_size is reached.
"""
chunks = []
current_chunk = ""

for split in splits:
split_len = self._length_function(split)

if split_len > self._chunk_size:
logger.warning(
f"Split of length {split_len}"
f"exceeds chunk size {self._chunk_size}."
)

if self._length_function(current_chunk) + split_len > self._chunk_size:
if current_chunk != "":
chunks.append(current_chunk)
current_chunk = ""

current_chunk += split

if current_chunk != "":
chunks.append(current_chunk)

return chunks
46 changes: 32 additions & 14 deletions libs/partners/ai21/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/partners/ai21/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.22"
langchain-text-splitters = "^0.0.1"
ai21 = "^2.1.2"

[tool.poetry.group.test]
Expand Down
Loading

0 comments on commit 533a6a1

Please sign in to comment.