-
- -

Source code for airflow_dbt_python.hooks.dbt

-"""Provides a hook to interact with a dbt project."""
-from __future__ import annotations
-
-import json
-import logging
-import sys
-from contextlib import contextmanager
-from pathlib import Path
-from tempfile import TemporaryDirectory
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Dict,
-    Iterable,
-    Iterator,
-    NamedTuple,
-    Optional,
-    Tuple,
-    Union,
-)
-from urllib.parse import urlparse
-
-from airflow.exceptions import AirflowException
-from airflow.hooks.base import BaseHook
-from airflow.models.connection import Connection
-
-if TYPE_CHECKING:
-    from dbt.contracts.results import RunResult
-    from dbt.task.base import BaseTask
-
-    from airflow_dbt_python.hooks.remote import DbtRemoteHook
-    from airflow_dbt_python.utils.configs import BaseConfig
-    from airflow_dbt_python.utils.url import URLLike
-
-    DbtRemoteHooksDict = Dict[Tuple[str, Optional[str]], DbtRemoteHook]
-
-
-
[docs]class DbtTaskResult(NamedTuple): - """A tuple returned after a dbt task executes. - - Attributes: - success: Whether the task succeeded or not. - run_results: Results from the dbt task, if available. - artifacts: A dictionary of saved dbt artifacts. It may be empty. - """ - - success: bool - run_results: Optional[RunResult] - artifacts: dict[str, Any]
- - -
[docs]class DbtConnectionParam(NamedTuple): - """A tuple indicating connection parameters relevant to dbt. - - Attributes: - name: The name of the connection parameter. This name will be used to get the - parameter from an Airflow Connection or its extras. - store_override_name: A new name for the connection parameter. If not None, this - is the name used in a dbt profiles. - default: A default value if the parameter is not found. - """ - - name: str - store_override_name: Optional[str] = None - default: Optional[Any] = None - - @property - def override_name(self): - """Returns the override_name if defined, otherwise defaults to name. - - >>> DbtConnectionParam("login", "user").override_name - 'user' - >>> DbtConnectionParam("port").override_name - 'port' - """ - if self.store_override_name is None: - return self.name - return self.store_override_name
- - -
[docs]class DbtTemporaryDirectory(TemporaryDirectory): - """A wrapper on TemporaryDirectory for older versions of Python. - - Support for ignore_cleanup_errors was added in Python 3.10. There is a very obscure - error that can happen when cleaning up a directory, even though everything should - be cleaned. We would like to use ignore_cleanup_errors to provide clean up on a - best-effort basis. For the time being, we are addressing this only for Python>=3.10. - """ - - def __init__(self, suffix=None, prefix=None, dir=None, ignore_cleanup_errors=True): - if sys.version_info.minor < 10 and sys.version_info.major == 3: - super().__init__(suffix=suffix, prefix=prefix, dir=dir) - else: - super().__init__( - suffix=suffix, - prefix=prefix, - dir=dir, - ignore_cleanup_errors=ignore_cleanup_errors, - )
- - -
[docs]class DbtHook(BaseHook): - """A hook to interact with dbt. - - Allows for running dbt tasks and provides required configurations for each task. - """ - - conn_name_attr = "dbt_conn_id" - default_conn_name = "dbt_default" - conn_type = "dbt" - hook_name = "dbt Hook" - - conn_params: list[Union[DbtConnectionParam, str]] = [ - DbtConnectionParam("conn_type", "type"), - "host", - DbtConnectionParam("conn_id", "dbname"), - "schema", - DbtConnectionParam("login", "user"), - "password", - "port", - ] - conn_extra_params: list[Union[DbtConnectionParam, str]] = [] - - def __init__( - self, - *args, - dbt_conn_id: Optional[str] = default_conn_name, - project_conn_id: Optional[str] = None, - profiles_conn_id: Optional[str] = None, - **kwargs, - ): - self.remotes: DbtRemoteHooksDict = {} - self.dbt_conn_id = dbt_conn_id - self.project_conn_id = project_conn_id - self.profiles_conn_id = profiles_conn_id - super().__init__(*args, **kwargs) - -
[docs] def get_remote(self, scheme: str, conn_id: Optional[str]) -> DbtRemoteHook: - """Get a remote to interact with dbt files. - - RemoteHooks are defined by the scheme we are looking for and an optional - connection id if we are looking to interface with any Airflow hook that - uses a connection. - """ - from .remote import get_remote - - try: - return self.remotes[(scheme, conn_id)] - except KeyError: - remote = get_remote(scheme, conn_id) - self.remotes[(scheme, conn_id)] = remote - return remote
- -
[docs] def download_dbt_profiles( - self, - profiles_dir: URLLike, - destination: URLLike, - ) -> Path: - """Pull a dbt profiles.yml file from a given profiles_dir. - - This operation is delegated to a DbtRemoteHook. An optional connection id is - supported for remotes that require it. - """ - scheme = urlparse(str(profiles_dir)).scheme - remote = self.get_remote(scheme, self.project_conn_id) - - return remote.download_dbt_profiles(profiles_dir, destination)
- -
[docs] def download_dbt_project( - self, - project_dir: URLLike, - destination: URLLike, - ) -> Path: - """Pull a dbt project from a given project_dir. - - This operation is delegated to a DbtRemoteHook. An optional connection id is - supported for remotes that require it. - """ - scheme = urlparse(str(project_dir)).scheme - remote = self.get_remote(scheme, self.project_conn_id) - - return remote.download_dbt_project(project_dir, destination)
- -
[docs] def upload_dbt_project( - self, - project_dir: URLLike, - destination: URLLike, - replace: bool = False, - delete_before: bool = False, - ) -> None: - """Push a dbt project from a given project_dir. - - This operation is delegated to a DbtRemoteHook. An optional connection id is - supported for remotes that require it. - """ - scheme = urlparse(str(destination)).scheme - remote = self.get_remote(scheme, self.project_conn_id) - - return remote.upload_dbt_project( - project_dir, destination, replace=replace, delete_before=delete_before - )
- -
[docs] def run_dbt_task( - self, - command: str, - upload_dbt_project: bool = False, - delete_before_upload: bool = False, - replace_on_upload: bool = False, - artifacts: Optional[Iterable[str]] = None, - env_vars: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> DbtTaskResult: - """Run a dbt task with a given configuration and return the results. - - The configuration used determines the task that will be ran. - - Returns: - A tuple containing a boolean indicating success and optionally the results - of running the dbt command. - """ - from dbt.adapters.factory import register_adapter - from dbt.config.runtime import UnsetProfileConfig - from dbt.main import adapter_management, track_run - from dbt.task.base import move_to_nearest_project_dir - - config = self.get_dbt_task_config(command, **kwargs) - extra_target = self.get_dbt_target_from_connection(config.target) - - with self.dbt_directory( - config, - upload_dbt_project=upload_dbt_project, - delete_before_upload=delete_before_upload, - replace_on_upload=replace_on_upload, - env_vars=env_vars, - ) as dbt_dir: - config.dbt_task.pre_init_hook(config) - self.ensure_profiles(config.profiles_dir) - - task, runtime_config = config.create_dbt_task(extra_target) - - # When creating tasks via from_args, dbt switches to the project directory. - # We have to do that here as we are not using from_args. - move_to_nearest_project_dir(config) - - self.setup_dbt_logging(task, config.debug) - - if not isinstance(runtime_config, UnsetProfileConfig): - if runtime_config is not None: - # The deps command installs the dependencies, which means they may - # not exist before deps runs and the following would raise a - # CompilationError. - runtime_config.load_dependencies() - - results = None - with adapter_management(): - if not isinstance(runtime_config, UnsetProfileConfig): - if runtime_config is not None: - register_adapter(runtime_config) - - with track_run(task): - results = task.run() - success = task.interpret_results(results) - - if artifacts is None: - return DbtTaskResult(success, results, {}) - - saved_artifacts = {} - for artifact in artifacts: - artifact_path = Path(dbt_dir) / "target" / artifact - - if not artifact_path.exists(): - self.log.warn( - "Required dbt artifact %s was not found. " - "Perhaps dbt failed and couldn't generate it.", - artifact, - ) - continue - - with open(artifact_path) as artifact_file: - json_artifact = json.load(artifact_file) - - saved_artifacts[artifact] = json_artifact - - return DbtTaskResult(success, results, saved_artifacts)
- -
[docs] def get_dbt_task_config(self, command: str, **config_kwargs) -> BaseConfig: - """Initialize a configuration for given dbt command with given kwargs.""" - from airflow_dbt_python.utils.configs import ConfigFactory - - return ConfigFactory.from_str(command).create_config(**config_kwargs)
- -
[docs] @contextmanager - def dbt_directory( - self, - config, - upload_dbt_project: bool = False, - delete_before_upload: bool = False, - replace_on_upload: bool = False, - env_vars: Optional[Dict[str, Any]] = None, - ) -> Iterator[str]: - """Provides a temporary directory to execute dbt. - - Creates a temporary directory for dbt to run in and prepares the dbt files - if they need to be pulled from S3. If a S3 backend is being used, and - self.upload_dbt_project is True, before leaving the temporary directory, we push - back the project to S3. Pushing back a project enables commands like deps or - docs generate. - - Yields: - The temporary directory's name. - """ - from airflow_dbt_python.utils.env import update_environment - - store_profiles_dir = config.profiles_dir - store_project_dir = config.project_dir - - with update_environment(env_vars): - with DbtTemporaryDirectory(prefix="airflow_tmp") as tmp_dir: - self.log.info("Initializing temporary directory: %s", tmp_dir) - - try: - project_dir, profiles_dir = self.prepare_directory( - tmp_dir, - store_project_dir, - store_profiles_dir, - ) - except Exception as e: - raise AirflowException( - "Failed to prepare temporary directory for dbt execution" - ) from e - - config.project_dir = project_dir - config.profiles_dir = profiles_dir - - if getattr(config, "state", None) is not None: - state = Path(getattr(config, "state", "")) - # Since we are running in a temporary directory, we need to make - # state paths relative to this temporary directory. - if not state.is_absolute(): - setattr(config, "state", str(Path(tmp_dir) / state)) - - yield tmp_dir - - if upload_dbt_project is True: - self.log.info("Uploading dbt project to: %s", store_project_dir) - self.upload_dbt_project( - tmp_dir, - store_project_dir, - replace=replace_on_upload, - delete_before=delete_before_upload, - ) - - config.profiles_dir = store_profiles_dir - config.project_dir = store_project_dir
- -
[docs] def prepare_directory( - self, - tmp_dir: str, - project_dir: URLLike, - profiles_dir: Optional[URLLike] = None, - ) -> tuple[str, Optional[str]]: - """Prepares a dbt directory for execution of a dbt task. - - Preparation involves downloading the required dbt project files and - profiles.yml. - """ - project_dir_path = self.download_dbt_project( - project_dir, - tmp_dir, - ) - new_project_dir = str(project_dir_path) + "/" - - if (project_dir_path / "profiles.yml").exists(): - # We may have downloaded the profiles.yml file together - # with the project. - return (new_project_dir, new_project_dir) - - if profiles_dir is not None: - profiles_file_path = self.download_dbt_profiles( - profiles_dir, - tmp_dir, - ) - new_profiles_dir = str(profiles_file_path.parent) + "/" - else: - new_profiles_dir = None - - return (new_project_dir, new_profiles_dir)
- -
[docs] def setup_dbt_logging(self, task: BaseTask, debug: Optional[bool]): - """Setup dbt logging. - - Starting with dbt v1, dbt initializes two loggers: default_file and - default_stdout. As these are initialized by the CLI app, we need to - initialize them here. - """ - from dbt.events.functions import setup_event_logger - - log_path = None - if task.config is not None: - log_path = getattr(task.config, "log_path", None) - - setup_event_logger(log_path or "logs") - - configured_file = logging.getLogger("configured_file") - file_log = logging.getLogger("file_log") - - if not debug: - # We have to do this after setting logs up as dbt hasn't - # configured the loggers before the call to setup_event_logger. - # In the future, handlers may also be cleared or setup to use Airflow's. - file_log.setLevel("INFO") - file_log.propagate = False - configured_file.setLevel("INFO") - configured_file.propagate = False
- -
[docs] def ensure_profiles(self, profiles_dir: Optional[str]): - """Ensure a profiles file exists.""" - if profiles_dir is not None: - # We expect one to exist given that we have passsed a profiles_dir. - return - - profiles_path = Path.home() / ".dbt/profiles.yml" - if not profiles_path.exists(): - profiles_path.parent.mkdir(exist_ok=True) - profiles_path.touch()
- -
[docs] def get_dbt_target_from_connection( - self, target: Optional[str] - ) -> Optional[dict[str, Any]]: - """Return a dictionary of connection details to use as a dbt target. - - The connection details are fetched from an Airflow connection identified by - target or self.dbt_conn_id. - - Args: - target: The target name to use as an Airflow connection ID. If ommitted, we - will use self.dbt_conn_id. - - Returns: - A dictionary with a configuration for a dbt target, or None if a matching - Airflow connection is not found for given dbt target. - """ - conn_id = target or self.dbt_conn_id - - if conn_id is None: - return None - - try: - conn = self.get_connection(conn_id) - except AirflowException: - self.log.debug( - "No Airflow connection matching dbt target %s was found.", target - ) - return None - - details = self.get_dbt_details_from_connection(conn) - - return {conn_id: details}
- -
[docs] def get_dbt_details_from_connection(self, conn: Connection) -> dict[str, Any]: - """Extract dbt connection details from Airflow Connection. - - dbt connection details may be present as Airflow Connection attributes or in the - Connection's extras. This class' conn_params and conn_extra_params will be used - to fetch required attributes from attributes and extras respectively. If - conn_extra_params is empty, we merge parameters with all extras. - - Subclasses may override this class attributes to narrow down the connection - details for a specific dbt target (like Postgres, or Redshift). - - Args: - conn: The Airflow Connection to extract dbt connection details from. - - Returns: - A dictionary of dbt connection details. - """ - dbt_details = {} - for param in self.conn_params: - if isinstance(param, DbtConnectionParam): - key = param.override_name - value = getattr(conn, param.name, param.default) - else: - key = param - value = getattr(conn, key, None) - - if value is None: - continue - - dbt_details[key] = value - - extra = conn.extra_dejson - - if not self.conn_extra_params: - return {**dbt_details, **extra} - - for param in self.conn_extra_params: - if isinstance(param, DbtConnectionParam): - key = param.override_name - value = extra.get(param.name, param.default) - else: - key = param - value = extra.get(key, None) - - if value is None: - continue - - dbt_details[key] = value - - return dbt_details
- - -
[docs]class DbtPostgresHook(DbtHook): - """A hook to interact with dbt using a Postgres connection.""" - - conn_type = "postgres" - hook_name = "dbt Postgres Hook" - conn_params = [ - DbtConnectionParam("conn_type", "type", "postgres"), - "host", - "schema", - DbtConnectionParam("login", "user"), - "password", - "port", - ] - conn_extra_params = [ - "dbname", - "threads", - "keepalives_idle", - "connect_timeout", - "retries", - "search_path", - "role", - "sslmode", - ]
- - -
[docs]class DbtRedshiftHook(DbtPostgresHook): - """A hook to interact with dbt using a Redshift connection.""" - - conn_type = "redshift" - hook_name = "dbt Redshift Hook" - conn_extra_params = DbtPostgresHook.conn_extra_params + [ - "ra3_node", - "iam_profile", - "iam_duration_secons", - "autocreate", - "db_groups", - ]
- - -
[docs]class DbtSnowflakeHook(DbtHook): - """A hook to interact with dbt using a Snowflake connection.""" - - conn_type = "snowflake" - hook_name = "dbt Snowflake Hook" - conn_params = [ - DbtConnectionParam("conn_type", "type", "postgres"), - "host", - "schema", - DbtConnectionParam("login", "user"), - "password", - ] - conn_extra_params = [ - "account", - "role", - "database", - "warehouse", - "threads", - "client_session_keep_alive", - "query_tag", - "connect_retries", - "connect_timeout", - "retry_on_database_errors", - "retry_all", - ]
-
- -
-