From 5f1ecaac695b2fafff28771fb67e9bf62833dadb Mon Sep 17 00:00:00 2001 From: pankajastro Date: Mon, 30 Sep 2024 19:58:16 +0530 Subject: [PATCH] Fix query --- cosmos/operators/airflow_async.py | 62 ++++++++++++++----- .../jaffle_shop/models/staging/sources.yml | 31 ---------- 2 files changed, 46 insertions(+), 47 deletions(-) delete mode 100644 dev/dags/dbt/jaffle_shop/models/staging/sources.yml diff --git a/cosmos/operators/airflow_async.py b/cosmos/operators/airflow_async.py index 4f1fae227..48caaf97f 100644 --- a/cosmos/operators/airflow_async.py +++ b/cosmos/operators/airflow_async.py @@ -3,9 +3,11 @@ from typing import Any from airflow.io.path import ObjectStoragePath +from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator from airflow.utils.context import Context +from cosmos.exceptions import CosmosValueError from cosmos.operators.local import ( DbtBuildLocalOperator, DbtCompileLocalOperator, @@ -47,25 +49,31 @@ class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator): pass -class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator): - def __init__(self, *args, **kwargs): - self.configuration = {} - self.job_id = kwargs.get("job_id", {}) or "" - self.impersonation_chain = kwargs.get("impersonation_chain", {}) or "" - self.project_id = kwargs.get("project_id", {}) or "" - +class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator): # + def __init__(self, *args, full_refresh: bool = False, **kwargs): + # dbt task param self.profile_config = kwargs.get("profile_config") self.project_dir = kwargs.get("project_dir") + self.profile_type = self.profile_config.get_profile_type() + self.full_refresh = full_refresh - self.async_op_args = kwargs.get("async_op_args", {}) + # airflow task param + self.async_op_args = kwargs.pop("async_op_args", {}) + self.configuration = {} + self.job_id = self.async_op_args.get("job_id", "") + self.impersonation_chain = self.async_op_args.get("impersonation_chain", "") + self.gcp_project = self.async_op_args.get("project_id", "astronomer-dag-authoring") + self.gcp_conn_id = self.profile_config.profile_mapping.conn_id + self.dataset = self.async_op_args.get("dataset", "my_dataset") + self.location = self.async_op_args.get("location", "US") self.async_op_args["deferrable"] = True + self.reattach_states: set[str] = self.async_op_args.get("reattach_states") or set() + super().__init__(*args, configuration=self.configuration, task_id=kwargs.get("task_id"), **self.async_op_args) - self.profile_type = self.profile_config.get_profile_type() + if self.profile_type not in _SUPPORTED_DATABASES: raise f"Async run are only supported: {_SUPPORTED_DATABASES}" - self.reattach_states: set[str] = self.async_op_args.get("reattach_states") or set() - def get_remote_sql(self): project_name = str(self.project_dir).split("/")[-1] model_name: str = self.task_id.split(".")[0] @@ -79,17 +87,39 @@ def get_remote_sql(self): with object_storage_path.open() as fp: return fp.read() - def execute(self, context: Context) -> Any | None: - sql = self.get_remote_sql() - print("sql: ", sql) + def drop_table_sql(self): + model_name = self.task_id.split(".")[0] + sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};" + hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) self.configuration = { "query": { "query": sql, "useLegacySql": False, } } - print("async_op_args: ", self.async_op_args) - super().execute(context) + hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project) + + def execute(self, context: Context) -> Any | None: + if not self.full_refresh: + raise CosmosValueError("The async execution only supported for full_refresh") + else: + self.drop_table_sql() + + sql = self.get_remote_sql() + model_name = self.task_id.split(".")[0] + # prefix explicit create command to create table + sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}" + + self.configuration = { + "query": { + "query": sql, + "useLegacySql": False, + } + } + super().execute(context) class DbtTestAirflowAsyncOperator(DbtTestLocalOperator): diff --git a/dev/dags/dbt/jaffle_shop/models/staging/sources.yml b/dev/dags/dbt/jaffle_shop/models/staging/sources.yml deleted file mode 100644 index a3139b585..000000000 --- a/dev/dags/dbt/jaffle_shop/models/staging/sources.yml +++ /dev/null @@ -1,31 +0,0 @@ - -version: 2 - -sources: - - name: postgres_db - database: "{{ env_var('POSTGRES_DB') }}" - schema: "{{ env_var('POSTGRES_SCHEMA') }}" - tables: - - name: raw_customers - columns: - - name: id - tests: - - unique - - not_null - - name: raw_payments - columns: - - name: id - tests: - - unique - - not_null - - name: raw_orders - columns: - - name: id - tests: - - unique - - not_null - freshness: - warn_after: - count: 3650 - period: day - loaded_at_field: CAST(order_date AS TIMESTAMP)