Skip to content

Commit

Permalink
Add BigQueryVectorWriterConfig tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Claude authored and claudevdm committed Dec 19, 2024
1 parent a22a48c commit 0542aa0
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 2
"modification": 3
}
36 changes: 16 additions & 20 deletions sdks/python/apache_beam/ml/rag/ingestion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,14 @@ class VectorDatabaseWriteConfig(ABC):
3. Transform handles converting Chunks to database-specific format
Example implementation:
```python
class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig):
def __init__(self, table: str):
self.embedding_column = embedding_column
def create_write_transform(self):
return beam.io.WriteToBigQuery(
table=self.table
)
```
>>> class BigQueryVectorWriterConfig(VectorDatabaseWriteConfig):
... def __init__(self, table: str):
... self.embedding_column = embedding_column
...
... def create_write_transform(self):
... return beam.io.WriteToBigQuery(
... table=self.table
... )
"""
@abstractmethod
def create_write_transform(self) -> beam.PTransform:
Expand All @@ -67,16 +65,14 @@ class VectorDatabaseWriteTransform(beam.PTransform):
the database-specific write transform.
Example usage:
```python
config = BigQueryVectorConfig(
table='project.dataset.embeddings',
embedding_column='embedding'
)
with beam.Pipeline() as p:
chunks = p | beam.Create([...]) # PCollection[Chunk]
chunks | VectorDatabaseWriteTransform(config)
```
>>> config = BigQueryVectorConfig(
... table='project.dataset.embeddings',
... embedding_column='embedding'
... )
>>>
>>> with beam.Pipeline() as p:
... chunks = p | beam.Create([...]) # PCollection[Chunk]
... chunks | VectorDatabaseWriteTransform(config)
Args:
database_config: Configuration for the target vector database.
Expand Down
72 changes: 34 additions & 38 deletions sdks/python/apache_beam/ml/rag/ingestion/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ class SchemaConfig:
Attributes:
schema: BigQuery TableSchema dict defining the table structure.
Example:
{
'fields': [
{'name': 'id', 'type': 'STRING'},
{'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'},
{'name': 'custom_field', 'type': 'STRING'}
]
}
>>> {
... 'fields': [
... {'name': 'id', 'type': 'STRING'},
... {'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'},
... {'name': 'custom_field', 'type': 'STRING'}
... ]
... }
chunk_to_dict_fn: Function that converts a Chunk to a dict matching the
schema. Takes a Chunk and returns Dict[str, Any] with keys matching
schema fields.
Example:
def chunk_to_dict(chunk: Chunk) -> Dict[str, Any]:
return {
'id': chunk.id,
'embedding': chunk.embedding.dense_embedding,
'custom_field': chunk.metadata.get('custom_field')
}
>>> def chunk_to_dict(chunk: Chunk) -> Dict[str, Any]:
... return {
... 'id': chunk.id,
... 'embedding': chunk.embedding.dense_embedding,
... 'custom_field': chunk.metadata.get('custom_field')
... }
"""
schema: Dict
chunk_to_dict_fn: ChunkToDictFn
Expand All @@ -66,40 +66,36 @@ def __init__(
self,
write_config: Dict[str, Any],
*, # Force keyword arguments
schema_config: Optional[SchemaConfig]
schema_config: Optional[SchemaConfig] = None
):
"""Configuration for writing vectors to BigQuery using managed transforms.
Supports both default schema (id, embedding, content, metadata columns) and
custom schemas through SchemaConfig.
Example with default schema:
```python
config = BigQueryVectorWriterConfig(
write_config={'table': 'project.dataset.embeddings'})
```
>>> config = BigQueryVectorWriterConfig(
... write_config={'table': 'project.dataset.embeddings'})
Example with custom schema:
```python
schema_config = SchemaConfig(
schema={
'fields': [
{'name': 'id', 'type': 'STRING'},
{'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'},
{'name': 'source_url', 'type': 'STRING'}
]
},
chunk_to_dict_fn=lambda chunk: {
'id': chunk.id,
'embedding': chunk.embedding.dense_embedding,
'source_url': chunk.metadata.get('url')
}
)
config = BigQueryVectorWriterConfig(
write_config={'table': 'project.dataset.embeddings'},
schema_config=schema_config
)
```
>>> schema_config = SchemaConfig(
... schema={
... 'fields': [
... {'name': 'id', 'type': 'STRING'},
... {'name': 'embedding', 'type': 'FLOAT64', 'mode': 'REPEATED'},
... {'name': 'source_url', 'type': 'STRING'}
... ]
... },
... chunk_to_dict_fn=lambda chunk: {
... 'id': chunk.id,
... 'embedding': chunk.embedding.dense_embedding,
... 'source_url': chunk.metadata.get('url')
... }
... )
>>> config = BigQueryVectorWriterConfig(
... write_config={'table': 'project.dataset.embeddings'},
... schema_config=schema_config
... )
Args:
write_config: BigQuery write configuration dict. Must include 'table'.
Expand Down
194 changes: 194 additions & 0 deletions sdks/python/apache_beam/ml/rag/ingestion/bigquery_it_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import logging
import secrets
import time
import unittest

import hamcrest as hc
import pytest

import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
from apache_beam.io.gcp.internal.clients import bigquery
from apache_beam.ml.rag.ingestion.bigquery import BigQueryVectorWriterConfig, SchemaConfig
from apache_beam.ml.rag.types import Chunk, Content, Embedding
from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher
from apache_beam.transforms.periodicsequence import PeriodicImpulse
import os


@pytest.mark.uses_gcp_java_expansion_service
@unittest.skipUnless(
os.environ.get('EXPANSION_JARS'),
"EXPANSION_JARS environment var is not provided, "
"indicating that jars have not been built")
class BigQueryVectorWriterConfigTest(unittest.TestCase):
BIG_QUERY_DATASET_ID = 'python_rag_bigquery_'

def setUp(self):
self.test_pipeline = TestPipeline(is_integration_test=True)
self._runner = type(self.test_pipeline.runner).__name__
self.project = self.test_pipeline.get_option('project')

self.bigquery_client = BigQueryWrapper()
self.dataset_id = '%s%d%s' % (
self.BIG_QUERY_DATASET_ID, int(time.time()), secrets.token_hex(3))
self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id)
_LOGGER = logging.getLogger(__name__)
_LOGGER.info(
"Created dataset %s in project %s", self.dataset_id, self.project)

def tearDown(self):
request = bigquery.BigqueryDatasetsDeleteRequest(
projectId=self.project, datasetId=self.dataset_id, deleteContents=True)
try:
_LOGGER = logging.getLogger(__name__)
_LOGGER.info(
"Deleting dataset %s in project %s", self.dataset_id, self.project)
self.bigquery_client.client.datasets.Delete(request)
# Failing to delete a dataset should not cause a test failure.
except Exception:
_LOGGER = logging.getLogger(__name__)
_LOGGER.debug(
'Failed to clean up dataset %s in project %s',
self.dataset_id,
self.project)

def test_default_schema(self):
table_name = 'python_default_schema_table'
table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name)

config = BigQueryVectorWriterConfig(write_config={'table': table_id})
chunks = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.1, 0.2]),
content=Content(text="foo"),
metadata={"a": "b"}),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.3, 0.4]),
content=Content(text="bar"),
metadata={"c": "d"})
]

pipeline_verifiers = [
BigqueryFullResultMatcher(
project=self.project,
query="SELECT id, content, embedding, metadata FROM %s" % table_id,
data=[("1", "foo", [0.1, 0.2], [{
"key": "a", "value": "b"
}]), ("2", "bar", [0.3, 0.4], [{
"key": "c", "value": "d"
}])])
]

args = self.test_pipeline.get_full_options_as_args(
on_success_matcher=hc.all_of(*pipeline_verifiers))
with beam.Pipeline(argv=args) as p:
_ = (p | beam.Create(chunks) | config.create_write_transform())

def test_custom_schema(self):
table_name = 'python_custom_schema_table'
table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name)

schema_config = SchemaConfig(
schema={
'fields': [{
'name': 'id', 'type': 'STRING'
},
{
'name': 'embedding',
'type': 'FLOAT64',
'mode': 'REPEATED'
}, {
'name': 'source', 'type': 'STRING'
}]
},
chunk_to_dict_fn=lambda chunk: {
'id': chunk.id,
'embedding': chunk.embedding.dense_embedding,
'source': chunk.metadata.get('source')
})
config = BigQueryVectorWriterConfig(
write_config={'table': table_id}, schema_config=schema_config)
chunks = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.1, 0.2]),
content=Content(text="foo content"),
metadata={"source": "foo"}),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.3, 0.4]),
content=Content(text="bar content"),
metadata={"source": "bar"})
]

pipeline_verifiers = [
BigqueryFullResultMatcher(
project=self.project,
query="SELECT id, embedding, source FROM %s" % table_id,
data=[("1", [0.1, 0.2], "foo"), ("2", [0.3, 0.4], "bar")])
]

args = self.test_pipeline.get_full_options_as_args(
on_success_matcher=hc.all_of(*pipeline_verifiers))

with beam.Pipeline(argv=args) as p:
_ = (p | beam.Create(chunks) | config.create_write_transform())

def test_streaming_default_schema(self):
self.skip_if_not_dataflow_runner()

table_name = 'python_streaming_default_schema_table'
table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name)

config = BigQueryVectorWriterConfig(write_config={'table': table_id})
chunks = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.1, 0.2]),
content=Content(text="foo"),
metadata={"a": "b"}),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.3, 0.4]),
content=Content(text="bar"),
metadata={"c": "d"})
]

pipeline_verifiers = [
BigqueryFullResultMatcher(
project=self.project,
query="SELECT id, content, embedding, metadata FROM %s" % table_id,
data=[("0", "foo", [0.1, 0.2], [{
"key": "a", "value": "b"
}]), ("2", "bar", [0.3, 0.4], [{
"key": "c", "value": "d"
}])])
]
args = self.test_pipeline.get_full_options_as_args(
on_success_matcher=hc.all_of(*pipeline_verifiers),
streaming=True,
allow_unsafe_triggers=True)

with beam.Pipeline(argv=args) as p:
_ = (
p
| PeriodicImpulse(0, 4, 1)
| beam.Map(lambda t: chunks[t])
| config.create_write_transform())

def skip_if_not_dataflow_runner(self):
# skip if dataflow runner is not specified
if not self._runner or "dataflowrunner" not in self._runner.lower():
self.skipTest(
"Streaming with exactly-once route has the requirement "
"`beam:requirement:pardo:on_window_expiration:v1`, "
"which is currently only supported by the Dataflow runner")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
1 change: 1 addition & 0 deletions sdks/python/test-suites/direct/common.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ project(":sdks:python:test-suites:xlang").ext.xlangTasks.each { taskMetadata ->
pythonPipelineOptions: [
"--runner=TestDirectRunner",
"--project=${gcpProject}",
"--temp_location=gs://temp-storage-for-end-to-end-tests/temp-it",
],
pytestOptions: [
"--capture=no", // print stdout instantly
Expand Down

0 comments on commit 0542aa0

Please sign in to comment.