diff --git a/piperider_cli/datasource/bigquery.py b/piperider_cli/datasource/bigquery.py index 8ead16104..2820d56a2 100644 --- a/piperider_cli/datasource/bigquery.py +++ b/piperider_cli/datasource/bigquery.py @@ -1,12 +1,13 @@ import json import os -from typing import List +from typing import List, Optional import inquirer from piperider_cli.error import PipeRiderConnectorError + from . import DataSource -from .field import PathField, ListField, DataSourceField, _default_validate_func +from .field import DataSourceField, ListField, _default_validate_func APPLICATION_DEFAULT_CREDENTIALS = os.path.join(os.path.expanduser('~'), '.config', 'gcloud', 'application_default_credentials.json') @@ -222,6 +223,14 @@ def engine_args(self): args['credentials_path'] = self.credential.get('keyfile') elif self.credential.get('method') == AUTH_METHOD_SERVICE_ACCOUNT_JSON: args['credentials_info'] = self.credential.get('keyfile_json', {}) + + target_service_account_email = self.credential.get('impersonate_service_account', None) + if target_service_account_email: + # monkey-patch + import piperider_cli.datasource.bigquery_patch + from sqlalchemy_bigquery import _helpers + piperider_cli.datasource.bigquery_patch.target_service_account_email = target_service_account_email + _helpers.create_bigquery_client = piperider_cli.datasource.bigquery_patch.create_impersonated_bigquery_client return args def verify_connector(self): diff --git a/piperider_cli/datasource/bigquery_patch.py b/piperider_cli/datasource/bigquery_patch.py new file mode 100644 index 000000000..9c21eefc9 --- /dev/null +++ b/piperider_cli/datasource/bigquery_patch.py @@ -0,0 +1,79 @@ +import base64 +import json +from typing import Optional + +import google.auth +import sqlalchemy +from google.api_core import client_info +from google.auth import impersonated_credentials +from google.cloud import bigquery +from google.oauth2 import service_account + +USER_AGENT_TEMPLATE = "sqlalchemy/{}" +SCOPES = ( + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/drive", +) + +target_service_account_email: Optional[str] = None + + +def google_client_info(): + user_agent = USER_AGENT_TEMPLATE.format(sqlalchemy.__version__) + return client_info.ClientInfo(user_agent=user_agent) + + +def create_impersonated_bigquery_client( + credentials_info=None, + credentials_path=None, + credentials_base64=None, + default_query_job_config=None, + location=None, + project_id=None, +): + default_project = None + + if credentials_base64: + credentials_info = json.loads(base64.b64decode(credentials_base64)) + + if credentials_path: + credentials = service_account.Credentials.from_service_account_file( + credentials_path + ) + credentials = credentials.with_scopes(SCOPES) + default_project = credentials.project_id + elif credentials_info: + credentials = service_account.Credentials.from_service_account_info( + credentials_info + ) + credentials = credentials.with_scopes(SCOPES) + default_project = credentials.project_id + else: + credentials, default_project = google.auth.default(scopes=SCOPES) + + if project_id is None: + project_id = default_project + + if target_service_account_email: + impersonated_creds = impersonated_credentials.Credentials( + source_credentials=credentials, + target_principal=target_service_account_email, + target_scopes=SCOPES, + lifetime=3600 # Duration for which the token is valid (in seconds) + ) + return bigquery.Client( + client_info=google_client_info(), + project=project_id, + credentials=impersonated_creds, + location=location, + default_query_job_config=default_query_job_config, + ) + + return bigquery.Client( + client_info=google_client_info(), + project=project_id, + credentials=credentials, + location=location, + default_query_job_config=default_query_job_config, + )