Skip to content

Commit

Permalink
Linter fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
claudevdm committed Dec 8, 2024
1 parent 55cacf2 commit 0b6bdde
Show file tree
Hide file tree
Showing 14 changed files with 44 additions and 81 deletions.
3 changes: 1 addition & 2 deletions sdks/python/apache_beam/ml/rag/chunking/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import apache_beam as beam
from apache_beam.ml.transforms.base import MLTransformProvider
from apache_beam.ml.rag.types import Chunk
from typing import List, Optional
from typing import Optional
from collections.abc import Callable
import abc
import uuid
Expand All @@ -37,7 +37,6 @@ def assign_chunk_id(chunk_id_fn: ChunkIdFn, chunk: Chunk):


class ChunkingTransformProvider(MLTransformProvider):

def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None):
self.assign_chunk_id_fn = functools.partial(
assign_chunk_id,
Expand Down
6 changes: 1 addition & 5 deletions sdks/python/apache_beam/ml/rag/chunking/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
from apache_beam.testing.util import equal_to
from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider, ChunkIdFn
from apache_beam.ml.rag.types import Chunk, Content
from typing import List, Optional
from typing import Optional


class WordSplitter(beam.DoFn):

def process(self, element):
words = element['text'].split()
for i, word in enumerate(words):
Expand All @@ -40,7 +39,6 @@ def process(self, element):


class MockChunkingProvider(ChunkingTransformProvider):

def __init__(self, chunk_id_fn: Optional[ChunkIdFn] = None):
super().__init__(chunk_id_fn=chunk_id_fn)

Expand All @@ -67,7 +65,6 @@ def id_equals(expected, actual):

@pytest.mark.uses_transformers
class ChunkingTransformProviderTest(unittest.TestCase):

def setUp(self):
self.test_doc = {'text': 'hello world test', 'source': 'test.txt'}

Expand Down Expand Up @@ -100,7 +97,6 @@ def test_chunking_transform(self):

def test_custom_chunk_id_fn(self):
"""Test the a custom chink id function."""

def source_index_id_fn(chunk: Chunk):
return f"{chunk.metadata['source']}_{chunk.index}"

Expand Down
4 changes: 1 addition & 3 deletions sdks/python/apache_beam/ml/rag/chunking/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@


class LangChainChunkingProvider(ChunkingTransformProvider):

def __init__(
self,
text_splitter: TextSplitter,
document_field: str,
metadata_fields: List[str] = [],
metadata_fields: List[str],
chunk_id_fn: Optional[ChunkIdFn] = None):
if not isinstance(text_splitter, TextSplitter):
raise TypeError("text_splitter must be a LangChain TextSplitter")
Expand All @@ -48,7 +47,6 @@ def get_text_splitter_transform(self) -> beam.DoFn:


class LangChainTextSplitter(beam.DoFn):

def __init__(
self,
text_splitter: TextSplitter,
Expand Down
24 changes: 14 additions & 10 deletions sdks/python/apache_beam/ml/rag/chunking/langchain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,21 @@
import unittest

import apache_beam as beam

from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to

from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
CharacterTextSplitter,
)

from apache_beam.ml.rag.chunking.langchain import LangChainChunkingProvider
from apache_beam.ml.rag.types import Chunk, Content
from apache_beam.ml.rag.types import Chunk

try:
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
CharacterTextSplitter,
)
LANGCHAIN_AVAILABLE = True
except ImportError:
LANGCHAIN_AVAILABLE = False

# Import optional dependencies
try:
Expand All @@ -49,7 +53,6 @@ def chunk_equals(expected, actual):


class LangChainChunkingTest(unittest.TestCase):

def setUp(self):
self.simple_text = {
'content': 'This is a simple test document. It has multiple sentences. '
Expand All @@ -72,7 +75,7 @@ def test_no_metadata_fields(self):
"""Test chunking with no metadata fields specified."""
splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20)
provider = LangChainChunkingProvider(
document_field='content', text_splitter=splitter)
document_field='content', metadata_fields=[], text_splitter=splitter)

with TestPipeline() as p:
chunks = (
Expand Down Expand Up @@ -102,7 +105,8 @@ def test_multiple_metadata_fields(self):

assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks')
assert_that(
chunks, lambda x: all(
chunks,
lambda x: all(
c.metadata == {
'source': 'simple.txt', 'language': 'en'
} for c in x))
Expand Down
12 changes: 8 additions & 4 deletions sdks/python/apache_beam/ml/rag/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ def create_rag_adapter() -> EmbeddingTypeAdapter:
"""
return EmbeddingTypeAdapter(
input_fn=lambda chunks: [chunk.content.text for chunk in chunks],
output_fn=lambda chunks, embeddings: [
output_fn=lambda chunks,
embeddings: [
Embedding(
id=chunk.id, dense_embedding=embeddings, sparse_embedding=None,
metadata=chunk.metadata, content=chunk.content)
for chunk, embeddings in zip(chunks, embeddings)
id=chunk.id,
dense_embedding=embeddings,
sparse_embedding=None,
metadata=chunk.metadata,
content=chunk.content) for chunk,
embeddings in zip(chunks, embeddings)
])
1 change: 0 additions & 1 deletion sdks/python/apache_beam/ml/rag/embeddings/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


class RAGBaseEmbeddingsTest(unittest.TestCase):

def setUp(self):
self.test_chunks = [
Chunk(
Expand Down
1 change: 0 additions & 1 deletion sdks/python/apache_beam/ml/rag/embeddings/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class HuggingfaceTextEmbeddings(EmbeddingsManager):
- Copies Chunk.metadata to Embedding.metadata
- Converts model output to Embedding.dense_embedding
"""

def __init__(
self, model_name: str, *, max_seq_length: Optional[int] = None, **kwargs):
"""Initialize RAG embeddings.
Expand Down
8 changes: 3 additions & 5 deletions sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.ml.transforms.base import MLTransform
from apache_beam.ml.rag.types import Chunk, Content, Embedding
from apache_beam.ml.rag.embeddings.huggingface import HuggingfaceTextEmbeddings

# pylint: disable=unused-import
try:
from sentence_transformers import SentenceTransformer
SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError:
SENTENCE_TRANSFORMERS_AVAILABLE = False

from apache_beam.ml.rag.types import Chunk, Content, Embedding
from apache_beam.ml.rag.embeddings.huggingface import HuggingfaceTextEmbeddings


def embedding_approximately_equals(expected, actual):
"""Compare embeddings allowing for numerical differences."""
Expand All @@ -52,7 +52,6 @@ def embedding_approximately_equals(expected, actual):
@unittest.skipIf(
not SENTENCE_TRANSFORMERS_AVAILABLE, "sentence-transformers not available")
class HuggingfaceTextEmbeddingsTest(unittest.TestCase):

def setUp(self):
self.artifact_location = tempfile.mkdtemp(prefix='sentence_transformers_')
self.test_chunks = [
Expand Down Expand Up @@ -89,7 +88,6 @@ def test_embedding_pipeline(self):
},
content=Content(text="This is a test sentence."))
]
"""Test the complete embedding pipeline."""
embedder = HuggingfaceTextEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class BigQueryVectorSearchEnrichmentHandler(
EnrichmentSourceHandler[Union[Embedding, List[Embedding]],
Union[Embedding, List[Embedding]]]):
"""Enrichment handler for BigQuery vector search."""

def __init__(
self,
project: str,
Expand Down
2 changes: 0 additions & 2 deletions sdks/python/apache_beam/ml/rag/ingestion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class VectorDatabaseConfig(ABC):
Implementations should provide database-specific configuration and
create appropriate write transforms.
"""

@abstractmethod
def create_write_transform(self) -> beam.PTransform:
"""Creates a PTransform that writes to the vector database.
Expand All @@ -40,7 +39,6 @@ class VectorDatabaseWriteTransform(beam.PTransform):
Uses the provided database config to create an appropriate write transform.
"""

def __init__(self, database_config: VectorDatabaseConfig):
"""Initialize transform with database config.
Expand Down
3 changes: 0 additions & 3 deletions sdks/python/apache_beam/ml/rag/ingestion/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@

class MockWriteTransform(beam.PTransform):
"""Mock transform that returns element."""

def expand(self, pcoll):
return pcoll | beam.Map(lambda x: x)


class MockDatabaseConfig(VectorDatabaseConfig):
"""Mock database config for testing."""

def __init__(self):
self.write_transform = MockWriteTransform()

Expand All @@ -41,7 +39,6 @@ def create_write_transform(self) -> beam.PTransform:


class VectorDatabaseBaseTest(unittest.TestCase):

def test_write_transform_creation(self):
"""Test that write transform is created correctly."""
config = MockDatabaseConfig()
Expand Down
9 changes: 3 additions & 6 deletions sdks/python/apache_beam/ml/rag/ingestion/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def create_write_transform(self) -> beam.PTransform:


class _WriteToBigQueryVectorDatabase(beam.PTransform):
"""Implementation of BigQuery vector database write."""

"""Implementation of BigQuery vector database write. """
def __init__(self, config: BigQueryVectorWriterConfig):
self.config = config

Expand All @@ -69,9 +68,7 @@ def expand(self, pcoll: beam.PCollection[Embedding]):
id=lambda x: str(x.id),
embedding=lambda x: [float(v) for v in x.dense_embedding],
content=lambda x: str(x.content.text),
metadata=lambda x: {
str(k): str(v)
for k, v in x.metadata.items()
})
metadata=lambda x: {str(k): str(v)
for k, v in x.metadata.items()})
| "Write to BigQuery" >> beam.managed.Write(
beam.managed.BIGQUERY, config=self.config.write_config))
Loading

0 comments on commit 0b6bdde

Please sign in to comment.