From 2924a08a0a7cb797dc7367bb5026fd768c52b20e Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 18 Dec 2024 21:04:31 +0000 Subject: [PATCH] Add BigQueryVectorWriterConfig tests. --- ...am_PostCommit_Python_Xlang_Gcp_Direct.json | 2 +- .../apache_beam/ml/rag/ingestion/bigquery.py | 2 +- .../ml/rag/ingestion/bigquery_it_test.py | 194 ++++++++++++++++++ sdks/python/test-suites/direct/common.gradle | 1 + 4 files changed, 197 insertions(+), 2 deletions(-) create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/bigquery_it_test.py diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json index b26833333238..c537844dc84a 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 2 + "modification": 3 } diff --git a/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py b/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py index c2173995e2e5..db98eecef8b1 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/bigquery.py @@ -66,7 +66,7 @@ def __init__( self, write_config: Dict[str, Any], *, # Force keyword arguments - schema_config: Optional[SchemaConfig] + schema_config: Optional[SchemaConfig] = None ): """Configuration for writing vectors to BigQuery using managed transforms. diff --git a/sdks/python/apache_beam/ml/rag/ingestion/bigquery_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/bigquery_it_test.py new file mode 100644 index 000000000000..5f468064b839 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/bigquery_it_test.py @@ -0,0 +1,194 @@ +import logging +import secrets +import time +import unittest + +import hamcrest as hc +import pytest + +import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper +from apache_beam.io.gcp.internal.clients import bigquery +from apache_beam.ml.rag.ingestion.bigquery import BigQueryVectorWriterConfig, SchemaConfig +from apache_beam.ml.rag.types import Chunk, Content, Embedding +from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher +from apache_beam.transforms.periodicsequence import PeriodicImpulse +import os + + +@pytest.mark.uses_gcp_java_expansion_service +@unittest.skipUnless( + os.environ.get('EXPANSION_JARS'), + "EXPANSION_JARS environment var is not provided, " + "indicating that jars have not been built") +class BigQueryVectorWriterConfigTest(unittest.TestCase): + BIG_QUERY_DATASET_ID = 'python_rag_bigquery_' + + def setUp(self): + self.test_pipeline = TestPipeline(is_integration_test=True) + self._runner = type(self.test_pipeline.runner).__name__ + self.project = self.test_pipeline.get_option('project') + + self.bigquery_client = BigQueryWrapper() + self.dataset_id = '%s%d%s' % ( + self.BIG_QUERY_DATASET_ID, int(time.time()), secrets.token_hex(3)) + self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id) + _LOGGER = logging.getLogger(__name__) + _LOGGER.info( + "Created dataset %s in project %s", self.dataset_id, self.project) + + def tearDown(self): + request = bigquery.BigqueryDatasetsDeleteRequest( + projectId=self.project, datasetId=self.dataset_id, deleteContents=True) + try: + _LOGGER = logging.getLogger(__name__) + _LOGGER.info( + "Deleting dataset %s in project %s", self.dataset_id, self.project) + self.bigquery_client.client.datasets.Delete(request) + # Failing to delete a dataset should not cause a test failure. + except Exception: + _LOGGER = logging.getLogger(__name__) + _LOGGER.debug( + 'Failed to clean up dataset %s in project %s', + self.dataset_id, + self.project) + + def test_default_schema(self): + table_name = 'python_default_schema_table' + table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name) + + config = BigQueryVectorWriterConfig(write_config={'table': table_id}) + chunks = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.1, 0.2]), + content=Content(text="foo"), + metadata={"a": "b"}), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.3, 0.4]), + content=Content(text="bar"), + metadata={"c": "d"}) + ] + + pipeline_verifiers = [ + BigqueryFullResultMatcher( + project=self.project, + query="SELECT id, content, embedding, metadata FROM %s" % table_id, + data=[("1", "foo", [0.1, 0.2], [{ + "key": "a", "value": "b" + }]), ("2", "bar", [0.3, 0.4], [{ + "key": "c", "value": "d" + }])]) + ] + + args = self.test_pipeline.get_full_options_as_args( + on_success_matcher=hc.all_of(*pipeline_verifiers)) + with beam.Pipeline(argv=args) as p: + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + def test_custom_schema(self): + table_name = 'python_custom_schema_table' + table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name) + + schema_config = SchemaConfig( + schema={ + 'fields': [{ + 'name': 'id', 'type': 'STRING' + }, + { + 'name': 'embedding', + 'type': 'FLOAT64', + 'mode': 'REPEATED' + }, { + 'name': 'source', 'type': 'STRING' + }] + }, + chunk_to_dict_fn=lambda chunk: { + 'id': chunk.id, + 'embedding': chunk.embedding.dense_embedding, + 'source': chunk.metadata.get('source') + }) + config = BigQueryVectorWriterConfig( + write_config={'table': table_id}, schema_config=schema_config) + chunks = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.1, 0.2]), + content=Content(text="foo content"), + metadata={"source": "foo"}), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.3, 0.4]), + content=Content(text="bar content"), + metadata={"source": "bar"}) + ] + + pipeline_verifiers = [ + BigqueryFullResultMatcher( + project=self.project, + query="SELECT id, embedding, source FROM %s" % table_id, + data=[("1", [0.1, 0.2], "foo"), ("2", [0.3, 0.4], "bar")]) + ] + + args = self.test_pipeline.get_full_options_as_args( + on_success_matcher=hc.all_of(*pipeline_verifiers)) + + with beam.Pipeline(argv=args) as p: + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + def test_streaming_default_schema(self): + self.skip_if_not_dataflow_runner() + + table_name = 'python_streaming_default_schema_table' + table_id = '{}.{}.{}'.format(self.project, self.dataset_id, table_name) + + config = BigQueryVectorWriterConfig(write_config={'table': table_id}) + chunks = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.1, 0.2]), + content=Content(text="foo"), + metadata={"a": "b"}), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.3, 0.4]), + content=Content(text="bar"), + metadata={"c": "d"}) + ] + + pipeline_verifiers = [ + BigqueryFullResultMatcher( + project=self.project, + query="SELECT id, content, embedding, metadata FROM %s" % table_id, + data=[("0", "foo", [0.1, 0.2], [{ + "key": "a", "value": "b" + }]), ("2", "bar", [0.3, 0.4], [{ + "key": "c", "value": "d" + }])]) + ] + args = self.test_pipeline.get_full_options_as_args( + on_success_matcher=hc.all_of(*pipeline_verifiers), + streaming=True, + allow_unsafe_triggers=True) + + with beam.Pipeline(argv=args) as p: + _ = ( + p + | PeriodicImpulse(0, 4, 1) + | beam.Map(lambda t: chunks[t]) + | config.create_write_transform()) + + def skip_if_not_dataflow_runner(self) -> bool: + # skip if dataflow runner is not specified + if not self._runner or "dataflowrunner" not in self._runner.lower(): + self.skipTest( + "Streaming with exactly-once route has the requirement " + "`beam:requirement:pardo:on_window_expiration:v1`, " + "which is currently only supported by the Dataflow runner") + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle index e290e8003b13..1dd15ecb09f9 100644 --- a/sdks/python/test-suites/direct/common.gradle +++ b/sdks/python/test-suites/direct/common.gradle @@ -447,6 +447,7 @@ project(":sdks:python:test-suites:xlang").ext.xlangTasks.each { taskMetadata -> pythonPipelineOptions: [ "--runner=TestDirectRunner", "--project=${gcpProject}", + "--temp_location=gs://temp-storage-for-end-to-end-tests/temp-it", ], pytestOptions: [ "--capture=no", // print stdout instantly