Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental BQ support to run dbt models with ExecutionMode.AIRFLOW_ASYNC #1230

Merged
merged 51 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
851564f
Draft: dbt compile task
pankajkoti Sep 25, 2024
9dc2c9c
Put compiled files under dag_id folder & refactor few snippets
pankajkoti Sep 29, 2024
0ce662e
Add tests & minor refactorings
pankajkoti Sep 29, 2024
1b6f57e
Apply suggestions from code review
pankajkoti Sep 29, 2024
cc48161
Install deps for the newly added example DAG
pankajkoti Sep 29, 2024
1068025
Add docs
pankajkoti Sep 30, 2024
faa706d
Add async run operator
pankajkoti Sep 25, 2024
0e155e4
Fix remote sql path and async args
pankajastro Sep 30, 2024
5f1ecaa
Fix query
pankajastro Sep 30, 2024
1278847
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
b3d6cf3
Use dbt node's filepath to construct remote path to fetch compiled SQ…
pankajkoti Sep 30, 2024
78bc069
Merge branch 'main' into execute-async-task
tatiana Sep 30, 2024
9ca5e85
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2024
99bf7c0
Fix unittests
tatiana Sep 30, 2024
3aaaf9e
Improve code
tatiana Sep 30, 2024
43158be
Working with deferrable=False, not working with deferrable=True
tatiana Oct 1, 2024
83b1010
Working with deferrable=False, not working with deferrable=True
tatiana Oct 1, 2024
bd6657a
Fix issue when using BQ deferrable operator - it requires location
tatiana Oct 1, 2024
1195955
Add limitation in docs
pankajastro Oct 1, 2024
2bdd9bb
Add full_refresh as templated field
pankajastro Oct 1, 2024
4a44603
Add more template fields
pankajastro Oct 1, 2024
c3c51cb
Construct & relay 'dbt dag-task group' identifier to upload & downloa…
pankajkoti Oct 1, 2024
72c6164
Fix model_name retrieval; get from dbt_node_config
pankajkoti Oct 1, 2024
e67098e
Fix unit tests
pankajkoti Oct 1, 2024
3e550bf
Fix subsequent failing unit tests
pankajkoti Oct 1, 2024
0730d0f
Fix type check failures
pankajkoti Oct 1, 2024
745768e
Add back the deleted sources.yml from jaffle_shop as it has dependenc…
pankajkoti Oct 1, 2024
43d62ea
Install dbt bigquery adapter for running simple_dag_async
pankajkoti Oct 1, 2024
9656248
Install dbt bigquery adapter in our CI setup scripts
pankajkoti Oct 1, 2024
a654f49
Update gcp conn in dev/dags/simple_dag_async.py
pankajkoti Oct 1, 2024
e60ace2
Refactor args in DbtRunAirflowAsyncOperator
tatiana Oct 1, 2024
7f055bc
Use GoogleCloudServiceAccountDictProfileMapping in profilemapping
pankajkoti Oct 1, 2024
ad057c8
set should_upload_compiled_sql to True
pankajkoti Oct 1, 2024
a70ca46
Remove async_op_args
tatiana Oct 1, 2024
7c6a1b2
remove install_deps from DAG
pankajkoti Oct 1, 2024
64a31d0
Merge branch 'main' into execute-async-task
tatiana Oct 1, 2024
c1aeff0
Fix test_build_airflow_graph_with_dbt_compile_task by passing needed …
pankajkoti Oct 1, 2024
02f7985
Specify required project id in the GoogleCloudServiceAccountDictProfi…
pankajkoti Oct 2, 2024
af454a9
Pass gcp_conn_id to super class init, otherwise it is lost & uses the…
pankajkoti Oct 2, 2024
9081e6a
Adapt manifest DAG to use & adapt to the newer GCP conn secret that i…
pankajkoti Oct 2, 2024
2dccf84
Release 1.7.0a1
tatiana Oct 2, 2024
7adeb99
Retrigger GH actions
tatiana Oct 2, 2024
7e6de30
temporarily move out simple_dag_async.py
tatiana Oct 2, 2024
16a87ea
Fix CI issue
tatiana Oct 2, 2024
05db6a0
Fix dbt-compile dependency by using Airflow tasks instead of dbt nodes
pankajkoti Oct 2, 2024
8fc4ae2
Apply suggestions from code review
pankajkoti Oct 2, 2024
ea5816b
Apply suggestions from code review
pankajkoti Oct 2, 2024
85f86a4
Add install instruction
pankajastro Oct 3, 2024
402f823
Add min airflow version in limitation
pankajastro Oct 3, 2024
621a4de
Ignore Async DAG for dbt <=1.5
pankajastro Oct 3, 2024
a0cb147
Ignore Async DAG for dbt <=1.5
pankajastro Oct 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: test

on:
push: # Run on pushes to the default branch
branches: [main]
branches: [main,poc-dbt-compile-task]
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
pull_request_target: # Also run on pull requests originated from forks
branches: [main]

Expand Down
4 changes: 4 additions & 0 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def create_test_task_metadata(
on_warning_callback: Callable[..., Any] | None = None,
node: DbtNode | None = None,
render_config: RenderConfig | None = None,
async_op_args: dict | None = None,
) -> TaskMetadata:
"""
Create the metadata that will be used to instantiate the Airflow Task that will be used to run the Dbt test node.
Expand Down Expand Up @@ -125,6 +126,7 @@ def create_test_task_metadata(
),
arguments=task_args,
extra_context=extra_context,
async_op_args=async_op_args,
)


Expand Down Expand Up @@ -288,6 +290,7 @@ def build_airflow_graph(
render_config: RenderConfig,
task_group: TaskGroup | None = None,
on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command
async_op_args: dict | None = None,
) -> None:
"""
Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory).
Expand Down Expand Up @@ -352,6 +355,7 @@ def build_airflow_graph(
task_args=task_args,
on_warning_callback=on_warning_callback,
render_config=render_config,
async_op_args=async_op_args,
)
test_task = create_airflow_task(test_meta, dag, task_group=task_group)
leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes)
Expand Down
16 changes: 16 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import Any, Callable, Iterator

import yaml
from airflow.version import version as airflow_version

from cosmos.cache import create_cache_profile, get_cached_profile, is_profile_cache_enabled
Expand Down Expand Up @@ -286,6 +287,21 @@ def validate_profiles_yml(self) -> None:
if self.profiles_yml_filepath and not Path(self.profiles_yml_filepath).exists():
raise CosmosValueError(f"The file {self.profiles_yml_filepath} does not exist.")

def get_profile_type(self) -> str:
if isinstance(self.profile_mapping, BaseProfileMapping):
return str(self.profile_mapping.dbt_profile_type)

profile_path = self._get_profile_path()

with open(profile_path) as file:
profiles = yaml.safe_load(file)

profile = profiles[self.profile_name]
target_type = profile["outputs"][self.target_name]["type"]
return str(target_type)

return "undefined"

def _get_profile_path(self, use_mock_values: bool = False) -> Path:
"""
Handle the profile caching mechanism.
Expand Down
3 changes: 3 additions & 0 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(
task_group: TaskGroup | None = None,
operator_args: dict[str, Any] | None = None,
on_warning_callback: Callable[..., Any] | None = None,
async_op_args: dict[str, Any] | None = None,
*args: Any,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -256,6 +257,7 @@ def __init__(
cache_identifier=cache_identifier,
dbt_vars=dbt_vars,
airflow_metadata=cache._get_airflow_metadata(dag, task_group),
async_op_args=async_op_args,
)
self.dbt_graph.load(method=render_config.load_method, execution_mode=execution_config.execution_mode)

Expand Down Expand Up @@ -301,6 +303,7 @@ def __init__(
dbt_project_name=project_config.project_name,
on_warning_callback=on_warning_callback,
render_config=render_config,
async_op_args=async_op_args,
)

current_time = time.perf_counter()
Expand Down
3 changes: 3 additions & 0 deletions cosmos/core/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None)
if task.owner != "":
task_kwargs["owner"] = task.owner

if task.async_op_args:
task_kwargs["async_op_args"] = task.async_op_args

airflow_task = Operator(
task_id=task.id,
dag=dag,
Expand Down
1 change: 1 addition & 0 deletions cosmos/core/graph/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,4 @@ class Task(CosmosEntity):
operator_class: str = "airflow.operators.empty.EmptyOperator"
arguments: Dict[str, Any] = field(default_factory=dict)
extra_context: Dict[str, Any] = field(default_factory=dict)
async_op_args: Dict[str, Any] = field(default_factory=dict)
3 changes: 2 additions & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,15 @@ def __init__(
dbt_vars: dict[str, str] | None = None,
airflow_metadata: dict[str, str] | None = None,
operator_args: dict[str, Any] | None = None,
async_op_args: dict[str, Any] | None = None,
):
self.project = project
self.render_config = render_config
self.profile_config = profile_config
self.execution_config = execution_config
self.cache_dir = cache_dir
self.airflow_metadata = airflow_metadata or {}
self.async_op_args = async_op_args
if cache_identifier:
self.dbt_ls_cache_key = cache.create_cache_key(cache_identifier)
else:
Expand Down Expand Up @@ -467,7 +469,6 @@ def should_use_dbt_ls_cache(self) -> bool:

def load_via_dbt_ls_cache(self) -> bool:
"""(Try to) load dbt ls cache from an Airflow Variable"""

logger.info(f"Trying to parse the dbt project using dbt ls cache {self.dbt_ls_cache_key}...")
if self.should_use_dbt_ls_cache():
project_path = self.project_path
Expand Down
111 changes: 88 additions & 23 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
from __future__ import annotations

from typing import Any

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos import settings
from cosmos.exceptions import CosmosValueError
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtCompileLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
DbtSeedLocalOperator,
DbtSnapshotLocalOperator,
DbtSourceLocalOperator,
DbtTestLocalOperator,
)
from cosmos.settings import remote_target_path, remote_target_path_conn_id

_SUPPORTED_DATABASES = ["bigquery"]


class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator):
Expand All @@ -35,8 +43,81 @@ class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator):
pass


class DbtRunAirflowAsyncOperator(DbtRunLocalOperator):
pass
class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator): # type: ignore
def __init__(self, *args, full_refresh: bool = False, **kwargs): # type: ignore
# dbt task param
self.profile_config = kwargs.get("profile_config")
self.project_dir = kwargs.get("project_dir")
self.file_path = kwargs.get("extra_context", {}).get("dbt_node_config", {}).get("file_path")
self.profile_type: str = self.profile_config.get_profile_type() # type: ignore
self.full_refresh = full_refresh

# airflow task param
self.async_op_args = kwargs.pop("async_op_args", {})
self.configuration: dict[str, object] = {}
self.job_id = self.async_op_args.get("job_id", "")
self.impersonation_chain = self.async_op_args.get("impersonation_chain", "")
self.gcp_project = self.async_op_args.get("project_id", "astronomer-dag-authoring")
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
self.dataset = self.async_op_args.get("dataset", "my_dataset")
self.location = self.async_op_args.get("location", "US")
self.async_op_args["deferrable"] = True
self.reattach_states: set[str] = self.async_op_args.get("reattach_states") or set()

super().__init__(*args, configuration=self.configuration, task_id=kwargs.get("task_id"), **self.async_op_args)

if self.profile_type not in _SUPPORTED_DATABASES:
raise CosmosValueError(f"Async run are only supported: {_SUPPORTED_DATABASES}")

def get_remote_sql(self) -> str:
if not settings.AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.")
from airflow.io.path import ObjectStoragePath

if not self.file_path or not self.project_dir:
raise CosmosValueError("file_path and project_dir are required to be set on the task for async execution")
project_dir_parent = str(self.project_dir.parent)
relative_file_path = str(self.file_path).replace(project_dir_parent, "").lstrip("/")
remote_model_path = f"{str(remote_target_path).rstrip('/')}/{self.dag_id}/{relative_file_path}"

print("remote_model_path: ", remote_model_path)
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp: # type: ignore
return fp.read() # type: ignore

def drop_table_sql(self) -> None:
model_name = self.task_id.split(".")[0]
sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};"
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project)

def execute(self, context: Context) -> Any | None:
if not self.full_refresh:
raise CosmosValueError("The async execution only supported for full_refresh")
else:
self.drop_table_sql()

sql = self.get_remote_sql()
model_name = self.task_id.split(".")[0]
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
# prefix explicit create command to create table
sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}"

self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return super().execute(context)


class DbtTestAirflowAsyncOperator(DbtTestLocalOperator):
Expand All @@ -47,21 +128,5 @@ class DbtRunOperationAirflowAsyncOperator(DbtRunOperationLocalOperator):
pass


class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator):
pass


class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator):
pass


class DbtDocsAzureStorageAirflowAsyncOperator(DbtDocsAzureStorageLocalOperator):
pass


class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator):
pass


class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator):
pass
2 changes: 2 additions & 0 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
dbt_cmd_global_flags: list[str] | None = None,
cache_dir: Path | None = None,
extra_context: dict[str, Any] | None = None,
# configuration: dict[str, Any] | None = None,
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
) -> None:
self.project_dir = project_dir
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(
self.cache_dir = cache_dir
self.extra_context = extra_context or {}
kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes
# kwargs["configuration"] = {}
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(**kwargs)

def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]:
Expand Down
8 changes: 6 additions & 2 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from cosmos.dataset import get_dataset_alias_name
from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.exceptions import AirflowCompatibilityError, CosmosValueError
from cosmos.settings import AIRFLOW_IO_AVAILABLE, remote_target_path, remote_target_path_conn_id
from cosmos.settings import remote_target_path, remote_target_path_conn_id

try:
from airflow.datasets import Dataset
Expand Down Expand Up @@ -143,6 +143,7 @@ def __init__(
should_store_compiled_sql: bool = True,
should_upload_compiled_sql: bool = False,
append_env: bool = True,
async_op_args: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.task_id = task_id
Expand All @@ -157,6 +158,7 @@ def __init__(
self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult]
self.handle_exception: Callable[..., None]
self._dbt_runner: dbtRunner | None = None
self.async_op_args = async_op_args
if self.invocation_mode:
self._set_invocation_methods()

Expand Down Expand Up @@ -294,7 +296,7 @@ def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]:
if remote_conn_id is None:
return None, None

if not AIRFLOW_IO_AVAILABLE:
if not settings.AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(
f"You're trying to specify remote target path {target_path_str}, but the required "
f"Object Storage feature is unavailable in Airflow version {airflow_version}. Please upgrade to "
Expand Down Expand Up @@ -340,6 +342,7 @@ def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
return

dest_target_dir, dest_conn_id = self._configure_remote_target_path()

if not dest_target_dir:
raise CosmosValueError(
"You're trying to upload compiled SQL files, but the remote target path is not configured. "
Expand All @@ -349,6 +352,7 @@ def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:

source_compiled_dir = Path(tmp_project_dir) / "target" / "compiled"
files = [str(file) for file in source_compiled_dir.rglob("*") if file.is_file()]

for file_path in files:
dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_compiled_dir, context)
dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id)
Expand Down
1 change: 0 additions & 1 deletion cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
# This will be merged with the `cache_dir` config parameter in upcoming releases.
remote_cache_dir = conf.get("cosmos", "remote_cache_dir", fallback=None)
remote_cache_dir_conn_id = conf.get("cosmos", "remote_cache_dir_conn_id", fallback=None)

remote_target_path = conf.get("cosmos", "remote_target_path", fallback=None)
remote_target_path_conn_id = conf.get("cosmos", "remote_target_path_conn_id", fallback=None)

Expand Down
31 changes: 0 additions & 31 deletions dev/dags/dbt/jaffle_shop/models/staging/sources.yml

This file was deleted.

1 change: 1 addition & 0 deletions docs/getting_started/execution-modes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ Each task will create a new Cloud Run Job execution, giving full isolation. The
},
)


Airflow Async (experimental)
----------------------------

Expand Down
Loading
Loading