Skip to content

Commit

Permalink
backport pr #578 (#598)
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-rogers-dbt authored Mar 10, 2023
1 parent eb908d7 commit e2620c6
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 32 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230303-132509.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: add dataproc serverless config to profile
time: 2023-03-03T13:25:09.02695-08:00
custom:
Author: colin-rogers-dbt torkjel
Issue: "530"
19 changes: 17 additions & 2 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import re
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from mashumaro.helper import pass_through

from functools import lru_cache
import agate
from requests.exceptions import ConnectionError
Expand Down Expand Up @@ -35,7 +37,7 @@
from dbt.events.types import SQLQuery
from dbt.version import __version__ as dbt_version

from dbt.dataclass_schema import StrEnum
from dbt.dataclass_schema import ExtensibleDbtClassMixin, StrEnum

logger = AdapterLogger("BigQuery")

Expand Down Expand Up @@ -92,6 +94,12 @@ class BigQueryAdapterResponse(AdapterResponse):
slot_ms: Optional[int] = None


@dataclass
class DataprocBatchConfig(ExtensibleDbtClassMixin):
def __init__(self, batch_config):
self.batch_config = batch_config


@dataclass
class BigQueryCredentials(Credentials):
method: BigQueryConnectionMethod
Expand Down Expand Up @@ -124,6 +132,13 @@ class BigQueryCredentials(Credentials):
dataproc_cluster_name: Optional[str] = None
gcs_bucket: Optional[str] = None

dataproc_batch: Optional[DataprocBatchConfig] = field(
metadata={
"serialization_strategy": pass_through,
},
default=None,
)

scopes: Optional[Tuple[str, ...]] = (
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/cloud-platform",
Expand Down
78 changes: 50 additions & 28 deletions dbt/adapters/bigquery/python_submissions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Dict, Union
import time

from dbt.adapters.base import PythonJobHelper
from dbt.adapters.bigquery import BigQueryConnectionManager, BigQueryCredentials
from dbt.adapters.bigquery.connections import DataprocBatchConfig
from google.api_core import retry
from google.api_core.client_options import ClientOptions
from google.cloud import storage, dataproc_v1 # type: ignore
from google.protobuf.json_format import ParseDict

OPERATION_RETRY_TIME = 10

Expand Down Expand Up @@ -67,14 +68,6 @@ def _get_job_client(
def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job:
raise NotImplementedError("_submit_dataproc_job not implemented")

def _wait_operation(self, operation):
# can't use due to https://github.com/googleapis/python-api-core/issues/458
# response = operation.result(retry=self.retry)
# Temp solution to wait for the job to finish
start = time.time()
while not operation.done(retry=None) and time.time() - start < self.timeout:
time.sleep(OPERATION_RETRY_TIME)


class ClusterDataprocHelper(BaseDataProcHelper):
def _get_job_client(self) -> dataproc_v1.JobControllerClient:
Expand Down Expand Up @@ -105,8 +98,7 @@ def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job:
"job": job,
}
)
self._wait_operation(operation)
response = operation.metadata
response = operation.result(retry=self.retry)
# check if job failed
if response.status.state == 6:
raise ValueError(response.status.details)
Expand All @@ -120,31 +112,17 @@ def _get_job_client(self) -> dataproc_v1.BatchControllerClient:
)

def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job:
# create the Dataproc Serverless job config
# need to pin dataproc version to 1.1 as it now defaults to 2.0
batch = dataproc_v1.Batch({"runtime_config": dataproc_v1.RuntimeConfig(version="1.1")})
batch.pyspark_batch.main_python_file_uri = self.gcs_location
# how to keep this up to date?
# we should probably also open this up to be configurable
jar_file_uri = self.parsed_model["config"].get(
"jar_file_uri",
"gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.12-0.21.1.jar",
)
batch.pyspark_batch.jar_file_uris = [jar_file_uri]
# should we make all of these spark/dataproc properties configurable?
# https://cloud.google.com/dataproc-serverless/docs/concepts/properties
# https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#runtimeconfig
batch.runtime_config.properties = {
"spark.executor.instances": "2",
}
batch = self._configure_batch()
parent = f"projects/{self.credential.execution_project}/locations/{self.credential.dataproc_region}"

request = dataproc_v1.CreateBatchRequest(
parent=parent,
batch=batch,
)
# make the request
operation = self.job_client.create_batch(request=request) # type: ignore
# this takes quite a while, waiting on GCP response to resolve
# (not a google-api-core issue, more likely a dataproc serverless issue)
response = operation.result(retry=self.retry)
return response
# there might be useful results here that we can parse and return
Expand All @@ -157,3 +135,47 @@ def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job:
# .blob(f"{matches.group(2)}.000000000")
# .download_as_string()
# )

def _configure_batch(self):
# create the Dataproc Serverless job config
# need to pin dataproc version to 1.1 as it now defaults to 2.0
# https://cloud.google.com/dataproc-serverless/docs/concepts/properties
# https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#runtimeconfig
batch = dataproc_v1.Batch(
{
"runtime_config": dataproc_v1.RuntimeConfig(
version="1.1",
properties={
"spark.executor.instances": "2",
},
)
}
)
# Apply defaults
batch.pyspark_batch.main_python_file_uri = self.gcs_location
jar_file_uri = self.parsed_model["config"].get(
"jar_file_uri",
"gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.12-0.21.1.jar",
)
batch.pyspark_batch.jar_file_uris = [jar_file_uri]

# Apply configuration from dataproc_batch key, possibly overriding defaults.
if self.credential.dataproc_batch:
self._update_batch_from_config(self.credential.dataproc_batch, batch)
return batch

@classmethod
def _update_batch_from_config(
cls, config_dict: Union[Dict, DataprocBatchConfig], target: dataproc_v1.Batch
):
try:
# updates in place
ParseDict(config_dict, target._pb)
except Exception as e:
docurl = (
"https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1"
"#google.cloud.dataproc.v1.Batch"
)
raise ValueError(
f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}"
) from e
3 changes: 3 additions & 0 deletions dbt/include/bigquery/macros/python_model/python.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{% macro bigquery__resolve_model_name(input_model_name) -%}
{{ input_model_name | string | replace('`', '') | replace('"', '\"') }}
{%- endmacro -%}
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ tox~=3.0;python_version=="3.7"
tox~=4.4;python_version>="3.8"
types-pytz~=2022.7
types-requests~=2.28
types-protobuf~=4.0
twine~=4.0
wheel~=0.38
56 changes: 56 additions & 0 deletions tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,43 @@ def setUp(self):
"threads": 1,
"location": "Solar Station",
},
'dataproc-serverless-configured' : {
'type': 'bigquery',
'method': 'oauth',
'schema': 'dummy_schema',
'threads': 1,
'gcs_bucket': 'dummy-bucket',
'dataproc_region': 'europe-west1',
'submission_method': 'serverless',
'dataproc_batch': {
'environment_config' : {
'execution_config' : {
'service_account': '[email protected]',
'subnetwork_uri': 'dataproc',
'network_tags': [ "foo", "bar" ]
}
},
'labels': {
'dbt': 'rocks',
'number': '1'
},
'runtime_config': {
'properties': {
'spark.executor.instances': '4',
'spark.driver.memory': '1g'
}
}
}
},
'dataproc-serverless-default' : {
'type': 'bigquery',
'method': 'oauth',
'schema': 'dummy_schema',
'threads': 1,
'gcs_bucket': 'dummy-bucket',
'dataproc_region': 'europe-west1',
'submission_method': 'serverless'
}
},
"target": "oauth",
}
Expand Down Expand Up @@ -184,6 +221,25 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection):
connection.handle
mock_open_connection.assert_called_once()

@patch('dbt.adapters.bigquery.connections.get_bigquery_defaults', return_value=('credentials', 'project_id'))
@patch('dbt.adapters.bigquery.BigQueryConnectionManager.open', return_value=_bq_conn())
def test_acquire_connection_dataproc_serverless(self, mock_open_connection, mock_get_bigquery_defaults):
adapter = self.get_adapter('dataproc-serverless-configured')
mock_get_bigquery_defaults.assert_called_once()
try:
connection = adapter.acquire_connection('dummy')
self.assertEqual(connection.type, 'bigquery')

except dbt.exceptions.ValidationException as e:
self.fail('got ValidationException: {}'.format(str(e)))

except BaseException as e:
raise

mock_open_connection.assert_not_called()
connection.handle
mock_open_connection.assert_called_once()

@patch('dbt.adapters.bigquery.BigQueryConnectionManager.open', return_value=_bq_conn())
def test_acquire_connection_service_account_validations(self, mock_open_connection):
adapter = self.get_adapter('service_account')
Expand Down
57 changes: 57 additions & 0 deletions tests/unit/test_configure_dataproc_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from unittest.mock import patch

from dbt.adapters.bigquery.python_submissions import ServerlessDataProcHelper
from google.cloud import dataproc_v1

from .test_bigquery_adapter import BaseTestBigQueryAdapter

# Test application of dataproc_batch configuration to a
# google.cloud.dataproc_v1.Batch object.
# This reuses the machinery from BaseTestBigQueryAdapter to get hold of the
# parsed credentials
class TestConfigureDataprocBatch(BaseTestBigQueryAdapter):

@patch('dbt.adapters.bigquery.connections.get_bigquery_defaults', return_value=('credentials', 'project_id'))
def test_update_dataproc_serverless_batch(self, mock_get_bigquery_defaults):
adapter = self.get_adapter('dataproc-serverless-configured')
mock_get_bigquery_defaults.assert_called_once()

credentials = adapter.acquire_connection('dummy').credentials
self.assertIsNotNone(credentials)

batchConfig = credentials.dataproc_batch
self.assertIsNotNone(batchConfig)

raw_batch_config = self.raw_profile['outputs']['dataproc-serverless-configured']['dataproc_batch']
raw_environment_config = raw_batch_config['environment_config']
raw_execution_config = raw_environment_config['execution_config']
raw_labels: dict[str, any] = raw_batch_config['labels']
raw_rt_config = raw_batch_config['runtime_config']

raw_batch_config = self.raw_profile['outputs']['dataproc-serverless-configured']['dataproc_batch']

batch = dataproc_v1.Batch()

ServerlessDataProcHelper._update_batch_from_config(raw_batch_config, batch)

# google's protobuf types expose maps as dict[str, str]
to_str_values = lambda d: dict([(k, str(v)) for (k, v) in d.items()])

self.assertEqual(batch.environment_config.execution_config.service_account, raw_execution_config['service_account'])
self.assertFalse(batch.environment_config.execution_config.network_uri)
self.assertEqual(batch.environment_config.execution_config.subnetwork_uri, raw_execution_config['subnetwork_uri'])
self.assertEqual(batch.environment_config.execution_config.network_tags, raw_execution_config['network_tags'])
self.assertEqual(batch.labels, to_str_values(raw_labels))
self.assertEqual(batch.runtime_config.properties, to_str_values(raw_rt_config['properties']))


@patch('dbt.adapters.bigquery.connections.get_bigquery_defaults', return_value=('credentials', 'project_id'))
def test_default_dataproc_serverless_batch(self, mock_get_bigquery_defaults):
adapter = self.get_adapter('dataproc-serverless-default')
mock_get_bigquery_defaults.assert_called_once()

credentials = adapter.acquire_connection('dummy').credentials
self.assertIsNotNone(credentials)

batchConfig = credentials.dataproc_batch
self.assertIsNone(batchConfig)
17 changes: 15 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,24 @@ passenv =
DBT_*
BIGQUERY_TEST_*
PYTEST_ADDOPTS
DATAPROC_*
GCS_BUCKET
commands =
bigquery: {envpython} -m pytest {posargs} -m profile_bigquery tests/integration
bigquery: {envpython} -m pytest {posargs} -vv tests/functional --profile service_account
deps =
-rdev-requirements.txt
-e.

[testenv:{python-tests,py37,py38,py39,py310,py311,py}]
description = python integration testing
skip_install = true
passenv =
DBT_*
BIGQUERY_TEST_*
PYTEST_ADDOPTS
DATAPROC_*
GCS_BUCKET
commands =
{envpython} -m pytest {posargs} -vv tests/functional -k "TestPython" --profile service_account
deps =
-rdev-requirements.txt
-e.

0 comments on commit e2620c6

Please sign in to comment.