Skip to content

Commit

Permalink
Support impersonate service-account for BigQuery
Browse files Browse the repository at this point in the history
Signed-off-by: Ching Yi, Chan <[email protected]>
  • Loading branch information
qrtt1 committed Oct 17, 2023
1 parent a1a9823 commit ed4e8b5
Showing 1 changed file with 87 additions and 2 deletions.
89 changes: 87 additions & 2 deletions piperider_cli/datasource/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import base64
import json
import os
from typing import List
from typing import List, Optional

import google.auth
import inquirer
import sqlalchemy
from google.api_core import client_info
from google.auth import impersonated_credentials
from google.cloud import bigquery
from google.oauth2 import service_account
from sqlalchemy_bigquery import _helpers

from piperider_cli.error import PipeRiderConnectorError

from . import DataSource
from .field import PathField, ListField, DataSourceField, _default_validate_func
from .field import DataSourceField, ListField, _default_validate_func

APPLICATION_DEFAULT_CREDENTIALS = os.path.join(os.path.expanduser('~'), '.config', 'gcloud',
'application_default_credentials.json')
Expand All @@ -24,6 +33,79 @@
AUTH_METHOD_SERVICE_ACCOUNT = 'service-account'
AUTH_METHOD_SERVICE_ACCOUNT_JSON = 'service-account-json'

USER_AGENT_TEMPLATE = "sqlalchemy/{}"
SCOPES = (
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/drive",
)

target_service_account_email: Optional[str] = None


def google_client_info():
user_agent = USER_AGENT_TEMPLATE.format(sqlalchemy.__version__)
return client_info.ClientInfo(user_agent=user_agent)


def create_impersonated_bigquery_client(
credentials_info=None,
credentials_path=None,
credentials_base64=None,
default_query_job_config=None,
location=None,
project_id=None,
):
default_project = None

if credentials_base64:
credentials_info = json.loads(base64.b64decode(credentials_base64))

if credentials_path:
credentials = service_account.Credentials.from_service_account_file(
credentials_path
)
credentials = credentials.with_scopes(SCOPES)
default_project = credentials.project_id
elif credentials_info:
credentials = service_account.Credentials.from_service_account_info(
credentials_info
)
credentials = credentials.with_scopes(SCOPES)
default_project = credentials.project_id
else:
credentials, default_project = google.auth.default(scopes=SCOPES)

if project_id is None:
project_id = default_project

if target_service_account_email:
impersonated_creds = impersonated_credentials.Credentials(
source_credentials=credentials,
target_principal=target_service_account_email,
target_scopes=SCOPES,
lifetime=3600 # Duration for which the token is valid (in seconds)
)
return bigquery.Client(
client_info=google_client_info(),
project=project_id,
credentials=impersonated_creds,
location=location,
default_query_job_config=default_query_job_config,
)

return bigquery.Client(
client_info=google_client_info(),
project=project_id,
credentials=credentials,
location=location,
default_query_job_config=default_query_job_config,
)


# monkey-patch
_helpers.create_bigquery_client = create_impersonated_bigquery_client


class HiddenProjectListFromOAuthField(DataSourceField):

Expand Down Expand Up @@ -217,11 +299,14 @@ def to_database_url(self, database):
return f'bigquery://{project}/{dataset}'

def engine_args(self):
global target_service_account_email
args = dict()
if self.credential.get('method') == AUTH_METHOD_SERVICE_ACCOUNT:
args['credentials_path'] = self.credential.get('keyfile')
elif self.credential.get('method') == AUTH_METHOD_SERVICE_ACCOUNT_JSON:
args['credentials_info'] = self.credential.get('keyfile_json', {})

target_service_account_email = self.credential.get('impersonate_service_account', None)
return args

def verify_connector(self):
Expand Down

0 comments on commit ed4e8b5

Please sign in to comment.