From b2497860794ed744a02aecb1131c275565e49a35 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 11 Dec 2024 20:12:06 -0500 Subject: [PATCH 01/11] Add core RAG types. --- sdks/python/apache_beam/ml/rag/__init__.py | 25 +++++++++ sdks/python/apache_beam/ml/rag/types.py | 65 ++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 sdks/python/apache_beam/ml/rag/__init__.py create mode 100644 sdks/python/apache_beam/ml/rag/types.py diff --git a/sdks/python/apache_beam/ml/rag/__init__.py b/sdks/python/apache_beam/ml/rag/__init__.py new file mode 100644 index 000000000000..554beb9d7aba --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/__init__.py @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Apache Beam RAG (Retrieval Augmented Generation) components. +This package provides components for building RAG pipelines in Apache Beam, +including: +- Chunking +- Embedding generation +- Vector storage +- Vector search enrichment +""" diff --git a/sdks/python/apache_beam/ml/rag/types.py b/sdks/python/apache_beam/ml/rag/types.py new file mode 100644 index 000000000000..82beeba5cc6c --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/types.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Core types for RAG pipelines. +This module contains the core dataclasses used throughout the RAG pipeline +implementation, including Chunk and Embedding types that define the data +contracts between different stages of the pipeline. +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Any +import uuid + + +@dataclass +class Content: + """Container for embeddable content. Add new types as when as necessary. + """ + text: Optional[str] = None + + +@dataclass +class Embedding: + """Represents vector embeddings. + + Attributes: + dense_embedding: Dense vector representation + sparse_embedding: Optional sparse vector representation for hybrid + search + """ + dense_embedding: Optional[List[float]] = None + # For hybrid search + sparse_embedding: Optional[Tuple[List[int], List[float]]] = None + + +@dataclass +class Chunk: + """Represents a chunk of embeddable content with metadata. + + Attributes: + content: The actual content of the chunk + id: Unique identifier for the chunk + index: Index of this chunk within the original document + metadata: Additional metadata about the chunk (e.g., document source) + embedding: Vector embeddings of the content + """ + content: Content + id: str = field(default_factory=lambda: str(uuid.uuid4())) + index: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + embedding: Optional[Embedding] = None From dec2ed93a23f7cb64ae606ca22e4b97b0f52e73a Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 11 Dec 2024 20:15:07 -0500 Subject: [PATCH 02/11] Add chunking base. --- .../apache_beam/ml/rag/chunking/__init__.py | 21 +++ .../apache_beam/ml/rag/chunking/base.py | 95 ++++++++++++++ .../apache_beam/ml/rag/chunking/base_test.py | 124 ++++++++++++++++++ 3 files changed, 240 insertions(+) create mode 100644 sdks/python/apache_beam/ml/rag/chunking/__init__.py create mode 100644 sdks/python/apache_beam/ml/rag/chunking/base.py create mode 100644 sdks/python/apache_beam/ml/rag/chunking/base_test.py diff --git a/sdks/python/apache_beam/ml/rag/chunking/__init__.py b/sdks/python/apache_beam/ml/rag/chunking/__init__.py new file mode 100644 index 000000000000..34a6a966b19e --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/chunking/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Chunking components for RAG pipelines. +This module provides components for splitting text into chunks for RAG +pipelines. +""" diff --git a/sdks/python/apache_beam/ml/rag/chunking/base.py b/sdks/python/apache_beam/ml/rag/chunking/base.py new file mode 100644 index 000000000000..9ab0dbde49b7 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/chunking/base.py @@ -0,0 +1,95 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import functools +from collections.abc import Callable +from typing import Any +from typing import Dict +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.transforms.base import MLTransformProvider + +ChunkIdFn = Callable[[Chunk], str] + + +def _assign_chunk_id(chunk_id_fn: ChunkIdFn, chunk: Chunk): + chunk.id = chunk_id_fn(chunk) + return chunk + + +class ChunkingTransformProvider(MLTransformProvider): + def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): + """Base class for chunking transforms in RAG pipelines. + + ChunkingTransformProvider defines the interface for splitting documents + into chunks for embedding and retrieval. Implementations should define how + to split content while preserving metadata and managing chunk IDs. + + The transform flow: + 1. Takes input documents with content and metadata + 2. Splits content into chunks using implementation-specific logic + 3. Preserves document metadata in resulting chunks + 4. Optionally assigns unique IDs to chunks (configurable via chunk_id_fn). + + Example usage: + ```python + class MyChunker(ChunkingTransformProvider): + def get_splitter_transform(self): + return beam.ParDo(MySplitterDoFn()) + + chunker = MyChunker(chunk_id_fn=my_id_function) + + with beam.Pipeline() as p: + chunks = ( + p + | beam.Create([{'text': 'document...', 'source': 'doc.txt'}]) + | MLTransform(...).with_transform(chunker)) + ``` + + Args: + chunk_id_fn: Optional function to generate chunk IDs. If not provided, + random UUIDs will be used. Function should take a Chunk and return + str. + """ + self.assign_chunk_id_fn = functools.partial( + _assign_chunk_id, chunk_id_fn) if chunk_id_fn is not None else None + + @abc.abstractmethod + def get_splitter_transform( + self + ) -> beam.PTransform[beam.PCollection[Dict[str, Any]], + beam.PCollection[Chunk]]: + """Creates transforms that emits splits for given content.""" + raise NotImplementedError( + "Subclasses must implement get_splitter_transform") + + def get_ptransform_for_processing( + self, **kwargs + ) -> beam.PTransform[beam.PCollection[Dict[str, Any]], + beam.PCollection[Chunk]]: + """Creates transform for processing documents into chunks.""" + ptransform = ( + "Split document" >> + self.get_splitter_transform().with_output_types(Chunk)) + if self.assign_chunk_id_fn: + ptransform = ( + ptransform | "Assign chunk id" >> beam.Map( + self.assign_chunk_id_fn).with_output_types(Chunk)) + return ptransform diff --git a/sdks/python/apache_beam/ml/rag/chunking/base_test.py b/sdks/python/apache_beam/ml/rag/chunking/base_test.py new file mode 100644 index 000000000000..5a8d816343da --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/chunking/base_test.py @@ -0,0 +1,124 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for apache_beam.ml.rag.chunking.base.""" + +import unittest +import pytest + +import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider, ChunkIdFn +from apache_beam.ml.rag.types import Chunk, Content +from typing import Optional, Dict, Any + + +class WordSplitter(beam.DoFn): + def process(self, element): + words = element['text'].split() + for i, word in enumerate(words): + yield Chunk( + content=Content(text=word), + index=i, + metadata={'source': element['source']}) + + +class MockChunkingProvider(ChunkingTransformProvider): + def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): + super().__init__(chunk_id_fn=chunk_id_fn) + + def get_splitter_transform( + self + ) -> beam.PTransform[beam.PCollection[Dict[str, Any]], + beam.PCollection[Chunk]]: + return beam.ParDo(WordSplitter()) + + +def chunk_equals(expected, actual): + """Custom equality function for Chunk objects.""" + if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): + return False + # Don't compare IDs since they're randomly generated + return ( + expected.index == actual.index and expected.content == actual.content and + expected.metadata == actual.metadata) + + +def id_equals(expected, actual): + """Custom equality function for Chunk object id's.""" + if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): + return False + return (expected.id == actual.id) + + +@pytest.mark.uses_transformers +class ChunkingTransformProviderTest(unittest.TestCase): + def setUp(self): + self.test_doc = {'text': 'hello world test', 'source': 'test.txt'} + + def test_chunking_transform(self): + """Test the complete chunking transform.""" + provider = MockChunkingProvider() + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([self.test_doc]) + | provider.get_ptransform_for_processing()) + + expected = [ + Chunk( + content=Content(text="hello"), + index=0, + metadata={'source': 'test.txt'}), + Chunk( + content=Content(text="world"), + index=1, + metadata={'source': 'test.txt'}), + Chunk( + content=Content(text="test"), + index=2, + metadata={'source': 'test.txt'}) + ] + + assert_that(chunks, equal_to(expected, equals_fn=chunk_equals)) + + def test_custom_chunk_id_fn(self): + """Test the a custom chink id function.""" + def source_index_id_fn(chunk: Chunk): + return f"{chunk.metadata['source']}_{chunk.index}" + + provider = MockChunkingProvider(chunk_id_fn=source_index_id_fn) + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([self.test_doc]) + | provider.get_ptransform_for_processing()) + + expected = [ + Chunk(content=Content(text="hello"), id="test.txt_0"), + Chunk(content=Content(text="world"), id="test.txt_1"), + Chunk(content=Content(text="test"), id="test.txt_2") + ] + + assert_that(chunks, equal_to(expected, equals_fn=id_equals)) + + +if __name__ == '__main__': + unittest.main() From f32f86924a2275df793737c9edbb29e7e300598b Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 11 Dec 2024 23:04:12 -0500 Subject: [PATCH 03/11] Add LangChain TextSplitter chunking. --- .../apache_beam/ml/rag/chunking/langchain.py | 120 ++++++++++ .../ml/rag/chunking/langchain_test.py | 216 ++++++++++++++++++ 2 files changed, 336 insertions(+) create mode 100644 sdks/python/apache_beam/ml/rag/chunking/langchain.py create mode 100644 sdks/python/apache_beam/ml/rag/chunking/langchain_test.py diff --git a/sdks/python/apache_beam/ml/rag/chunking/langchain.py b/sdks/python/apache_beam/ml/rag/chunking/langchain.py new file mode 100644 index 000000000000..f23e7be65d49 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/chunking/langchain.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.rag.chunking.base import ChunkIdFn +from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content + +try: + from langchain.text_splitter import TextSplitter +except ImportError: + TextSplitter = None # type: ignore + + +class LangChainChunker(ChunkingTransformProvider): + def __init__( + self, + text_splitter: TextSplitter, + document_field: str, + metadata_fields: List[str], + chunk_id_fn: Optional[ChunkIdFn] = None): + """A ChunkingTransformProvider that uses LangChain text splitters. + + This provider integrates LangChain's text splitting capabilities into + Beam's MLTransform framework. It supports various text splitting strategies + through LangChain's TextSplitter interface, including recursive character + splitting and other methods. + + The provider: + - Takes documents with text content and metadata + - Splits text using configured LangChain splitter + - Preserves document metadata in resulting chunks + - Assigns unique IDs to chunks (configurable via chunk_id_fn) + + Example usage: + ```python + from langchain.text_splitter import RecursiveCharacterTextSplitter + + splitter = RecursiveCharacterTextSplitter( + chunk_size=100, + chunk_overlap=20 + ) + + chunker = LangChainChunker(text_splitter=splitter) + + with beam.Pipeline() as p: + chunks = ( + p + | beam.Create([{'text': 'long document...', 'source': 'doc.txt'}]) + | MLTransform(...).with_transform(chunker)) + ``` + + Args: + text_splitter: A LangChain TextSplitter instance that defines how + documents are split into chunks. + metadata_fields: List of field names to copy from input documents to + chunk metadata. These fields will be preserved in each chunk created + from the document. + chunk_id_fn: Optional function that take a Chunk and return str to + generate chunk IDs. If not provided, random UUIDs will be used. + """ + if not TextSplitter: + raise ImportError( + "langchain is required to use LangChainChunker" + "Please install it with using `pip install langchain`.") + if not isinstance(text_splitter, TextSplitter): + raise TypeError("text_splitter must be a LangChain TextSplitter") + if not document_field: + raise ValueError("document_field cannot be empty") + super().__init__(chunk_id_fn) + self.text_splitter = text_splitter + self.document_field = document_field + self.metadata_fields = metadata_fields + + def get_splitter_transform( + self + ) -> beam.PTransform[beam.PCollection[Dict[str, Any]], + beam.PCollection[Chunk]]: + return "Langchain text split" >> beam.ParDo( + _LangChainTextSplitter( + text_splitter=self.text_splitter, + document_field=self.document_field, + metadata_fields=self.metadata_fields)) + + +class _LangChainTextSplitter(beam.DoFn): + def __init__( + self, + text_splitter: TextSplitter, + document_field: str, + metadata_fields: List[str]): + self.text_splitter = text_splitter + self.document_field = document_field + self.metadata_fields = metadata_fields + + def process(self, element): + text_chunks = self.text_splitter.split_text(element[self.document_field]) + metadata = {field: element[field] for field in self.metadata_fields} + for i, text_chunk in enumerate(text_chunks): + yield Chunk(content=Content(text=text_chunk), index=i, metadata=metadata) diff --git a/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py new file mode 100644 index 000000000000..5bc17a824775 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py @@ -0,0 +1,216 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for apache_beam.ml.rag.chunking.langchain.""" + +import unittest + +import apache_beam as beam + +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.ml.rag.types import Chunk + +try: + from apache_beam.ml.rag.chunking.langchain import LangChainChunker + from langchain.text_splitter import ( + RecursiveCharacterTextSplitter, + CharacterTextSplitter, + ) + LANGCHAIN_AVAILABLE = True +except ImportError: + LANGCHAIN_AVAILABLE = False + +# Import optional dependencies +try: + from transformers import AutoTokenizer + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + + +def chunk_equals(expected, actual): + """Custom equality function for Chunk objects.""" + if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): + return False + return ( + expected.content == actual.content and expected.index == actual.index and + expected.metadata == actual.metadata) + + +@unittest.skipIf(not LANGCHAIN_AVAILABLE, 'langchain is not installed.') +class LangChainChunkingTest(unittest.TestCase): + def setUp(self): + self.simple_text = { + 'content': 'This is a simple test document. It has multiple sentences. ' + 'We will use it to test basic splitting.', + 'source': 'simple.txt', + 'language': 'en' + } + + self.complex_text = { + 'content': ( + 'The patient arrived at 2 p.m. yesterday. ' + 'Initial assessment was completed. ' + 'Lab results showed normal ranges. ' + 'Follow-up scheduled for next week.'), + 'source': 'medical.txt', + 'language': 'en' + } + + def test_no_metadata_fields(self): + """Test chunking with no metadata fields specified.""" + splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20) + provider = LangChainChunker( + document_field='content', metadata_fields=[], text_splitter=splitter) + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([self.simple_text]) + | provider.get_ptransform_for_processing()) + chunks_count = chunks | beam.combiners.Count.Globally() + + assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') + + assert_that(chunks, lambda x: all(c.metadata == {} for c in x)) + + def test_multiple_metadata_fields(self): + """Test chunking with multiple metadata fields.""" + splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20) + provider = LangChainChunker( + document_field='content', + metadata_fields=['source', 'language'], + text_splitter=splitter) + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([self.simple_text]) + | provider.get_ptransform_for_processing()) + chunks_count = chunks | beam.combiners.Count.Globally() + + assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') + assert_that( + chunks, + lambda x: all( + c.metadata == { + 'source': 'simple.txt', 'language': 'en' + } for c in x)) + + def test_recursive_splitter_no_overlap(self): + """Test RecursiveCharacterTextSplitter with no overlap.""" + splitter = RecursiveCharacterTextSplitter( + chunk_size=30, chunk_overlap=0, separators=[". "]) + provider = LangChainChunker( + document_field='content', + metadata_fields=['source'], + text_splitter=splitter) + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([self.simple_text]) + | provider.get_ptransform_for_processing()) + chunks_count = chunks | beam.combiners.Count.Globally() + + assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') + assert_that(chunks, lambda x: all(len(c.content.text) <= 30 for c in x)) + + @unittest.skipIf(not TRANSFORMERS_AVAILABLE, "transformers not available") + def test_huggingface_tokenizer_splitter(self): + """Test text splitter created from HuggingFace tokenizer.""" + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( + tokenizer, + chunk_size=10, # tokens + chunk_overlap=2 # tokens + ) + + provider = LangChainChunker( + document_field='content', + metadata_fields=['source'], + text_splitter=splitter) + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([self.simple_text]) + | provider.get_ptransform_for_processing()) + + def check_token_lengths(chunks): + for chunk in chunks: + # Verify each chunk's token length is within limits + num_tokens = len(tokenizer.encode(chunk.content.text)) + if not num_tokens <= 10: + raise AssertionError( + f"Chunk has {num_tokens} tokens, expected <= 10") + return True + + chunks_count = chunks | beam.combiners.Count.Globally() + + assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks') + assert_that(chunks, check_token_lengths) + + def test_invalid_document_field(self): + """Test that using an invalid document field raises KeyError.""" + splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20) + provider = LangChainChunker( + document_field='nonexistent', + metadata_fields={}, + text_splitter=splitter) + + with self.assertRaises(KeyError): + with TestPipeline() as p: + _ = ( + p + | beam.Create([self.simple_text]) + | provider.get_ptransform_for_processing()) + + def test_invalid_text_splitter(self): + """Test that using an invalid document field raises KeyError.""" + + with self.assertRaises(TypeError): + provider = LangChainChunker( + document_field='nonexistent', text_splitter="Not a text splitter!") + with TestPipeline() as p: + _ = ( + p + | beam.Create([self.simple_text]) + | provider.get_ptransform_for_processing()) + + def test_empty_text(self): + """Test that empty text produces no chunks.""" + empty_doc = {'content': '', 'source': 'empty.txt'} + + splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20) + provider = LangChainChunker( + document_field='content', + metadata_fields=['source'], + text_splitter=splitter) + + with TestPipeline() as p: + chunks = ( + p + | beam.Create([empty_doc]) + | provider.get_ptransform_for_processing()) + + assert_that(chunks, equal_to([])) + + +if __name__ == '__main__': + unittest.main() From 676b9663185ece65bcb731699c63c5520d85bf7f Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 11 Dec 2024 23:11:26 -0500 Subject: [PATCH 04/11] Add generic type support for embeddings. --- sdks/python/apache_beam/ml/transforms/base.py | 181 ++++++++++++------ .../apache_beam/ml/transforms/base_test.py | 60 +++--- .../ml/transforms/embeddings/huggingface.py | 4 +- 3 files changed, 152 insertions(+), 93 deletions(-) diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index a963f602a06d..8fd9e2f61ffe 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -15,22 +15,29 @@ # limitations under the License. import abc -import collections import logging import os import tempfile import uuid +from collections.abc import Callable from collections.abc import Mapping from collections.abc import Sequence from typing import Any +from typing import cast +from typing import Dict from typing import Generic +from typing import Iterable +from typing import List from typing import Optional from typing import TypeVar from typing import Union +import functools import jsonpickle import numpy as np +from dataclasses import dataclass + import apache_beam as beam from apache_beam.io.filesystems import FileSystems from apache_beam.metrics.metric import Metrics @@ -62,36 +69,31 @@ # Output of the apply() method of BaseOperation. OperationOutputT = TypeVar('OperationOutputT') +# Input to the EmbeddingTypeAdapter input_fn +EmbeddingTypeAdapterInputT = TypeVar( + 'EmbeddingTypeAdapterInputT') # e.g., Chunk +# Output of the EmbeddingTypeAdapter output_fn +EmbeddingTypeAdapterOutputT = TypeVar( + 'EmbeddingTypeAdapterOutputT') # e.g., Embedding -def _convert_list_of_dicts_to_dict_of_lists( - list_of_dicts: Sequence[dict[str, Any]]) -> dict[str, list[Any]]: - keys_to_element_list = collections.defaultdict(list) - input_keys = list_of_dicts[0].keys() - for d in list_of_dicts: - if set(d.keys()) != set(input_keys): - extra_keys = set(d.keys()) - set(input_keys) if len( - d.keys()) > len(input_keys) else set(input_keys) - set(d.keys()) - raise RuntimeError( - f'All the dicts in the input data should have the same keys. ' - f'Got: {extra_keys} instead.') - for key, value in d.items(): - keys_to_element_list[key].append(value) - return keys_to_element_list - - -def _convert_dict_of_lists_to_lists_of_dict( - dict_of_lists: dict[str, list[Any]]) -> list[dict[str, Any]]: - batch_length = len(next(iter(dict_of_lists.values()))) - result: list[dict[str, Any]] = [{} for _ in range(batch_length)] - # all the values in the dict_of_lists should have same length - for key, values in dict_of_lists.items(): - assert len(values) == batch_length, ( - "This function expects all the values " - "in the dict_of_lists to have same length." - ) - for i in range(len(values)): - result[i][key] = values[i] - return result + +@dataclass +class EmbeddingTypeAdapter(Generic[EmbeddingTypeAdapterInputT, + EmbeddingTypeAdapterOutputT]): + """Adapts input types to text for embedding and converts output embeddings. + + Args: + input_fn: Function to extract text for embedding from input type + output_fn: Function to create output type from input and embeddings + """ + input_fn: Callable[[Sequence[EmbeddingTypeAdapterInputT]], List[str]] + output_fn: Callable[[Sequence[EmbeddingTypeAdapterInputT], Sequence[Any]], + List[EmbeddingTypeAdapterOutputT]] + + def __reduce__(self): + """Custom serialization that preserves type information during + jsonpickle.""" + return (self.__class__, (self.input_fn, self.output_fn)) def _map_errors_to_beam_row(element, cls_name=None): @@ -182,24 +184,96 @@ def append_transform(self, transform: BaseOperation): """ +def _dict_input_fn(columns: Sequence[str], + batch: Sequence[Dict[str, Any]]) -> List[str]: + """Extract text from specified columns in batch.""" + if not batch or not isinstance(batch[0], dict): + raise TypeError( + 'Expected data to be dicts, got ' + f'{type(batch[0])} instead.') + + result = [] + expected_keys = set(batch[0].keys()) + expected_columns = set(columns) + # Process one batch item at a time + for item in batch: + item_keys = item.keys() + if set(item_keys) != expected_keys: + extra_keys = item_keys - expected_keys + missing_keys = expected_keys - item_keys + raise RuntimeError( + f'All dicts in batch must have the same keys. ' + f'extra keys: {extra_keys}, ' + f'missing keys: {missing_keys}') + missing_columns = expected_columns - item_keys + if (missing_columns): + raise RuntimeError( + f'Data does not contain the following columns ' + f': {missing_columns}.') + + # Get all columns for this item + for col in columns: + result.append(item[col]) + return result + + +def _dict_output_fn( + columns: Sequence[str], + batch: Sequence[Dict[str, Any]], + embeddings: Sequence[Any]) -> List[Dict[str, Any]]: + """Map embeddings back to columns in batch.""" + result = [] + for batch_idx, item in enumerate(batch): + for col_idx, col in enumerate(columns): + embedding_idx = batch_idx * len(columns) + col_idx + item[col] = embeddings[embedding_idx] + result.append(item) + return result + + +def _create_dict_adapter( + columns: List[str]) -> EmbeddingTypeAdapter[Dict[str, Any], Dict[str, Any]]: + """Create adapter for dict-based processing.""" + return EmbeddingTypeAdapter[Dict[str, Any], Dict[str, Any]]( + input_fn=cast( + Callable[[Sequence[Dict[str, Any]]], List[str]], + functools.partial(_dict_input_fn, columns)), + output_fn=cast( + Callable[[Sequence[Dict[str, Any]], Sequence[Any]], + List[Dict[str, Any]]], + functools.partial(_dict_output_fn, columns))) + + # TODO:https://github.com/apache/beam/issues/29356 # Add support for inference_fn class EmbeddingsManager(MLTransformProvider): def __init__( self, - columns: list[str], *, + columns: Optional[list[str]] = None, + type_adapter: Optional[EmbeddingTypeAdapter] = None, # common args for all ModelHandlers. load_model_args: Optional[dict[str, Any]] = None, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, large_model: bool = False, **kwargs): + if columns is not None and type_adapter is not None: + raise ValueError( + "Cannot specify both 'columns' and 'type_adapter'. " + "Use either columns for dict processing or type_adapter " + "for custom types.") self.load_model_args = load_model_args or {} self.min_batch_size = min_batch_size self.max_batch_size = max_batch_size self.large_model = large_model self.columns = columns + if columns is not None: + self.type_adapter = _create_dict_adapter(columns) + elif type_adapter is not None: + self.type_adapter = type_adapter + else: + raise ValueError("Either columns or type_adapter must be specified") self.inference_args = kwargs.pop('inference_args', {}) if kwargs: @@ -622,32 +696,6 @@ def _validate_batch(self, batch: Sequence[dict[str, Any]]): 'Expected data to be dicts, got ' f'{type(batch[0])} instead.') - def _process_batch( - self, - dict_batch: dict[str, list[Any]], - model: ModelT, - inference_args: Optional[dict[str, Any]]) -> dict[str, list[Any]]: - result: dict[str, list[Any]] = collections.defaultdict(list) - input_keys = dict_batch.keys() - missing_columns_in_data = set(self.columns) - set(input_keys) - if missing_columns_in_data: - raise RuntimeError( - f'Data does not contain the following columns ' - f': {missing_columns_in_data}.') - for key, batch in dict_batch.items(): - if key in self.columns: - self._validate_column_data(batch) - prediction = self._underlying.run_inference( - batch, model, inference_args) - if isinstance(prediction, np.ndarray): - prediction = prediction.tolist() - result[key] = prediction # type: ignore[assignment] - else: - result[key] = prediction # type: ignore[assignment] - else: - result[key] = batch - return result - def run_inference( self, batch: Sequence[dict[str, list[str]]], @@ -659,12 +707,19 @@ def run_inference( a list of dicts. Each dict should have the same keys, and the shape should be of the same size for a single key across the batch. """ - self._validate_batch(batch) - dict_batch = _convert_list_of_dicts_to_dict_of_lists(list_of_dicts=batch) - transformed_batch = self._process_batch(dict_batch, model, inference_args) - return _convert_dict_of_lists_to_lists_of_dict( - dict_of_lists=transformed_batch, - ) + embedding_input = self.embedding_config.type_adapter.input_fn(batch) + self._validate_column_data(batch=embedding_input) + prediction = self._underlying.run_inference( + embedding_input, model, inference_args) + # Convert prediction to Sequence[Any] + if isinstance(prediction, np.ndarray): + prediction_seq = prediction.tolist() + elif isinstance(prediction, Iterable) and not isinstance(prediction, + (str, bytes)): + prediction_seq = list(prediction) + else: + prediction_seq = [prediction] + return self.embedding_config.type_adapter.output_fn(batch, prediction_seq) def get_metrics_namespace(self) -> str: return ( diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index 3db5a63b9542..a5f179b726fd 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -430,9 +430,9 @@ def test_handler_on_multiple_columns(self): 'x': "Apache Beam", 'y': "Hello world", 'z': 'unchanged' }, ] - self.embedding_conig.columns = ['x', 'y'] + embedding_config = FakeEmbeddingsManager(columns=['x', 'y']) expected_data = [{ - key: (value[::-1] if key in self.embedding_conig.columns else value) + key: (value[::-1] if key in embedding_config.columns else value) for key, value in d.items() } for d in data] @@ -440,9 +440,8 @@ def test_handler_on_multiple_columns(self): result = ( p | beam.Create(data) - | base.MLTransform( - write_artifact_location=self.artifact_location).with_transform( - self.embedding_conig)) + | base.MLTransform(write_artifact_location=self.artifact_location). + with_transform(embedding_config)) assert_that( result, equal_to(expected_data), @@ -457,16 +456,15 @@ def test_handler_on_columns_not_exist_in_input_data(self): 'x': "Apache Beam", 'y': "Hello world" }, ] - self.embedding_conig.columns = ['x', 'y', 'a'] + embedding_config = FakeEmbeddingsManager(columns=['x', 'y', 'a']) with self.assertRaises(RuntimeError): with beam.Pipeline() as p: _ = ( p | beam.Create(data) - | base.MLTransform( - write_artifact_location=self.artifact_location).with_transform( - self.embedding_conig)) + | base.MLTransform(write_artifact_location=self.artifact_location). + with_transform(embedding_config)) def test_handler_with_list_data(self): data = [{ @@ -588,31 +586,37 @@ def test_handler_with_dict_inputs(self): class TestUtilFunctions(unittest.TestCase): - def test_list_of_dicts_to_dict_of_lists_normal(self): + def test_dict_input_fn_normal(self): input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] - expected_output = {'a': [1, 3], 'b': [2, 4]} + columns = ['a', 'b'] + + expected_output = [1, 2, 3, 4] + self.assertEqual(base._dict_input_fn(columns, input_list), expected_output) + + def test_dict_output_fn_normal(self): + input_list = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] + columns = ['a', 'b'] + embeddings = [1.1, 2.2, 3.3, 4.4] + + expected_output = [{'a': 1.1, 'b': 2.2}, {'a': 3.3, 'b': 4.4}] self.assertEqual( - base._convert_list_of_dicts_to_dict_of_lists(input_list), - expected_output) + base._dict_output_fn(columns, input_list, embeddings), expected_output) - def test_list_of_dicts_to_dict_of_lists_on_list_inputs(self): + def test_dict_input_fn_on_list_inputs(self): input_list = [{'a': [1, 2, 10], 'b': 3}, {'a': [1], 'b': 5}] - expected_output = {'a': [[1, 2, 10], [1]], 'b': [3, 5]} - self.assertEqual( - base._convert_list_of_dicts_to_dict_of_lists(input_list), - expected_output) + columns = ['a', 'b'] - def test_dict_of_lists_to_lists_of_dict_normal(self): - input_dict = {'a': [1, 3], 'b': [2, 4]} - expected_output = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] - self.assertEqual( - base._convert_dict_of_lists_to_lists_of_dict(input_dict), - expected_output) + expected_output = [[1, 2, 10], 3, [1], 5] + self.assertEqual(base._dict_input_fn(columns, input_list), expected_output) - def test_dict_of_lists_to_lists_of_dict_unequal_length(self): - input_dict = {'a': [1, 3], 'b': [2]} - with self.assertRaises(AssertionError): - base._convert_dict_of_lists_to_lists_of_dict(input_dict) + def test_dict_output_fn_on_list_inputs(self): + input_list = [{'a': [1, 2, 10], 'b': 3}, {'a': [1], 'b': 5}] + columns = ['a', 'b'] + embeddings = [1.1, 2.2, 3.3, 4.4] + + expected_output = [{'a': 1.1, 'b': 2.2}, {'a': 3.3, 'b': 4.4}] + self.assertEqual( + base._dict_output_fn(columns, input_list, embeddings), expected_output) class TestJsonPickleTransformAttributeManager(unittest.TestCase): diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py index 2162ed050c42..e492cb164222 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py @@ -133,7 +133,7 @@ def __init__( max_batch_size: The maximum batch size to be used for inference. large_model: Whether to share the model across processes. """ - super().__init__(columns, **kwargs) + super().__init__(columns=columns, **kwargs) self.model_name = model_name self.max_seq_length = max_seq_length self.image_model = image_model @@ -219,7 +219,7 @@ def __init__( api_url: Optional[str] = None, **kwargs, ): - super().__init__(columns, **kwargs) + super().__init__(columns=columns, **kwargs) self._authorization_token = {"Authorization": f"Bearer {hf_token}"} self._model_name = model_name self.hf_token = hf_token From 28900a6844f017b1aedb8b76989baf2cf11bc38a Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 11 Dec 2024 23:29:34 -0500 Subject: [PATCH 05/11] Add base rag EmbeddingTypeAdapter. --- .../apache_beam/ml/rag/embeddings/__init__.py | 20 +++++ .../apache_beam/ml/rag/embeddings/base.py | 53 +++++++++++ .../ml/rag/embeddings/base_test.py | 90 +++++++++++++++++++ 3 files changed, 163 insertions(+) create mode 100644 sdks/python/apache_beam/ml/rag/embeddings/__init__.py create mode 100644 sdks/python/apache_beam/ml/rag/embeddings/base.py create mode 100644 sdks/python/apache_beam/ml/rag/embeddings/base_test.py diff --git a/sdks/python/apache_beam/ml/rag/embeddings/__init__.py b/sdks/python/apache_beam/ml/rag/embeddings/__init__.py new file mode 100644 index 000000000000..d2cdb63c0bde --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Embedding components for RAG pipelines. +This module provides components for generating embeddings in RAG pipelines. +""" diff --git a/sdks/python/apache_beam/ml/rag/embeddings/base.py b/sdks/python/apache_beam/ml/rag/embeddings/base.py new file mode 100644 index 000000000000..712bc768e9d7 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/base.py @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from apache_beam.ml.transforms.base import EmbeddingTypeAdapter +from apache_beam.ml.rag.types import Embedding, Chunk +from typing import List +from collections.abc import Sequence + + +def create_rag_adapter() -> EmbeddingTypeAdapter[Chunk, Chunk]: + """Creates adapter for converting between Chunk and Embedding types. + + The adapter: + - Extracts text from Chunk.content.text for embedding + - Creates Embedding objects from model output + - Sets Embedding in Chunk.embedding + + Returns: + EmbeddingTypeAdapter configured for RAG pipeline types + """ + return EmbeddingTypeAdapter( + input_fn=_extract_chunk_text, output_fn=_add_embedding_fn) + + +def _extract_chunk_text(chunks: Sequence[Chunk]) -> List[str]: + """Extract text from chunks for embedding.""" + chunk_texts = [] + for chunk in chunks: + if not chunk.content.text: + raise ValueError("Expected chunk text content.") + chunk_texts.append(chunk.content.text) + return chunk_texts + + +def _add_embedding_fn( + chunks: Sequence[Chunk], embeddings: Sequence[List[float]]) -> List[Chunk]: + """Create Embeddings from chunks and embedding vectors.""" + for chunk, embedding in zip(chunks, embeddings): + chunk.embedding = Embedding(dense_embedding=embedding) + return list(chunks) diff --git a/sdks/python/apache_beam/ml/rag/embeddings/base_test.py b/sdks/python/apache_beam/ml/rag/embeddings/base_test.py new file mode 100644 index 000000000000..77dc7a67bf55 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/base_test.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from apache_beam.ml.rag.types import Chunk, Content, Embedding +from apache_beam.ml.rag.embeddings.base import (create_rag_adapter) + + +class RAGBaseEmbeddingsTest(unittest.TestCase): + def setUp(self): + self.test_chunks = [ + Chunk( + content=Content(text="This is a test sentence."), + id="1", + metadata={ + "source": "test.txt", "language": "en" + }), + Chunk( + content=Content(text="Another example."), + id="2", + metadata={ + "source": "test2.txt", "language": "en" + }) + ] + + def test_adapter_input_conversion(self): + """Test the RAG type adapter converts correctly.""" + adapter = create_rag_adapter() + + # Test input conversion + texts = adapter.input_fn(self.test_chunks) + self.assertEqual(texts, ["This is a test sentence.", "Another example."]) + + def test_adapter_input_conversion_missing_text_content(self): + """Test the RAG type adapter converts correctly.""" + adapter = create_rag_adapter() + + # Test input conversion + with self.assertRaisesRegex(ValueError, "Expected chunk text content"): + adapter.input_fn([ + Chunk( + content=Content(), + id="1", + metadata={ + "source": "test.txt", "language": "en" + }) + ]) + + def test_adapter_output_conversion(self): + """Test the RAG type adapter converts correctly.""" + # Test output conversion + mock_embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + # Expected outputs + expected = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + metadata={ + 'source': 'test.txt', 'language': 'en' + }, + content=Content(text='This is a test sentence.')), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.4, 0.5, 0.6]), + metadata={ + 'source': 'test2.txt', 'language': 'en' + }, + content=Content(text='Another example.')), + ] + adapter = create_rag_adapter() + + embeddings = adapter.output_fn(self.test_chunks, mock_embeddings) + self.assertListEqual(embeddings, expected) + + +if __name__ == '__main__': + unittest.main() From 4eb7cef62b0ab948aa54adc5f96ca62a7e5ab554 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 12 Dec 2024 13:12:10 -0500 Subject: [PATCH 06/11] Create HuggingfaceTextEmbeddings. --- .../ml/rag/embeddings/huggingface.py | 69 ++++++++++++ .../ml/rag/embeddings/huggingface_test.py | 104 ++++++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 sdks/python/apache_beam/ml/rag/embeddings/huggingface.py create mode 100644 sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py new file mode 100644 index 000000000000..8e3f449b0af0 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RAG-specific embedding implementations using HuggingFace models.""" + +from typing import Optional + +import apache_beam as beam +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.rag.embeddings.base import create_rag_adapter +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.transforms.base import ( + EmbeddingsManager, _TextEmbeddingHandler) +from apache_beam.ml.transforms.embeddings.huggingface import ( + SentenceTransformer, _SentenceTransformerModelHandler) + + +class HuggingfaceTextEmbeddings(EmbeddingsManager): + """SentenceTransformer embeddings for RAG pipeline. + + Extends EmbeddingsManager to work with RAG-specific types: + - Input: Chunk objects containing text to embed + - Output: Chunk objects with embedding property set + """ + def __init__( + self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs): + """Initialize RAG embeddings. + + Args: + model_name: Name of the sentence-transformers model to use + max_seq_length: Maximum sequence length for the model + **kwargs: Additional arguments passed to parent + """ + super().__init__(type_adapter=create_rag_adapter(), **kwargs) + self.model_name = model_name + self.max_seq_length = max_seq_length + self.model_class = SentenceTransformer + + def get_model_handler(self): + """Returns model handler configured with RAG adapter.""" + return _SentenceTransformerModelHandler( + model_class=self.model_class, + max_seq_length=self.max_seq_length, + model_name=self.model_name, + load_model_args=self.load_model_args, + min_batch_size=self.min_batch_size, + max_batch_size=self.max_batch_size, + large_model=self.large_model) + + def get_ptransform_for_processing( + self, **kwargs + ) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]: + """Returns PTransform that uses the RAG adapter.""" + return RunInference( + model_handler=_TextEmbeddingHandler(self), + inference_args=self.inference_args).with_output_types(Chunk) diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py new file mode 100644 index 000000000000..d640aed577cc --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for apache_beam.ml.rag.embeddings.huggingface.""" + +import pytest +import tempfile +import unittest + +import apache_beam as beam +from apache_beam.ml.rag.embeddings.huggingface import HuggingfaceTextEmbeddings +from apache_beam.ml.rag.types import Chunk, Content, Embedding +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that, equal_to + +# pylint: disable=unused-import +try: + from sentence_transformers import SentenceTransformer + SENTENCE_TRANSFORMERS_AVAILABLE = True +except ImportError: + SENTENCE_TRANSFORMERS_AVAILABLE = False + + +def chunk_approximately_equals(expected, actual): + """Compare embeddings allowing for numerical differences.""" + if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): + return False + + return ( + expected.id == actual.id and expected.metadata == actual.metadata and + expected.content == actual.content and + len(expected.embedding.dense_embedding) == len( + actual.embedding.dense_embedding) and + all(isinstance(x, float) for x in actual.embedding.dense_embedding)) + + +@pytest.mark.uses_transformers +@unittest.skipIf( + not SENTENCE_TRANSFORMERS_AVAILABLE, "sentence-transformers not available") +class HuggingfaceTextEmbeddingsTest(unittest.TestCase): + def setUp(self): + self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_') + self.test_chunks = [ + Chunk( + content=Content(text="This is a test sentence."), + id="1", + metadata={ + "source": "test.txt", "language": "en" + }), + Chunk( + content=Content(text="Another example."), + id="2", + metadata={ + "source": "test.txt", "language": "en" + }) + ] + + def test_embedding_pipeline(self): + expected = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.0] * 384), + metadata={ + "source": "test.txt", "language": "en" + }, + content=Content(text="This is a test sentence.")), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.0] * 384), + metadata={ + "source": "test.txt", "language": "en" + }, + content=Content(text="Another example.")) + ] + embedder = HuggingfaceTextEmbeddings( + model_name="sentence-transformers/all-MiniLM-L6-v2") + + with TestPipeline() as p: + embeddings = ( + p + | beam.Create(self.test_chunks) + | MLTransform(write_artifact_location=self.artifact_location). + with_transform(embedder)) + + assert_that( + embeddings, equal_to(expected, equals_fn=chunk_approximately_equals)) + + +if __name__ == '__main__': + unittest.main() From 0a28de33eb0e509d1e05fb35ce45be0ad24f0e35 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 12 Dec 2024 13:13:17 -0500 Subject: [PATCH 07/11] Linter fixes. --- sdks/python/apache_beam/ml/rag/chunking/base_test.py | 11 ++++++++--- .../apache_beam/ml/rag/chunking/langchain_test.py | 8 +++----- sdks/python/apache_beam/ml/rag/embeddings/base.py | 8 +++++--- .../apache_beam/ml/rag/embeddings/base_test.py | 7 +++++-- .../apache_beam/ml/rag/embeddings/huggingface.py | 8 ++++---- .../ml/rag/embeddings/huggingface_test.py | 10 +++++++--- sdks/python/apache_beam/ml/rag/types.py | 9 +++++++-- sdks/python/apache_beam/ml/transforms/base.py | 12 +++--------- 8 files changed, 42 insertions(+), 31 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/chunking/base_test.py b/sdks/python/apache_beam/ml/rag/chunking/base_test.py index 5a8d816343da..d6a2c0037e3a 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/base_test.py +++ b/sdks/python/apache_beam/ml/rag/chunking/base_test.py @@ -17,15 +17,20 @@ """Tests for apache_beam.ml.rag.chunking.base.""" import unittest +from typing import Any +from typing import Dict +from typing import Optional + import pytest import apache_beam as beam +from apache_beam.ml.rag.chunking.base import ChunkIdFn +from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to -from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider, ChunkIdFn -from apache_beam.ml.rag.types import Chunk, Content -from typing import Optional, Dict, Any class WordSplitter(beam.DoFn): diff --git a/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py index 5bc17a824775..615c67207d9d 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py +++ b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py @@ -19,18 +19,16 @@ import unittest import apache_beam as beam - +from apache_beam.ml.rag.types import Chunk from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to -from apache_beam.ml.rag.types import Chunk try: from apache_beam.ml.rag.chunking.langchain import LangChainChunker + from langchain.text_splitter import ( - RecursiveCharacterTextSplitter, - CharacterTextSplitter, - ) + CharacterTextSplitter, RecursiveCharacterTextSplitter) LANGCHAIN_AVAILABLE = True except ImportError: LANGCHAIN_AVAILABLE = False diff --git a/sdks/python/apache_beam/ml/rag/embeddings/base.py b/sdks/python/apache_beam/ml/rag/embeddings/base.py index 712bc768e9d7..25dc3ee47e80 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/base.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/base.py @@ -14,10 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from apache_beam.ml.transforms.base import EmbeddingTypeAdapter -from apache_beam.ml.rag.types import Embedding, Chunk -from typing import List from collections.abc import Sequence +from typing import List + +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.transforms.base import EmbeddingTypeAdapter def create_rag_adapter() -> EmbeddingTypeAdapter[Chunk, Chunk]: diff --git a/sdks/python/apache_beam/ml/rag/embeddings/base_test.py b/sdks/python/apache_beam/ml/rag/embeddings/base_test.py index 77dc7a67bf55..3a27ae8e7ebb 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/base_test.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/base_test.py @@ -15,8 +15,11 @@ # limitations under the License. import unittest -from apache_beam.ml.rag.types import Chunk, Content, Embedding -from apache_beam.ml.rag.embeddings.base import (create_rag_adapter) + +from apache_beam.ml.rag.embeddings.base import create_rag_adapter +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding class RAGBaseEmbeddingsTest(unittest.TestCase): diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py index 8e3f449b0af0..68468ad1875b 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py @@ -22,10 +22,10 @@ from apache_beam.ml.inference.base import RunInference from apache_beam.ml.rag.embeddings.base import create_rag_adapter from apache_beam.ml.rag.types import Chunk -from apache_beam.ml.transforms.base import ( - EmbeddingsManager, _TextEmbeddingHandler) -from apache_beam.ml.transforms.embeddings.huggingface import ( - SentenceTransformer, _SentenceTransformerModelHandler) +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformer +from apache_beam.ml.transforms.embeddings.huggingface import _SentenceTransformerModelHandler class HuggingfaceTextEmbeddings(EmbeddingsManager): diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py index d640aed577cc..aa63d13025a1 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py @@ -16,16 +16,20 @@ """Tests for apache_beam.ml.rag.embeddings.huggingface.""" -import pytest import tempfile import unittest +import pytest + import apache_beam as beam from apache_beam.ml.rag.embeddings.huggingface import HuggingfaceTextEmbeddings -from apache_beam.ml.rag.types import Chunk, Content, Embedding +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding from apache_beam.ml.transforms.base import MLTransform from apache_beam.testing.test_pipeline import TestPipeline -from apache_beam.testing.util import assert_that, equal_to +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to # pylint: disable=unused-import try: diff --git a/sdks/python/apache_beam/ml/rag/types.py b/sdks/python/apache_beam/ml/rag/types.py index 82beeba5cc6c..5d7d8b486739 100644 --- a/sdks/python/apache_beam/ml/rag/types.py +++ b/sdks/python/apache_beam/ml/rag/types.py @@ -21,9 +21,14 @@ contracts between different stages of the pipeline. """ -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Any import uuid +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple @dataclass diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index 8fd9e2f61ffe..703892886bef 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -15,6 +15,7 @@ # limitations under the License. import abc +import functools import logging import os import tempfile @@ -22,8 +23,8 @@ from collections.abc import Callable from collections.abc import Mapping from collections.abc import Sequence +from dataclasses import dataclass from typing import Any -from typing import cast from typing import Dict from typing import Generic from typing import Iterable @@ -31,13 +32,11 @@ from typing import Optional from typing import TypeVar from typing import Union +from typing import cast -import functools import jsonpickle import numpy as np -from dataclasses import dataclass - import apache_beam as beam from apache_beam.io.filesystems import FileSystems from apache_beam.metrics.metric import Metrics @@ -258,11 +257,6 @@ def __init__( max_batch_size: Optional[int] = None, large_model: bool = False, **kwargs): - if columns is not None and type_adapter is not None: - raise ValueError( - "Cannot specify both 'columns' and 'type_adapter'. " - "Use either columns for dict processing or type_adapter " - "for custom types.") self.load_model_args = load_model_args or {} self.min_batch_size = min_batch_size self.max_batch_size = max_batch_size From 9637269573065ea686c3f8d31da3e0681e679b42 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 13 Dec 2024 13:15:55 -0500 Subject: [PATCH 08/11] Docstring fixes. --- .../apache_beam/ml/rag/chunking/base.py | 33 +++++++++---------- .../ml/rag/embeddings/huggingface.py | 12 ++----- sdks/python/apache_beam/ml/rag/types.py | 7 ++-- 3 files changed, 24 insertions(+), 28 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/chunking/base.py b/sdks/python/apache_beam/ml/rag/chunking/base.py index 9ab0dbde49b7..014e458d4286 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/base.py +++ b/sdks/python/apache_beam/ml/rag/chunking/base.py @@ -37,7 +37,7 @@ def _assign_chunk_id(chunk_id_fn: ChunkIdFn, chunk: Chunk): class ChunkingTransformProvider(MLTransformProvider): def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): """Base class for chunking transforms in RAG pipelines. - + ChunkingTransformProvider defines the interface for splitting documents into chunks for embedding and retrieval. Implementations should define how to split content while preserving metadata and managing chunk IDs. @@ -49,24 +49,23 @@ def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): 4. Optionally assigns unique IDs to chunks (configurable via chunk_id_fn). Example usage: - ```python - class MyChunker(ChunkingTransformProvider): - def get_splitter_transform(self): - return beam.ParDo(MySplitterDoFn()) - - chunker = MyChunker(chunk_id_fn=my_id_function) - - with beam.Pipeline() as p: - chunks = ( - p - | beam.Create([{'text': 'document...', 'source': 'doc.txt'}]) - | MLTransform(...).with_transform(chunker)) - ``` + ```python + class MyChunker(ChunkingTransformProvider): + def get_splitter_transform(self): + return beam.ParDo(MySplitterDoFn()) + + chunker = MyChunker(chunk_id_fn=my_id_function) + + with beam.Pipeline() as p: + chunks = ( + p + | beam.Create([{'text': 'document...', 'source': 'doc.txt'}]) + | MLTransform(...).with_transform(chunker)) + ``` Args: - chunk_id_fn: Optional function to generate chunk IDs. If not provided, - random UUIDs will be used. Function should take a Chunk and return - str. + chunk_id_fn: Optional function to generate chunk IDs. If not provided, + random UUIDs will be used. Function should take a Chunk and return str. """ self.assign_chunk_id_fn = functools.partial( _assign_chunk_id, chunk_id_fn) if chunk_id_fn is not None else None diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py index 68468ad1875b..8355c3e5a2a5 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py @@ -29,20 +29,14 @@ class HuggingfaceTextEmbeddings(EmbeddingsManager): - """SentenceTransformer embeddings for RAG pipeline. - - Extends EmbeddingsManager to work with RAG-specific types: - - Input: Chunk objects containing text to embed - - Output: Chunk objects with embedding property set - """ def __init__( self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs): - """Initialize RAG embeddings. - + """Utilizes huggingface SentenceTransformer embeddings for RAG pipeline. + Args: model_name: Name of the sentence-transformers model to use max_seq_length: Maximum sequence length for the model - **kwargs: Additional arguments passed to parent + **kwargs: Additional arguments including ModelHandlers arguments """ super().__init__(type_adapter=create_rag_adapter(), **kwargs) self.model_name = model_name diff --git a/sdks/python/apache_beam/ml/rag/types.py b/sdks/python/apache_beam/ml/rag/types.py index 5d7d8b486739..79429899e4c1 100644 --- a/sdks/python/apache_beam/ml/rag/types.py +++ b/sdks/python/apache_beam/ml/rag/types.py @@ -34,6 +34,9 @@ @dataclass class Content: """Container for embeddable content. Add new types as when as necessary. + + Args: + text: Text content to be embedded """ text: Optional[str] = None @@ -42,7 +45,7 @@ class Content: class Embedding: """Represents vector embeddings. - Attributes: + Args: dense_embedding: Dense vector representation sparse_embedding: Optional sparse vector representation for hybrid search @@ -56,7 +59,7 @@ class Embedding: class Chunk: """Represents a chunk of embeddable content with metadata. - Attributes: + Args: content: The actual content of the chunk id: Unique identifier for the chunk index: Index of this chunk within the original document From 384fcfbadf5f3bb460ad5941ff44ceead009ffb4 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Dec 2024 12:08:42 -0500 Subject: [PATCH 09/11] Typehint fixes. --- sdks/python/apache_beam/ml/rag/chunking/langchain.py | 2 +- .../apache_beam/ml/rag/embeddings/huggingface.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/chunking/langchain.py b/sdks/python/apache_beam/ml/rag/chunking/langchain.py index f23e7be65d49..9e3b6b0c8ef9 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/langchain.py +++ b/sdks/python/apache_beam/ml/rag/chunking/langchain.py @@ -29,7 +29,7 @@ try: from langchain.text_splitter import TextSplitter except ImportError: - TextSplitter = None # type: ignore + TextSplitter = None class LangChainChunker(ChunkingTransformProvider): diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py index 8355c3e5a2a5..be34fdbd36a1 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py @@ -24,9 +24,13 @@ from apache_beam.ml.rag.types import Chunk from apache_beam.ml.transforms.base import EmbeddingsManager from apache_beam.ml.transforms.base import _TextEmbeddingHandler -from apache_beam.ml.transforms.embeddings.huggingface import SentenceTransformer from apache_beam.ml.transforms.embeddings.huggingface import _SentenceTransformerModelHandler +try: + from sentence_transformers import SentenceTransformer +except ImportError: + SentenceTransformer = None + class HuggingfaceTextEmbeddings(EmbeddingsManager): def __init__( @@ -38,6 +42,11 @@ def __init__( max_seq_length: Maximum sequence length for the model **kwargs: Additional arguments including ModelHandlers arguments """ + if not SentenceTransformer: + raise ImportError( + "sentence-transformers is required to use " + "HuggingfaceTextEmbeddings." + "Please install it with using `pip install sentence-transformers`.") super().__init__(type_adapter=create_rag_adapter(), **kwargs) self.model_name = model_name self.max_seq_length = max_seq_length From a5bfaf398680133fad13c4f91a88fee9509c43c6 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Dec 2024 12:48:58 -0500 Subject: [PATCH 10/11] Docstring fix. --- .../apache_beam/ml/rag/chunking/base.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/chunking/base.py b/sdks/python/apache_beam/ml/rag/chunking/base.py index 014e458d4286..626a6ea8abbe 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/base.py +++ b/sdks/python/apache_beam/ml/rag/chunking/base.py @@ -43,25 +43,23 @@ def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): to split content while preserving metadata and managing chunk IDs. The transform flow: - 1. Takes input documents with content and metadata - 2. Splits content into chunks using implementation-specific logic - 3. Preserves document metadata in resulting chunks - 4. Optionally assigns unique IDs to chunks (configurable via chunk_id_fn). + - Takes input documents with content and metadata + - Splits content into chunks using implementation-specific logic + - Preserves document metadata in resulting chunks + - Optionally assigns unique IDs to chunks (configurable via chunk_id_fn Example usage: - ```python - class MyChunker(ChunkingTransformProvider): - def get_splitter_transform(self): - return beam.ParDo(MySplitterDoFn()) - - chunker = MyChunker(chunk_id_fn=my_id_function) - - with beam.Pipeline() as p: - chunks = ( - p - | beam.Create([{'text': 'document...', 'source': 'doc.txt'}]) - | MLTransform(...).with_transform(chunker)) - ``` + >>> class MyChunker(ChunkingTransformProvider): + ... def get_splitter_transform(self): + ... return beam.ParDo(MySplitterDoFn()) + ... + >>> chunker = MyChunker(chunk_id_fn=my_id_function) + >>> + >>> with beam.Pipeline() as p: + ... chunks = ( + ... p + ... | beam.Create([{'text': 'document...', 'source': 'doc.txt'}]) + ... | MLTransform(...).with_transform(chunker)) Args: chunk_id_fn: Optional function to generate chunk IDs. If not provided, From 7fe1022fda12f4eb19aea23ff1effff8fed213e7 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 17 Dec 2024 07:06:34 -0500 Subject: [PATCH 11/11] Add EmbeddingManager to args and more test coverage. --- .../apache_beam/ml/rag/chunking/base_test.py | 10 +++++ .../ml/rag/chunking/langchain_test.py | 15 ++++--- .../ml/rag/embeddings/huggingface.py | 4 +- sdks/python/apache_beam/ml/transforms/base.py | 6 --- .../apache_beam/ml/transforms/base_test.py | 43 ++++++++++++++++++- 5 files changed, 64 insertions(+), 14 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/chunking/base_test.py b/sdks/python/apache_beam/ml/rag/chunking/base_test.py index d6a2c0037e3a..54e25591c348 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/base_test.py +++ b/sdks/python/apache_beam/ml/rag/chunking/base_test.py @@ -43,6 +43,11 @@ def process(self, element): metadata={'source': element['source']}) +class InvalidChunkingProvider(ChunkingTransformProvider): + def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): + super().__init__(chunk_id_fn=chunk_id_fn) + + class MockChunkingProvider(ChunkingTransformProvider): def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None): super().__init__(chunk_id_fn=chunk_id_fn) @@ -76,6 +81,11 @@ class ChunkingTransformProviderTest(unittest.TestCase): def setUp(self): self.test_doc = {'text': 'hello world test', 'source': 'test.txt'} + def test_doesnt_override_get_text_splitter_transform(self): + provider = InvalidChunkingProvider() + with self.assertRaises(NotImplementedError): + provider.get_splitter_transform() + def test_chunking_transform(self): """Test the complete chunking transform.""" provider = MockChunkingProvider() diff --git a/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py index 615c67207d9d..83a4fc1a778f 100644 --- a/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py +++ b/sdks/python/apache_beam/ml/rag/chunking/langchain_test.py @@ -179,17 +179,20 @@ def test_invalid_document_field(self): | beam.Create([self.simple_text]) | provider.get_ptransform_for_processing()) + def test_empty_document_field(self): + """Test that using an invalid document field raises KeyError.""" + splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20) + + with self.assertRaises(ValueError): + _ = LangChainChunker( + document_field='', metadata_fields={}, text_splitter=splitter) + def test_invalid_text_splitter(self): """Test that using an invalid document field raises KeyError.""" with self.assertRaises(TypeError): - provider = LangChainChunker( + _ = LangChainChunker( document_field='nonexistent', text_splitter="Not a text splitter!") - with TestPipeline() as p: - _ = ( - p - | beam.Create([self.simple_text]) - | provider.get_ptransform_for_processing()) def test_empty_text(self): """Test that empty text produces no chunks.""" diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py index be34fdbd36a1..4cb0aecd6e82 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface.py @@ -40,7 +40,9 @@ def __init__( Args: model_name: Name of the sentence-transformers model to use max_seq_length: Maximum sequence length for the model - **kwargs: Additional arguments including ModelHandlers arguments + **kwargs: Additional arguments passed to + :class:`~apache_beam.ml.transforms.base.EmbeddingsManager` + constructor including ModelHandler arguments """ if not SentenceTransformer: raise ImportError( diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index 703892886bef..57a5efd3ff0e 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -684,12 +684,6 @@ def load_model(self): def _validate_column_data(self, batch): pass - def _validate_batch(self, batch: Sequence[dict[str, Any]]): - if not batch or not isinstance(batch[0], dict): - raise TypeError( - 'Expected data to be dicts, got ' - f'{type(batch[0])} instead.') - def run_inference( self, batch: Sequence[dict[str, list[str]]], diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index a5f179b726fd..1ef01acca18a 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -78,6 +78,10 @@ def setUp(self) -> None: def tearDown(self): shutil.rmtree(self.artifact_location) + def test_ml_transform_no_read_or_write_artifact_lcoation(self): + with self.assertRaises(ValueError): + _ = base.MLTransform(transforms=[]) + @unittest.skipIf(tft is None, 'tft module is not installed.') def test_ml_transform_appends_transforms_to_process_handler_correctly(self): fake_fn_1 = _FakeOperation(name='fake_fn_1', columns=['x']) @@ -354,6 +358,21 @@ def __repr__(self): return 'FakeEmbeddingsManager' +class InvalidEmbeddingsManager(base.EmbeddingsManager): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_model_handler(self) -> ModelHandler: + InvalidEmbeddingsManager.__repr__ = lambda x: 'InvalidEmbeddingsManager' # type: ignore[method-assign] + return FakeModelHandler() + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + return (RunInference(model_handler=base._TextEmbeddingHandler(self))) + + def __repr__(self): + return 'InvalidEmbeddingsManager' + + class TextEmbeddingHandlerTest(unittest.TestCase): def setUp(self) -> None: self.embedding_conig = FakeEmbeddingsManager(columns=['x']) @@ -362,6 +381,10 @@ def setUp(self) -> None: def tearDown(self) -> None: shutil.rmtree(self.artifact_location) + def test_no_columns_or_type_adapter(self): + with self.assertRaises(ValueError): + _ = InvalidEmbeddingsManager() + def test_handler_with_incompatible_datatype(self): text_handler = base._TextEmbeddingHandler( embeddings_manager=self.embedding_conig) @@ -548,7 +571,7 @@ def tearDown(self) -> None: shutil.rmtree(self.artifact_location) @unittest.skipIf(PIL is None, 'PIL module is not installed.') - def test_handler_with_incompatible_datatype(self): + def test_handler_with_non_dict_datatype(self): image_handler = base._ImageEmbeddingHandler( embeddings_manager=self.embedding_config) data = [ @@ -559,6 +582,24 @@ def test_handler_with_incompatible_datatype(self): with self.assertRaises(TypeError): image_handler.run_inference(data, None, None) + @unittest.skipIf(PIL is None, 'PIL module is not installed.') + def test_handler_with_non_image_datatype(self): + image_handler = base._ImageEmbeddingHandler( + embeddings_manager=self.embedding_config) + data = [ + { + 'x': 'hi there' + }, + { + 'x': 'not an image' + }, + { + 'x': 'image_path.jpg' + }, + ] + with self.assertRaises(TypeError): + image_handler.run_inference(data, None, None) + @unittest.skipIf(PIL is None, 'PIL module is not installed.') def test_handler_with_dict_inputs(self): img_one = PIL.Image.new(mode='RGB', size=(1, 1))