Skip to content

Commit

Permalink
Add BigQueryVectorWriterConfig tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Claude authored and claudevdm committed Dec 19, 2024
1 parent a22a48c commit 2924a08
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 2
"modification": 3
}
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/ml/rag/ingestion/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self,
write_config: Dict[str, Any],
*, # Force keyword arguments
schema_config: Optional[SchemaConfig]
schema_config: Optional[SchemaConfig] = None
):
"""Configuration for writing vectors to BigQuery using managed transforms.
Expand Down
194 changes: 194 additions & 0 deletions sdks/python/apache_beam/ml/rag/ingestion/bigquery_it_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import logging
import secrets
import time
import unittest

import hamcrest as hc
import pytest

import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
from apache_beam.io.gcp.internal.clients import bigquery
from apache_beam.ml.rag.ingestion.bigquery import BigQueryVectorWriterConfig, SchemaConfig
from apache_beam.ml.rag.types import Chunk, Content, Embedding
from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher
from apache_beam.transforms.periodicsequence import PeriodicImpulse
import os


@pytest.mark.uses_gcp_java_expansion_service
@unittest.skipUnless(
os.environ.get('EXPANSION_JARS'),
"EXPANSION_JARS environment var is not provided, "
"indicating that jars have not been built")
class BigQueryVectorWriterConfigTest(unittest.TestCase):
BIG_QUERY_DATASET_ID = 'python_rag_bigquery_'

def setUp(self):
self.test_pipeline = TestPipeline(is_integration_test=True)
self._runner = type(self.test_pipeline.runner).__name__
self.project = self.test_pipeline.get_option('project')

self.bigquery_client = BigQueryWrapper()
self.dataset_id = '%s%d%s' % (
self.BIG_QUERY_DATASET_ID, int(time.time()), secrets.token_hex(3))
self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id)
_LOGGER = logging.getLogger(__name__)
_LOGGER.info(
"Created dataset %s in project %s", self.dataset_id, self.project)

def tearDown(self):
request = bigquery.BigqueryDatasetsDeleteRequest(
projectId=self.project, datasetId=self.dataset_id, deleteContents=True)
try:
_LOGGER = logging.getLogger(__name__)
_LOGGER.info(
"Deleting dataset %s in project %s", self.dataset_id, self.project)
self.bigquery_client.client.datasets.Delete(request)
# Failing to delete a dataset should not cause a test failure.
except Exception:
_LOGGER = logging.getLogger(__name__)
_LOGGER.debug(
'Failed to clean up dataset %s in project %s',
self.dataset_id,
self.project)

def test_default_schema(self):
table_name = 'python_default_schema_table'
table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name)

config = BigQueryVectorWriterConfig(write_config={'table': table_id})
chunks = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.1, 0.2]),
content=Content(text="foo"),
metadata={"a": "b"}),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.3, 0.4]),
content=Content(text="bar"),
metadata={"c": "d"})
]

pipeline_verifiers = [
BigqueryFullResultMatcher(
project=self.project,
query="SELECT id, content, embedding, metadata FROM %s" % table_id,
data=[("1", "foo", [0.1, 0.2], [{
"key": "a", "value": "b"
}]), ("2", "bar", [0.3, 0.4], [{
"key": "c", "value": "d"
}])])
]

args = self.test_pipeline.get_full_options_as_args(
on_success_matcher=hc.all_of(*pipeline_verifiers))
with beam.Pipeline(argv=args) as p:
_ = (p | beam.Create(chunks) | config.create_write_transform())

def test_custom_schema(self):
table_name = 'python_custom_schema_table'
table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name)

schema_config = SchemaConfig(
schema={
'fields': [{
'name': 'id', 'type': 'STRING'
},
{
'name': 'embedding',
'type': 'FLOAT64',
'mode': 'REPEATED'
}, {
'name': 'source', 'type': 'STRING'
}]
},
chunk_to_dict_fn=lambda chunk: {
'id': chunk.id,
'embedding': chunk.embedding.dense_embedding,
'source': chunk.metadata.get('source')
})
config = BigQueryVectorWriterConfig(
write_config={'table': table_id}, schema_config=schema_config)
chunks = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.1, 0.2]),
content=Content(text="foo content"),
metadata={"source": "foo"}),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.3, 0.4]),
content=Content(text="bar content"),
metadata={"source": "bar"})
]

pipeline_verifiers = [
BigqueryFullResultMatcher(
project=self.project,
query="SELECT id, embedding, source FROM %s" % table_id,
data=[("1", [0.1, 0.2], "foo"), ("2", [0.3, 0.4], "bar")])
]

args = self.test_pipeline.get_full_options_as_args(
on_success_matcher=hc.all_of(*pipeline_verifiers))

with beam.Pipeline(argv=args) as p:
_ = (p | beam.Create(chunks) | config.create_write_transform())

def test_streaming_default_schema(self):
self.skip_if_not_dataflow_runner()

table_name = 'python_streaming_default_schema_table'
table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name)

config = BigQueryVectorWriterConfig(write_config={'table': table_id})
chunks = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.1, 0.2]),
content=Content(text="foo"),
metadata={"a": "b"}),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.3, 0.4]),
content=Content(text="bar"),
metadata={"c": "d"})
]

pipeline_verifiers = [
BigqueryFullResultMatcher(
project=self.project,
query="SELECT id, content, embedding, metadata FROM %s" % table_id,
data=[("0", "foo", [0.1, 0.2], [{
"key": "a", "value": "b"
}]), ("2", "bar", [0.3, 0.4], [{
"key": "c", "value": "d"
}])])
]
args = self.test_pipeline.get_full_options_as_args(
on_success_matcher=hc.all_of(*pipeline_verifiers),
streaming=True,
allow_unsafe_triggers=True)

with beam.Pipeline(argv=args) as p:
_ = (
p
| PeriodicImpulse(0, 4, 1)
| beam.Map(lambda t: chunks[t])
| config.create_write_transform())

def skip_if_not_dataflow_runner(self) -> bool:
# skip if dataflow runner is not specified
if not self._runner or "dataflowrunner" not in self._runner.lower():
self.skipTest(
"Streaming with exactly-once route has the requirement "
"`beam:requirement:pardo:on_window_expiration:v1`, "
"which is currently only supported by the Dataflow runner")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
1 change: 1 addition & 0 deletions sdks/python/test-suites/direct/common.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ project(":sdks:python:test-suites:xlang").ext.xlangTasks.each { taskMetadata ->
pythonPipelineOptions: [
"--runner=TestDirectRunner",
"--project=${gcpProject}",
"--temp_location=gs://temp-storage-for-end-to-end-tests/temp-it",
],
pytestOptions: [
"--capture=no", // print stdout instantly
Expand Down

0 comments on commit 2924a08

Please sign in to comment.