Skip to content

Commit

Permalink
feat: sync node to yaml instead of co-manipulation, decouples mutations
Browse files Browse the repository at this point in the history
  • Loading branch information
z3z1ma committed Jan 1, 2025
1 parent 6a786d3 commit 6ffb027
Showing 1 changed file with 144 additions and 23 deletions.
167 changes: 144 additions & 23 deletions src/dbt_osmosis/core/osmosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,147 @@ def _remove_sources(existing_doc: dict[str, t.Any], nodes: list[ResultNode]) ->
existing_doc["sources"] = keep_sources


def _sync_doc_section(
context: YamlRefactorContext, node: ResultNode, doc_section: dict[str, t.Any]
) -> None:
"""Helper function that overwrites 'doc_section' with data from 'node'.
This includes columns, description, meta, tags, etc.
We assume node is the single source of truth, so doc_section is replaced.
"""
if node.description:
doc_section["description"] = node.description
else:
doc_section.pop("description", None)

current_columns: list[dict[str, t.Any]] = doc_section.setdefault("columns", [])
incoming_columns: list[dict[str, t.Any]] = []

current_map = {}
for c in current_columns:
norm_name = normalize_column_name(c["name"], context.project.config.credentials.type)
current_map[norm_name] = c

for name, meta in node.columns.items():
cdict = meta.to_dict()
cdict["name"] = name
norm_name = normalize_column_name(name, context.project.config.credentials.type)

current_yaml = t.cast(dict[str, t.Any], current_map.get(norm_name, {}))
merged = dict(current_yaml)

for k, v in cdict.items():
if k == "description" and not v:
merged.pop("description", None)
else:
merged[k] = v

if not merged.get("description"):
merged.pop("description", None)
if merged.get("tags") == []:
merged.pop("tags", None)
if merged.get("meta") == {}:
merged.pop("meta", None)

for k in list(merged.keys()):
if not merged[k]:
merged.pop(k)

incoming_columns.append(merged)

doc_section["columns"] = incoming_columns


def sync_node_to_yaml(context: YamlRefactorContext, node: ResultNode | None = None) -> None:
"""Synchronize a single node's columns, description, tags, meta, etc. from the manifest into its corresponding YAML file.
We assume the manifest node is the single source of truth, so the YAML file is overwritten to match.
- If the YAML file doesn't exist yet, we create it with minimal structure.
- If the YAML file exists, we read it from the file/ cache, locate the node's section,
and then overwrite that section to match the node's current columns, meta, etc.
This is a one-way sync:
Manifest Node => YAML
All changes to the Node (columns, metadata, etc.) should happen before calling this function.
"""
if node is None:
for _, node in filter_models(context):
sync_node_to_yaml(context, node)
return

current_path = get_current_yaml_path(context, node)
if not current_path or not current_path.exists():
current_path = get_target_yaml_path(context, node)

doc: dict[str, t.Any] = _read_yaml(context, current_path)
if not doc:
doc = {"version": 2}

doc.setdefault("models", [])
doc.setdefault("sources", [])
doc.setdefault("seeds", [])

if node.resource_type == NodeType.Source:
sync_list_key = "sources"
elif node.resource_type == NodeType.Seed:
sync_list_key = "seeds"
else:
sync_list_key = "models"

if node.resource_type == NodeType.Source:
# The doc structure => sources: [ { "name": <source_name>, "tables": [...]}, ... ]
# Step A: find or create the source
doc_source: dict[str, t.Any] | None = None
for s in doc["sources"]:
if s.get("name") == node.source_name:
doc_source = s
break
if not doc_source:
doc_source = {
"name": node.source_name,
"tables": [],
}
doc["sources"].append(doc_source)

# Step B: find or create the table
doc_table: dict[str, t.Any] | None = None
for t_ in doc_source["tables"]:
if t_.get("name") == node.name:
doc_table = t_
break
if not doc_table:
doc_table = {
"name": node.name,
"columns": [],
}
doc_source["tables"].append(doc_table)

# We'll store the columns & description on "doc_table"
# For source, "description" is stored at table-level in the Node
_sync_doc_section(context, node, doc_table)

else:
# Models or Seeds => doc[ "models" ] or doc[ "seeds" ] is a list of { "name", "description", "columns", ... }
doc_list = doc[sync_list_key]
doc_obj: dict[str, t.Any] | None = None
for item in doc_list:
if item.get("name") == node.name:
doc_obj = item
break
if not doc_obj:
doc_obj = {
"name": node.name,
"columns": [],
}
doc_list.append(doc_obj)

_sync_doc_section(context, node, doc_obj)

_write_yaml(context, current_path, doc)


def apply_restructure_plan(
context: YamlRefactorContext, plan: RestructureDeltaPlan, *, confirm: bool = False
) -> None:
Expand Down Expand Up @@ -1183,7 +1324,6 @@ def inherit_upstream_column_knowledge(
if extra not in inheritable:
inheritable.append(extra)

yaml_section = _get_member_yaml(context, node)
column_knowledge_graph = _build_column_knowledge_graph(context, node)
kwargs = None
for name, node_column in node.columns.items():
Expand All @@ -1194,15 +1334,6 @@ def inherit_upstream_column_knowledge(
updated_metadata = {k: v for k, v in kwargs.items() if v is not None and k in inheritable}
node.columns[name] = node_column.replace(**updated_metadata)

if not yaml_section:
continue
for column in yaml_section.get("columns", []):
yaml_name = normalize_column_name(
column["name"], context.project.config.credentials.type
)
if yaml_name == name:
column.update(**updated_metadata)


def inject_missing_columns(context: YamlRefactorContext, node: ResultNode | None = None) -> None:
"""Add missing columns to a dbt node and it's corresponding yaml section. Changes are implicitly buffered until commit_yamls is called."""
Expand All @@ -1229,7 +1360,6 @@ def inject_missing_columns(context: YamlRefactorContext, node: ResultNode | None
if dtype := incoming_meta.type:
gen_col["data_type"] = dtype.lower() if context.settings.output_to_lower else dtype
node.columns[incoming_name] = ColumnInfo.from_dict(gen_col)
yaml_section.setdefault("columns", []).append(gen_col)


def remove_columns_not_in_database(
Expand All @@ -1252,9 +1382,6 @@ def remove_columns_not_in_database(
for extra_column in extra_columns:
logger.info(f"Detected and removing extra column {extra_column} in node {node.unique_id}")
_ = node.columns.pop(extra_column, None)
yaml_section["columns"] = [
c for c in yaml_section.get("columns", []) if c["name"] != extra_column
]


def sort_columns_as_in_database(
Expand All @@ -1265,9 +1392,6 @@ def sort_columns_as_in_database(
for _, node in filter_models(context):
sort_columns_as_in_database(context, node)
return
yaml_section = _get_member_yaml(context, node)
if yaml_section is None:
return
incoming_columns = get_columns(context, get_table_ref(node))

def _position(column: dict[str, t.Any]):
Expand All @@ -1276,7 +1400,6 @@ def _position(column: dict[str, t.Any]):
return 99999
return db_info.index

t.cast(list[dict[str, t.Any]], yaml_section["columns"]).sort(key=_position)
node.columns = {
k: v for k, v in sorted(node.columns.items(), key=lambda i: _position(i[1].to_dict()))
}
Expand All @@ -1291,10 +1414,6 @@ def sort_columns_alphabetically(
for _, node in filter_models(context):
sort_columns_alphabetically(context, node)
return
yaml_section = _get_member_yaml(context, node)
if yaml_section is None:
return
t.cast(list[dict[str, t.Any]], yaml_section["columns"]).sort(key=lambda c: c["name"])
node.columns = {k: v for k, v in sorted(node.columns.items(), key=lambda i: i[0])}
context.register_mutations(1)

Expand Down Expand Up @@ -1378,7 +1497,8 @@ def run_example_compilation_flow(c: DbtConfiguration) -> None:
c = DbtConfiguration(
project_dir="demo_duckdb", profiles_dir="demo_duckdb", vars={"dbt-osmosis": {}}
)
run_example_compilation_flow(c)

# run_example_compilation_flow(c)

project = create_dbt_project_context(c)
yaml_context = YamlRefactorContext(
Expand All @@ -1393,6 +1513,7 @@ def run_example_compilation_flow(c: DbtConfiguration) -> None:
(remove_columns_not_in_database, (yaml_context,), {}),
(inherit_upstream_column_knowledge, (yaml_context,), {}),
(sort_columns_as_in_database, (yaml_context,), {}),
(sync_node_to_yaml, (yaml_context,), {}),
(commit_yamls, (yaml_context,), {}),
)
steps = iter(t.cast(t.Any, steps))
Expand Down

0 comments on commit 6ffb027

Please sign in to comment.