Skip to content

Commit

Permalink
Add async run operator
Browse files Browse the repository at this point in the history
Remove print stmt

Fix query
Fix query

Remove oss execute method code
  • Loading branch information
pankajkoti authored and pankajastro committed Sep 30, 2024
1 parent 1068025 commit faa706d
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 7 deletions.
14 changes: 14 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import Any, Callable, Iterator

import yaml
from airflow.version import version as airflow_version

from cosmos.cache import create_cache_profile, get_cached_profile, is_profile_cache_enabled
Expand Down Expand Up @@ -286,6 +287,19 @@ def validate_profiles_yml(self) -> None:
if self.profiles_yml_filepath and not Path(self.profiles_yml_filepath).exists():
raise CosmosValueError(f"The file {self.profiles_yml_filepath} does not exist.")

def get_profile_type(self):
if self.profile_mapping.dbt_profile_type:
return self.profile_mapping.dbt_profile_type

profile_path = self._get_profile_path()

with open(profile_path) as file:
profiles = yaml.safe_load(file)

profile = profiles[self.profile_name]
target_type = profile["outputs"][self.target_name]["type"]
return target_type

def _get_profile_path(self, use_mock_values: bool = False) -> Path:
"""
Handle the profile caching mechanism.
Expand Down
3 changes: 3 additions & 0 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(
task_group: TaskGroup | None = None,
operator_args: dict[str, Any] | None = None,
on_warning_callback: Callable[..., Any] | None = None,
async_op_args: dict[str, Any] | None = None,
*args: Any,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -256,6 +257,7 @@ def __init__(
cache_identifier=cache_identifier,
dbt_vars=dbt_vars,
airflow_metadata=cache._get_airflow_metadata(dag, task_group),
async_op_args=async_op_args,
)
self.dbt_graph.load(method=render_config.load_method, execution_mode=execution_config.execution_mode)

Expand Down Expand Up @@ -301,6 +303,7 @@ def __init__(
dbt_project_name=project_config.project_name,
on_warning_callback=on_warning_callback,
render_config=render_config,
async_op_args=async_op_args,
)

current_time = time.perf_counter()
Expand Down
7 changes: 6 additions & 1 deletion cosmos/core/airflow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import importlib

from airflow.models import BaseOperator
Expand All @@ -10,7 +12,7 @@
logger = get_logger(__name__)


def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None) -> BaseOperator:
def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None) -> BaseOperator:
"""
Get the Airflow Operator class for a Task.
Expand All @@ -29,6 +31,9 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None
if task.owner != "":
task_kwargs["owner"] = task.owner

if task.async_op_args:
task_kwargs["async_op_args"] = task.async_op_args

airflow_task = Operator(
task_id=task.id,
dag=dag,
Expand Down
1 change: 1 addition & 0 deletions cosmos/core/graph/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,4 @@ class Task(CosmosEntity):
operator_class: str = "airflow.operators.empty.EmptyOperator"
arguments: Dict[str, Any] = field(default_factory=dict)
extra_context: Dict[str, Any] = field(default_factory=dict)
async_op_args: Dict[str, Any] = field(default_factory=dict)
3 changes: 2 additions & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,15 @@ def __init__(
dbt_vars: dict[str, str] | None = None,
airflow_metadata: dict[str, str] | None = None,
operator_args: dict[str, Any] | None = None,
async_op_args: dict[str, Any] | None = None,
):
self.project = project
self.render_config = render_config
self.profile_config = profile_config
self.execution_config = execution_config
self.cache_dir = cache_dir
self.airflow_metadata = airflow_metadata or {}
self.async_op_args = async_op_args
if cache_identifier:
self.dbt_ls_cache_key = cache.create_cache_key(cache_identifier)
else:
Expand Down Expand Up @@ -467,7 +469,6 @@ def should_use_dbt_ls_cache(self) -> bool:

def load_via_dbt_ls_cache(self) -> bool:
"""(Try to) load dbt ls cache from an Airflow Variable"""

logger.info(f"Trying to parse the dbt project using dbt ls cache {self.dbt_ls_cache_key}...")
if self.should_use_dbt_ls_cache():
project_path = self.project_path
Expand Down
76 changes: 71 additions & 5 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,34 @@
from __future__ import annotations

from typing import Any

from airflow.io.path import ObjectStoragePath
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos.operators.base import DbtCompileMixin
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtCompileLocalOperator,
DbtDepsLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsCloudLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLocalBaseOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
DbtSeedLocalOperator,
DbtSnapshotLocalOperator,
DbtSourceLocalOperator,
DbtTestLocalOperator,
DbtRunOperationLocalOperator,
DbtCompileLocalOperator,
)

from cosmos.settings import remote_target_path, remote_target_path_conn_id

_SUPPORTED_DATABASES = ["bigquery"]


class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator):
pass
Expand All @@ -35,8 +50,50 @@ class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator):
pass


class DbtRunAirflowAsyncOperator(DbtRunLocalOperator):
pass
class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator):
def __init__(self, *args, **kwargs):
self.configuration = {}
self.job_id = kwargs.get("job_id", {}) or ""
self.impersonation_chain = kwargs.get("impersonation_chain", {}) or ""
self.project_id = kwargs.get("project_id", {}) or ""

self.profile_config = kwargs.get("profile_config")
self.project_dir = kwargs.get("project_dir")

self.async_op_args = kwargs.get("async_op_args", {})
self.async_op_args["deferrable"] = True
super().__init__(*args, configuration=self.configuration, task_id=kwargs.get("task_id"), **self.async_op_args)
self.profile_type = self.profile_config.get_profile_type()
if self.profile_type not in _SUPPORTED_DATABASES:
raise f"Async run are only supported: {_SUPPORTED_DATABASES}"

self.reattach_states: set[str] = self.async_op_args.get("reattach_states") or set()

def get_remote_sql(self):
project_name = str(self.project_dir).split("/")[-1]
model_name: str = self.task_id.split(".")[0]
if model_name.startswith("stg_"):
remote_model_path = f"{remote_target_path}/{project_name}/models/staging/{model_name}.sql"
else:
remote_model_path = f"{remote_target_path}/{project_name}/models/{model_name}.sql"

print("remote_model_path: ", remote_model_path)
object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp:
return fp.read()

def execute(self, context: Context) -> Any | None:
sql = self.get_remote_sql()
print("sql: ", sql)
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
print("async_op_args: ", self.async_op_args)
super().execute(context)



class DbtTestAirflowAsyncOperator(DbtTestLocalOperator):
Expand All @@ -51,6 +108,10 @@ class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator):
pass


class DbtDocsCloudAirflowAsyncOperator(DbtDocsCloudLocalOperator):
pass


class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator):
pass

Expand All @@ -65,3 +126,8 @@ class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator):

class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator):
pass

class DbtDepsAirflowAsyncOperator(DbtDepsLocalOperator):
pass


2 changes: 2 additions & 0 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
dbt_cmd_global_flags: list[str] | None = None,
cache_dir: Path | None = None,
extra_context: dict[str, Any] | None = None,
# configuration: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.project_dir = project_dir
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(
self.cache_dir = cache_dir
self.extra_context = extra_context or {}
kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes
# kwargs["configuration"] = {}
super().__init__(**kwargs)

def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]:
Expand Down
4 changes: 4 additions & 0 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(
should_store_compiled_sql: bool = True,
should_upload_compiled_sql: bool = False,
append_env: bool = True,
async_op_args: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.profile_config = profile_config
Expand All @@ -148,6 +149,7 @@ def __init__(
self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult]
self.handle_exception: Callable[..., None]
self._dbt_runner: dbtRunner | None = None
self.async_op_args = async_op_args
if self.invocation_mode:
self._set_invocation_methods()
super().__init__(**kwargs)
Expand Down Expand Up @@ -289,6 +291,7 @@ def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]:

return _configured_target_path, remote_conn_id


def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
"""
Uploads the compiled SQL files from the dbt compile output to the remote store.
Expand All @@ -297,6 +300,7 @@ def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
return

dest_target_dir, dest_conn_id = self._configure_remote_target_path()

if not dest_target_dir:
raise CosmosValueError(
"You're trying to upload compiled SQL files, but the remote target path is not configured. "
Expand Down

0 comments on commit faa706d

Please sign in to comment.