From 528be4653affa2ea2ea7fd70991a7607a34cb196 Mon Sep 17 00:00:00 2001 From: z3z1ma Date: Mon, 30 Dec 2024 16:50:12 -0700 Subject: [PATCH] wip: continue working on functional rewrite --- src/dbt_osmosis/core/osmosis.py | 34 +++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/src/dbt_osmosis/core/osmosis.py b/src/dbt_osmosis/core/osmosis.py index 5ff5486..28c0f3f 100644 --- a/src/dbt_osmosis/core/osmosis.py +++ b/src/dbt_osmosis/core/osmosis.py @@ -687,7 +687,9 @@ def inject_missing_columns(context: YamlRefactorContext, node: ResultNode | None for node in context.project.manifest.nodes.values(): inject_missing_columns(context, node) return - yaml_section = _get_member_yaml(context, node) or {} + yaml_section = _get_member_yaml(context, node) + if yaml_section is None: + return current_columns = { normalize_column_name(c["name"], context.project.config.credentials.type) for c in yaml_section.get("columns", []) @@ -698,13 +700,40 @@ def inject_missing_columns(context: YamlRefactorContext, node: ResultNode | None logger.info( f"Detected and reconciling missing column {incoming_name} in node {node.unique_id}" ) - gen_col = {"name": incoming_name, "description": incoming_meta.comment} + gen_col = {"name": incoming_name, "description": incoming_meta.comment or ""} if dtype := incoming_meta.type: gen_col["data_type"] = dtype.lower() if context.settings.output_to_lower else dtype node.columns[incoming_name] = ColumnInfo.from_dict(gen_col) yaml_section.setdefault("columns", []).append(gen_col) +def remove_columns_not_in_database( + context: YamlRefactorContext, node: ResultNode | None = None +) -> None: + """Remove columns from a dbt node and it's corresponding yaml section that are not present in the database. Changes are implicitly buffered until commit_yamls is called.""" + if context.settings.skip_add_columns: + return + if node is None: + for node in context.project.manifest.nodes.values(): + remove_columns_not_in_database(context, node) + return + yaml_section = _get_member_yaml(context, node) + if yaml_section is None: + return + current_columns = { + normalize_column_name(c["name"], context.project.config.credentials.type) + for c in yaml_section.get("columns", []) + } + incoming_columns = get_columns(context, get_table_ref(node)) + extra_columns = current_columns - set(incoming_columns.keys()) + for extra_column in extra_columns: + logger.info(f"Detected and removing extra column {extra_column} in node {node.unique_id}") + _ = node.columns.pop(extra_column, None) + yaml_section["columns"] = [ + c for c in yaml_section.get("columns", []) if c["name"] != extra_column + ] + + def normalize_column_name(column: str, credentials_type: str) -> str: """Apply case normalization to a column name based on the credentials type.""" if credentials_type == "snowflake" and column.startswith('"') and column.endswith('"'): @@ -1165,4 +1194,5 @@ def run_example_compilation_flow() -> None: plan = draft_restructure_delta_plan(yaml_context) apply_restructure_plan(yaml_context, plan, confirm=True) inject_missing_columns(yaml_context) + remove_columns_not_in_database(yaml_context) commit_yamls(yaml_context)