-
Notifications
You must be signed in to change notification settings - Fork 15.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ai21[minor]: AI21 Labs Semantic Text Splitter support (#19510)
Description: Added support for AI21 Labs model - Segmentation, as a Text Splitter Dependencies: ai21, langchain-text-splitter Twitter handle: https://github.com/AI21Labs --------- Co-authored-by: Bagatur <[email protected]> Co-authored-by: Bagatur <[email protected]>
- Loading branch information
Showing
11 changed files
with
976 additions
and
15 deletions.
There are no files selected for viewing
466 changes: 466 additions & 0 deletions
466
docs/docs/integrations/document_transformers/ai21_semantic_text_splitter.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
158 changes: 158 additions & 0 deletions
158
libs/partners/ai21/langchain_ai21/semantic_text_splitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import copy | ||
import logging | ||
import re | ||
from typing import ( | ||
Any, | ||
Iterable, | ||
List, | ||
Optional, | ||
) | ||
|
||
from ai21.models import DocumentType | ||
from langchain_core.documents import Document | ||
from langchain_core.pydantic_v1 import SecretStr | ||
from langchain_text_splitters import TextSplitter | ||
|
||
from langchain_ai21.ai21_base import AI21Base | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AI21SemanticTextSplitter(TextSplitter): | ||
"""Splitting text into coherent and readable units, | ||
based on distinct topics and lines | ||
""" | ||
|
||
def __init__( | ||
self, | ||
chunk_size: int = 0, | ||
chunk_overlap: int = 0, | ||
client: Optional[Any] = None, | ||
api_key: Optional[SecretStr] = None, | ||
api_host: Optional[str] = None, | ||
timeout_sec: Optional[float] = None, | ||
num_retries: Optional[int] = None, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Create a new TextSplitter.""" | ||
super().__init__( | ||
chunk_size=chunk_size, | ||
chunk_overlap=chunk_overlap, | ||
**kwargs, | ||
) | ||
|
||
self._segmentation = AI21Base( | ||
client=client, | ||
api_key=api_key, | ||
api_host=api_host, | ||
timeout_sec=timeout_sec, | ||
num_retries=num_retries, | ||
).client.segmentation | ||
|
||
def split_text(self, source: str) -> List[str]: | ||
"""Split text into multiple components. | ||
Args: | ||
source: Specifies the text input for text segmentation | ||
""" | ||
response = self._segmentation.create( | ||
source=source, source_type=DocumentType.TEXT | ||
) | ||
|
||
segments = [segment.segment_text for segment in response.segments] | ||
|
||
if self._chunk_size > 0: | ||
return self._merge_splits_no_seperator(segments) | ||
|
||
return segments | ||
|
||
def split_text_to_documents(self, source: str) -> List[Document]: | ||
"""Split text into multiple documents. | ||
Args: | ||
source: Specifies the text input for text segmentation | ||
""" | ||
response = self._segmentation.create( | ||
source=source, source_type=DocumentType.TEXT | ||
) | ||
|
||
return [ | ||
Document( | ||
page_content=segment.segment_text, | ||
metadata={"source_type": segment.segment_type}, | ||
) | ||
for segment in response.segments | ||
] | ||
|
||
def create_documents( | ||
self, texts: List[str], metadatas: Optional[List[dict]] = None | ||
) -> List[Document]: | ||
"""Create documents from a list of texts.""" | ||
_metadatas = metadatas or [{}] * len(texts) | ||
documents = [] | ||
|
||
for i, text in enumerate(texts): | ||
normalized_text = self._normalized_text(text) | ||
index = 0 | ||
previous_chunk_len = 0 | ||
|
||
for chunk in self.split_text_to_documents(text): | ||
# merge metadata from user (if exists) and from segmentation api | ||
metadata = copy.deepcopy(_metadatas[i]) | ||
metadata.update(chunk.metadata) | ||
|
||
if self._add_start_index: | ||
# find the start index of the chunk | ||
offset = index + previous_chunk_len - self._chunk_overlap | ||
normalized_chunk = self._normalized_text(chunk.page_content) | ||
index = normalized_text.find(normalized_chunk, max(0, offset)) | ||
metadata["start_index"] = index | ||
previous_chunk_len = len(normalized_chunk) | ||
|
||
documents.append( | ||
Document( | ||
page_content=chunk.page_content, | ||
metadata=metadata, | ||
) | ||
) | ||
|
||
return documents | ||
|
||
def _normalized_text(self, string: str) -> str: | ||
"""Use regular expression to replace sequences of '\n'""" | ||
return re.sub(r"\s+", "", string) | ||
|
||
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: | ||
"""This method overrides the default implementation of TextSplitter""" | ||
return self._merge_splits_no_seperator(splits) | ||
|
||
def _merge_splits_no_seperator(self, splits: Iterable[str]) -> List[str]: | ||
"""Merge splits into chunks. | ||
If the segment size is bigger than chunk_size, | ||
it will be left as is (won't be cut to match to chunk_size). | ||
If the segment size is smaller than chunk_size, | ||
it will be merged with the next segment until the chunk_size is reached. | ||
""" | ||
chunks = [] | ||
current_chunk = "" | ||
|
||
for split in splits: | ||
split_len = self._length_function(split) | ||
|
||
if split_len > self._chunk_size: | ||
logger.warning( | ||
f"Split of length {split_len}" | ||
f"exceeds chunk size {self._chunk_size}." | ||
) | ||
|
||
if self._length_function(current_chunk) + split_len > self._chunk_size: | ||
if current_chunk != "": | ||
chunks.append(current_chunk) | ||
current_chunk = "" | ||
|
||
current_chunk += split | ||
|
||
if current_chunk != "": | ||
chunks.append(current_chunk) | ||
|
||
return chunks |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.