Skip to content

Commit

Permalink
chore(pipeline): generate airflow dag from dbt
Browse files Browse the repository at this point in the history
  • Loading branch information
vmttn committed Aug 12, 2024
1 parent c04c2ce commit 1b84641
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 122 deletions.
22 changes: 4 additions & 18 deletions pipeline/dags/dag_utils/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,10 @@ def dbt_operator_factory(
)


def get_staging_tasks(schedule=None):
def get_staging_tasks():
task_list = []

for source_id, src_meta in sorted(SOURCES_CONFIGS.items()):
if schedule and src_meta["schedule"] != schedule:
continue

for source_id in sorted(SOURCES_CONFIGS):
dbt_source_id = source_id.replace("-", "_")

stg_selector = f"path:models/staging/sources/**/stg_{dbt_source_id}__*.sql"
Expand Down Expand Up @@ -89,9 +86,9 @@ def get_staging_tasks(schedule=None):
return task_list


def get_before_geocoding_tasks():
def get_intermediate_tasks():
return dbt_operator_factory(
task_id="dbt_build_before_geocoding",
task_id="dbt_build_intermediate",
command="build",
select=" ".join(
[
Expand All @@ -104,17 +101,6 @@ def get_before_geocoding_tasks():
"path:models/intermediate/int__union_adresses.sql",
"path:models/intermediate/int__union_services.sql",
"path:models/intermediate/int__union_structures.sql",
]
),
)


def get_after_geocoding_tasks():
return dbt_operator_factory(
task_id="dbt_build_after_geocoding",
command="build",
select=" ".join(
[
"path:models/intermediate/extra",
"path:models/intermediate/int__deprecated_sirets.sql",
"path:models/intermediate/int__plausible_personal_emails.sql",
Expand Down
145 changes: 44 additions & 101 deletions pipeline/dags/main.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,60 @@
import cosmos
import cosmos.airflow
import cosmos.constants
import cosmos.profiles
import pendulum

import airflow
from airflow.operators import empty, python
from airflow.models import Variable
from airflow.operators import empty

from dag_utils import date, marts
from dag_utils.dbt import (
dbt_operator_factory,
get_after_geocoding_tasks,
get_before_geocoding_tasks,
get_staging_tasks,
)
from dag_utils import date
from dag_utils.notifications import format_failure, notify_webhook
from dag_utils.virtualenvs import PYTHON_BIN_PATH
from dag_utils.virtualenvs import DBT_PYTHON_BIN_PATH

default_args = {
"retries": 2,
"on_failure_callback": lambda context: notify_webhook(
context, "mattermost", format_failure
)
),
}


def _geocode():
import logging

import sqlalchemy as sqla

from airflow.models import Variable
from airflow.providers.postgres.hooks.postgres import PostgresHook

from dag_utils import geocoding
from dag_utils.sources import utils

logger = logging.getLogger(__name__)

pg_hook = PostgresHook(postgres_conn_id="pg")

# 1. Retrieve input data
input_df = pg_hook.get_pandas_df(
sql="""
SELECT
_di_surrogate_id,
adresse,
code_postal,
commune
FROM public_intermediate.int__union_adresses;
"""
)

utils.log_df_info(input_df, logger=logger)

geocoding_backend = geocoding.BaseAdresseNationaleBackend(
base_url=Variable.get("BAN_API_URL")
)

# 2. Geocode
output_df = geocoding_backend.geocode(input_df)

utils.log_df_info(output_df, logger=logger)

# 3. Write result back
engine = pg_hook.get_sqlalchemy_engine()

with engine.connect() as conn:
with conn.begin():
output_df.to_sql(
"extra__geocoded_results",
schema="public",
con=conn,
if_exists="replace",
index=False,
dtype={
"latitude": sqla.Float,
"longitude": sqla.Float,
"result_score": sqla.Float,
},
)


with airflow.DAG(
dag = cosmos.DbtDag(
dag_id="main",
start_date=pendulum.datetime(2022, 1, 1, tz=date.TIME_ZONE),
default_args=default_args,
schedule="0 4 * * *",
catchup=False,
concurrency=4,
) as dag:
start = empty.EmptyOperator(task_id="start")
end = empty.EmptyOperator(task_id="end")

dbt_seed = dbt_operator_factory(
task_id="dbt_seed",
command="seed",
)

dbt_create_udfs = dbt_operator_factory(
task_id="dbt_create_udfs",
command="run-operation create_udfs",
)

python_geocode = python.ExternalPythonOperator(
task_id="python_geocode",
python=str(PYTHON_BIN_PATH),
python_callable=_geocode,
)

(
start
>> dbt_seed
>> dbt_create_udfs
>> get_staging_tasks()
>> get_before_geocoding_tasks()
>> python_geocode
>> get_after_geocoding_tasks()
>> marts.export_di_dataset_to_s3()
>> end
)
project_config=cosmos.ProjectConfig(
dbt_project_path=Variable.get("DBT_PROJECT_DIR"),
),
profile_config=cosmos.ProfileConfig(
profile_name="data_inclusion",
target_name="dev",
profile_mapping=cosmos.profiles.PostgresUserPasswordProfileMapping(
conn_id="pg",
profile_args={"schema": "public"},
),
),
execution_config=cosmos.ExecutionConfig(
dbt_executable_path=str(DBT_PYTHON_BIN_PATH.parent / "dbt")
),
render_config=cosmos.RenderConfig(
select=[
"source:*",
"path:models/staging/sources/**/*.sql",
"path:models/intermediate/sources/**/*.sql",
"path:models/intermediate/*.sql",
"path:models/marts/**/*.sql",
],
# show the source as start nodes in the graph
node_converters={
cosmos.constants.DbtResourceType("source"): lambda dag,
task_group,
node,
**kwargs: empty.EmptyOperator(
dag=dag, task_group=task_group, task_id=f"source_{node.name}"
),
},
),
)
2 changes: 1 addition & 1 deletion pipeline/requirements/airflow/constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ apache-airflow-providers-discord==3.7.1
apache-airflow-providers-docker==3.12.2
apache-airflow-providers-elasticsearch==5.4.1
apache-airflow-providers-exasol==4.5.2
apache-airflow-providers-fab==1.2.1
apache-airflow-providers-fab==1.2.2
apache-airflow-providers-facebook==3.5.2
apache-airflow-providers-ftp==3.10.0
apache-airflow-providers-github==2.6.2
Expand Down
1 change: 1 addition & 0 deletions pipeline/requirements/airflow/requirements.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
apache-airflow[amazon,postgres,sentry,ssh]==2.9.3
astronomer-cosmos==1.*
43 changes: 42 additions & 1 deletion pipeline/requirements/airflow/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# This file was autogenerated by uv via the following command:
# uv pip compile --constraint airflow/constraints.txt airflow/requirements.in --output-file=airflow/requirements.txt
aenum==3.1.15
# via astronomer-cosmos
aiohttp==3.9.5
# via
# -c airflow/constraints.txt
Expand All @@ -12,6 +14,10 @@ alembic==1.13.2
# via
# -c airflow/constraints.txt
# apache-airflow
annotated-types==0.7.0
# via
# -c airflow/constraints.txt
# pydantic
anyio==4.4.0
# via
# -c airflow/constraints.txt
Expand All @@ -30,6 +36,7 @@ apache-airflow==2.9.3
# apache-airflow-providers-smtp
# apache-airflow-providers-sqlite
# apache-airflow-providers-ssh
# astronomer-cosmos
apache-airflow-providers-amazon==8.25.0
# via
# -c airflow/constraints.txt
Expand All @@ -45,7 +52,7 @@ apache-airflow-providers-common-sql==1.14.2
# apache-airflow-providers-amazon
# apache-airflow-providers-postgres
# apache-airflow-providers-sqlite
apache-airflow-providers-fab==1.2.1
apache-airflow-providers-fab==1.2.2
# via
# -c airflow/constraints.txt
# apache-airflow
Expand Down Expand Up @@ -96,11 +103,14 @@ asn1crypto==1.5.1
# via
# -c airflow/constraints.txt
# scramp
astronomer-cosmos==1.5.1
# via -r airflow/requirements.in
attrs==23.2.0
# via
# -c airflow/constraints.txt
# aiohttp
# apache-airflow
# astronomer-cosmos
# jsonschema
# referencing
babel==2.15.0
Expand Down Expand Up @@ -206,6 +216,10 @@ dill==0.3.8
# via
# -c airflow/constraints.txt
# apache-airflow
distlib==0.3.8
# via
# -c airflow/constraints.txt
# virtualenv
dnspython==2.6.1
# via
# -c airflow/constraints.txt
Expand All @@ -218,6 +232,10 @@ email-validator==2.2.0
# via
# -c airflow/constraints.txt
# flask-appbuilder
filelock==3.15.4
# via
# -c airflow/constraints.txt
# virtualenv
flask==2.2.5
# via
# -c airflow/constraints.txt
Expand Down Expand Up @@ -327,6 +345,7 @@ idna==3.7
importlib-metadata==6.11.0
# via
# -c airflow/constraints.txt
# apache-airflow
# opentelemetry-api
importlib-resources==6.4.0
# via
Expand All @@ -348,6 +367,7 @@ jinja2==3.1.4
# via
# -c airflow/constraints.txt
# apache-airflow
# astronomer-cosmos
# flask
# flask-babel
# python-nvd3
Expand Down Expand Up @@ -441,6 +461,8 @@ more-itertools==10.3.0
# via
# -c airflow/constraints.txt
# apache-airflow-providers-common-sql
msgpack==1.0.8
# via astronomer-cosmos
multidict==6.0.5
# via
# -c airflow/constraints.txt
Expand Down Expand Up @@ -495,6 +517,7 @@ packaging==24.1
# -c airflow/constraints.txt
# apache-airflow
# apispec
# astronomer-cosmos
# connexion
# gunicorn
# limits
Expand All @@ -515,6 +538,10 @@ pendulum==3.0.0
# via
# -c airflow/constraints.txt
# apache-airflow
platformdirs==4.2.2
# via
# -c airflow/constraints.txt
# virtualenv
pluggy==1.5.0
# via
# -c airflow/constraints.txt
Expand Down Expand Up @@ -548,6 +575,14 @@ pycparser==2.22
# via
# -c airflow/constraints.txt
# cffi
pydantic==2.8.2
# via
# -c airflow/constraints.txt
# astronomer-cosmos
pydantic-core==2.20.1
# via
# -c airflow/constraints.txt
# pydantic
pygments==2.18.0
# via
# -c airflow/constraints.txt
Expand Down Expand Up @@ -734,6 +769,8 @@ typing-extensions==4.12.2
# flask-limiter
# limits
# opentelemetry-sdk
# pydantic
# pydantic-core
tzdata==2024.1
# via
# -c airflow/constraints.txt
Expand All @@ -756,6 +793,10 @@ urllib3==2.0.7
# botocore
# requests
# sentry-sdk
virtualenv==20.26.3
# via
# -c airflow/constraints.txt
# astronomer-cosmos
watchtower==3.2.0
# via
# -c airflow/constraints.txt
Expand Down
Loading

0 comments on commit 1b84641

Please sign in to comment.