From 0d87c4000f48072d9b06913536bbc14a0300f37a Mon Sep 17 00:00:00 2001 From: Brian K Date: Wed, 20 Nov 2024 21:23:15 -0800 Subject: [PATCH] feat: add output-to-lower This command allows for the outputs to be in lower case. --- src/dbt_osmosis/core/osmosis.py | 195 ++++++++++++++++++++++---------- src/dbt_osmosis/main.py | 27 ++++- 2 files changed, 161 insertions(+), 61 deletions(-) diff --git a/src/dbt_osmosis/core/osmosis.py b/src/dbt_osmosis/core/osmosis.py index 397177c..b5fa875 100644 --- a/src/dbt_osmosis/core/osmosis.py +++ b/src/dbt_osmosis/core/osmosis.py @@ -128,6 +128,7 @@ def __init__( use_unrendered_descriptions: bool = False, profile: Optional[str] = None, add_inheritance_for_specified_keys: Optional[List[str]] = None, + output_to_lower: bool = False, ): """Initializes the DbtYamlManager class.""" super().__init__(target, profiles_dir, project_dir, threads, vars=vars, profile=profile) @@ -145,6 +146,7 @@ def __init__( self.add_progenitor_to_meta = add_progenitor_to_meta self.use_unrendered_descriptions = use_unrendered_descriptions self.add_inheritance_for_specified_keys = add_inheritance_for_specified_keys or [] + self.output_to_lower = output_to_lower if len(list(self.filtered_models())) == 0: logger().warning( @@ -169,9 +171,14 @@ def yaml_handler(self): self._yaml_handler = YamlHandler() return self._yaml_handler - def column_casing(self, column: str) -> str: + def column_casing(self, column: str, output_to_lower: bool) -> str: """Converts a column name to the correct casing for the target database.""" - if self.config.credentials.type == "snowflake": + # leave column name as is if encapsulated by quotes. + if self.config.credentials.type == "snowflake" and '"' == column[0] and '"' == column[-1]: + return column + elif output_to_lower: + return column.lower() + elif self.config.credentials.type == "snowflake": return column.upper() return column @@ -320,40 +327,40 @@ def get_catalog_key(node: ManifestNode) -> CatalogKey: return CatalogKey(node.database, node.schema, getattr(node, "identifier", node.name)) return CatalogKey(node.database, node.schema, getattr(node, "alias", node.name)) - def get_base_model(self, node: ManifestNode) -> Dict[str, Any]: + def get_base_model(self, node: ManifestNode, output_to_lower: bool) -> Dict[str, Any]: """Construct a base model object with model name, column names populated from database""" - columns = self.get_columns(self.get_catalog_key(node)) + columns = self.get_columns(self.get_catalog_key(node), output_to_lower) return { "name": node.name, "columns": [{"name": column_name, "description": ""} for column_name in columns], } def augment_existing_model( - self, documentation: Dict[str, Any], node: ManifestNode + self, documentation: Dict[str, Any], node: ManifestNode, output_to_lower: bool ) -> Dict[str, Any]: """Injects columns from database into existing model if not found""" model_columns: List[str] = [c["name"] for c in documentation.get("columns", [])] - database_columns = self.get_columns(self.get_catalog_key(node)) + database_columns = self.get_columns(self.get_catalog_key(node), output_to_lower) for column in ( c for c in database_columns if not any(c.lower() == m.lower() for m in model_columns) ): logger().info( ":syringe: Injecting column %s into dbt schema for %s", - self.column_casing(column), + self.column_casing(column, output_to_lower), node.unique_id, ) documentation.setdefault("columns", []).append( { - "name": self.column_casing(column), + "name": self.column_casing(column, output_to_lower), "description": getattr(column, "description", ""), } ) return documentation - def get_columns(self, catalog_key: CatalogKey) -> List[str]: + def get_columns(self, catalog_key: CatalogKey, output_to_lower: bool) -> List[str]: """Get all columns in a list for a model""" - return list(self.get_columns_meta(catalog_key).keys()) + return list(self.get_columns_meta(catalog_key, output_to_lower).keys()) @property def catalog(self) -> Optional[CatalogArtifact]: @@ -382,7 +389,9 @@ def _get_column_type(self, column: Column) -> str: return column.dtype @lru_cache(maxsize=5000) - def get_columns_meta(self, catalog_key: CatalogKey) -> Dict[str, ColumnMetadata]: + def get_columns_meta( + self, catalog_key: CatalogKey, output_to_lower: bool = False + ) -> Dict[str, ColumnMetadata]: """Get all columns in a list for a model""" columns = OrderedDict() blacklist = self.config.vars.vars.get("dbt-osmosis", {}).get("_blacklist", []) @@ -399,8 +408,8 @@ def get_columns_meta(self, catalog_key: CatalogKey) -> Dict[str, ColumnMetadata] for col in matching_models_or_sources[0].columns.values(): if any(re.match(pattern, col.name) for pattern in blacklist): continue - columns[self.column_casing(col.name)] = ColumnMetadata( - name=self.column_casing(col.name), + columns[self.column_casing(col.name, output_to_lower)] = ColumnMetadata( + name=self.column_casing(col.name, output_to_lower), type=col.type, index=col.index, comment=col.comment, @@ -424,8 +433,8 @@ def get_columns_meta(self, catalog_key: CatalogKey) -> Dict[str, ColumnMetadata] for c in self.adapter.get_columns_in_relation(table): if any(re.match(pattern, c.name) for pattern in blacklist): continue - columns[self.column_casing(c.name)] = ColumnMetadata( - name=self.column_casing(c.name), + columns[self.column_casing(c.name, output_to_lower)] = ColumnMetadata( + name=self.column_casing(c.name, output_to_lower), type=self._get_column_type(c), index=None, # type: ignore comment=getattr(c, "comment", None), @@ -434,11 +443,13 @@ def get_columns_meta(self, catalog_key: CatalogKey) -> Dict[str, ColumnMetadata] for exp in c.flatten(): if any(re.match(pattern, exp.name) for pattern in blacklist): continue - columns[self.column_casing(exp.name)] = ColumnMetadata( - name=self.column_casing(exp.name), - type=self._get_column_type(exp), - index=None, # type: ignore - comment=getattr(exp, "comment", None), + columns[self.column_casing(exp.name, output_to_lower)] = ( + ColumnMetadata( + name=self.column_casing(exp.name, output_to_lower), + type=self._get_column_type(exp), + index=None, # type: ignore + comment=getattr(exp, "comment", None), + ) ) except Exception as error: logger().info( @@ -449,7 +460,7 @@ def get_columns_meta(self, catalog_key: CatalogKey) -> Dict[str, ColumnMetadata] ) return columns - def bootstrap_sources(self) -> None: + def bootstrap_sources(self, output_to_lower: bool = False) -> None: """Bootstrap sources from the dbt-osmosis vars config""" performed_disk_mutation = False blacklist = self.config.vars.vars.get("dbt-osmosis", {}).get("_blacklist", []) @@ -493,7 +504,7 @@ def bootstrap_sources(self) -> None: "columns": ( [ { - "name": self.column_casing(exp.name), + "name": self.column_casing(exp.name, output_to_lower), "description": getattr( exp, "description", getattr(c, "description", "") ), @@ -537,11 +548,11 @@ def bootstrap_sources(self) -> None: logger().info("...reloading project to pick up new sources") self.safe_parse_project(reinit=True) - def build_schema_folder_mapping(self) -> Dict[str, SchemaFileLocation]: + def build_schema_folder_mapping(self, output_to_lower: bool) -> Dict[str, SchemaFileLocation]: """Builds a mapping of models or sources to their existing and target schema file paths""" # Resolve target nodes - self.bootstrap_sources() + self.bootstrap_sources(output_to_lower) # Container for output schema_map = {} @@ -559,7 +570,13 @@ def build_schema_folder_mapping(self) -> Dict[str, SchemaFileLocation]: return schema_map - def _draft(self, schema_file: SchemaFileLocation, unique_id: str, blueprint: dict) -> None: + def _draft( + self, + schema_file: SchemaFileLocation, + unique_id: str, + blueprint: dict, + output_to_lower: bool, + ) -> None: try: with self.mutex: blueprint.setdefault( @@ -579,7 +596,9 @@ def _draft(self, schema_file: SchemaFileLocation, unique_id: str, blueprint: dic # NodeType.Source files are guaranteed to exist by this point with self.mutex: assert schema_file.node_type == NodeType.Model - blueprint[schema_file.target].output["models"].append(self.get_base_model(node)) + blueprint[schema_file.target].output["models"].append( + self.get_base_model(node, output_to_lower) + ) else: # Sanity check that the file exists before we try to load it, this should never be false assert schema_file.current.exists(), f"File {schema_file.current} does not exist" @@ -592,7 +611,9 @@ def _draft(self, schema_file: SchemaFileLocation, unique_id: str, blueprint: dic model for model in models_in_file if model["name"] == node.name ): # Augment Documented Model - augmented_model = self.augment_existing_model(documented_model, node) + augmented_model = self.augment_existing_model( + documented_model, node, output_to_lower + ) with self.mutex: blueprint[schema_file.target].output["models"].append(augmented_model) # Target to supersede current @@ -608,7 +629,9 @@ def _draft(self, schema_file: SchemaFileLocation, unique_id: str, blueprint: dic if table["name"] == node.name ): # Augment Documented Source - augmented_model = self.augment_existing_model(documented_model, node) + augmented_model = self.augment_existing_model( + documented_model, node, output_to_lower + ) with self.mutex: if not any( s["name"] == node.source_name @@ -644,7 +667,9 @@ def _draft(self, schema_file: SchemaFileLocation, unique_id: str, blueprint: dic ) raise e - def draft_project_structure_update_plan(self) -> Dict[Path, SchemaFileMigration]: + def draft_project_structure_update_plan( + self, output_to_lower: bool = False + ) -> Dict[Path, SchemaFileMigration]: """Build project structure update plan based on `dbt-osmosis:` configs set across dbt_project.yml and model files. The update plan includes injection of undocumented models. Unless this plan is constructed and executed by the `commit_project_restructure` function, @@ -664,9 +689,13 @@ def draft_project_structure_update_plan(self) -> Dict[Path, SchemaFileMigration] ) futs = [] with self.adapter.connection_named("dbt-osmosis"): - for unique_id, schema_file in self.build_schema_folder_mapping().items(): + for unique_id, schema_file in self.build_schema_folder_mapping(output_to_lower).items(): if not schema_file.is_valid: - futs.append(self.tpe.submit(self._draft, schema_file, unique_id, blueprint)) + futs.append( + self.tpe.submit( + self._draft, schema_file, unique_id, blueprint, output_to_lower + ) + ) wait(futs) return blueprint @@ -681,7 +710,9 @@ def cleanup_blueprint(self, blueprint: dict) -> None: return blueprint def commit_project_restructure_to_disk( - self, blueprint: Optional[Dict[Path, SchemaFileMigration]] = None + self, + blueprint: Optional[Dict[Path, SchemaFileMigration]] = None, + output_to_lower: bool = False, ) -> bool: """Given a project restrucure plan of pathlib Paths to a mapping of output and supersedes which is in itself a mapping of Paths to model names, commit changes to filesystem to @@ -691,6 +722,7 @@ def commit_project_restructure_to_disk( Args: blueprint (Dict[Path, SchemaFileMigration]): Project restructure plan as typically created by `build_project_structure_update_plan` + output_to_lower (bool): Set column casing to lowercase. Returns: bool: True if the project was restructured, False if no action was required @@ -698,7 +730,7 @@ def commit_project_restructure_to_disk( # Build blueprint if not user supplied if not blueprint: - blueprint = self.draft_project_structure_update_plan() + blueprint = self.draft_project_structure_update_plan(output_to_lower) blueprint = self.cleanup_blueprint(blueprint) @@ -846,6 +878,7 @@ def _run( node: ManifestNode, schema_map: Dict[str, SchemaFileLocation], force_inheritance: bool = False, + output_to_lower: bool = False, ): try: with self.mutex: @@ -861,8 +894,8 @@ def _run( # Build Sets logger().info(":mag: Resolving columns in database") - database_columns_ordered = self.get_columns(self.get_catalog_key(node)) - columns_db_meta = self.get_columns_meta(self.get_catalog_key(node)) + database_columns_ordered = self.get_columns(self.get_catalog_key(node), output_to_lower) + columns_db_meta = self.get_columns_meta(self.get_catalog_key(node), output_to_lower) database_columns: Set[str] = set(database_columns_ordered) yaml_columns_ordered = [column for column in node.columns] yaml_columns: Set[str] = set(yaml_columns_ordered) @@ -930,6 +963,7 @@ def _run( node, section, columns_db_meta, + output_to_lower, ) if ( n_cols_added @@ -953,7 +987,9 @@ def _run( def _sort_columns(column_info: dict) -> int: nonlocal last_ix try: - normalized_name = self.column_casing(column_info["name"]) + normalized_name = self.column_casing( + column_info["name"], output_to_lower + ) return database_columns_ordered.index(normalized_name) except IndexError: last_ix += 1 @@ -1002,13 +1038,17 @@ def _sort_columns(column_info: dict) -> int: logger().error("Error occurred while processing model %s: %s", unique_id, e) raise e - def propagate_documentation_downstream(self, force_inheritance: bool = False) -> None: - schema_map = self.build_schema_folder_mapping() + def propagate_documentation_downstream( + self, force_inheritance: bool = False, output_to_lower: bool = False + ) -> None: + schema_map = self.build_schema_folder_mapping(output_to_lower) futs = [] with self.adapter.connection_named("dbt-osmosis"): for unique_id, node in self.filtered_models(): futs.append( - self.tpe.submit(self._run, unique_id, node, schema_map, force_inheritance) + self.tpe.submit( + self._run, unique_id, node, schema_map, force_inheritance, output_to_lower + ) ) wait(futs) @@ -1041,12 +1081,13 @@ def update_columns_attribute( attribute_name: str, meta_key: str, skip_attribute_update: Any, + output_to_lower: bool = False, ) -> int: changes_committed = 0 if (skip_attribute_update is True) or (skip_attribute_update is None): return changes_committed for column in columns_db_meta: - cased_column_name = self.column_casing(column) + cased_column_name = self.column_casing(column, output_to_lower) if cased_column_name in node.columns: column_meta_obj = columns_db_meta.get(cased_column_name) if column_meta_obj: @@ -1058,8 +1099,14 @@ def update_columns_attribute( continue setattr(node.columns[cased_column_name], attribute_name, column_meta) for model_column in yaml_file_model_section["columns"]: - if self.column_casing(model_column["name"]) == cased_column_name: - model_column.update({attribute_name: column_meta}) + if ( + self.column_casing(model_column["name"], output_to_lower) + == cased_column_name + ): + if output_to_lower: + model_column.update({attribute_name: column_meta.lower()}) + else: + model_column.update({attribute_name: column_meta}) changes_committed += 1 return changes_committed @@ -1069,25 +1116,42 @@ def add_missing_cols_to_node_and_model( node: ManifestNode, yaml_file_model_section: Dict[str, Any], columns_db_meta: Dict[str, ColumnMetadata], + output_to_lower: bool, ) -> int: """Add missing columns to node and model simultaneously THIS MUTATES THE NODE AND MODEL OBJECTS so that state is always accurate""" changes_committed = 0 for column in missing_columns: - node.columns[column] = ColumnInfo.from_dict( - { - "name": column, - "description": columns_db_meta[column].comment or "", - "data_type": columns_db_meta[column].type, - } - ) - yaml_file_model_section.setdefault("columns", []).append( - { - "name": column, - "description": columns_db_meta[column].comment or "", - "data_type": columns_db_meta[column].type, - } - ) + if output_to_lower: + node.columns[column] = ColumnInfo.from_dict( + { + "name": column.lower(), + "description": columns_db_meta[column].comment or "", + "data_type": columns_db_meta[column].type.lower(), + } + ) + yaml_file_model_section.setdefault("columns", []).append( + { + "name": column.lower(), + "description": columns_db_meta[column].comment or "", + "data_type": columns_db_meta[column].type.lower(), + } + ) + else: + node.columns[column] = ColumnInfo.from_dict( + { + "name": column, + "description": columns_db_meta[column].comment or "", + "data_type": columns_db_meta[column].type, + } + ) + yaml_file_model_section.setdefault("columns", []).append( + { + "name": column, + "description": columns_db_meta[column].comment or "", + "data_type": columns_db_meta[column].type, + } + ) changes_committed += 1 logger().info( ":syringe: Injecting column %s into dbt schema for model %s", column, node.unique_id @@ -1102,13 +1166,14 @@ def update_schema_file_and_node( node: ManifestNode, section: Dict[str, Any], columns_db_meta: Dict[str, ColumnMetadata], + output_to_lower: bool, ) -> Tuple[int, int, int, int, int]: """Take action on a schema file mirroring changes in the node.""" logger().info(":microscope: Looking for actions for %s", node.unique_id) n_cols_added = 0 if not self.skip_add_columns: n_cols_added = self.add_missing_cols_to_node_and_model( - missing_columns, node, section, columns_db_meta + missing_columns, node, section, columns_db_meta, output_to_lower ) knowledge = ColumnLevelKnowledgePropagator.get_node_columns_with_inherited_knowledge( @@ -1131,10 +1196,22 @@ def update_schema_file_and_node( ) ) n_cols_data_type_updated = self.update_columns_attribute( - node, section, columns_db_meta, "data_type", "type", self.skip_add_data_types + node, + section, + columns_db_meta, + "data_type", + "type", + self.skip_add_data_types, + self.output_to_lower, ) n_cols_description_updated = self.update_columns_attribute( - node, section, columns_db_meta, "description", "comment", self.catalog_file + node, + section, + columns_db_meta, + "description", + "comment", + self.catalog_file, + self.output_to_lower, ) n_cols_removed = self.remove_columns_not_in_database(extra_columns, node, section) return ( diff --git a/src/dbt_osmosis/main.py b/src/dbt_osmosis/main.py index b42999e..e62e7af 100644 --- a/src/dbt_osmosis/main.py +++ b/src/dbt_osmosis/main.py @@ -193,6 +193,11 @@ def wrapper(*args, **kwargs): type=click.STRING, help="If specified, will add inheritance for the specified keys.", ) +@click.option( + "--output-to-lower", + is_flag=True, + help="If specified, output yaml file in lowercase if possible.", +) @click.argument("models", nargs=-1) def refactor( target: Optional[str] = None, @@ -215,6 +220,7 @@ def refactor( vars: Optional[str] = None, use_unrendered_descriptions: bool = False, add_inheritance_for_specified_keys: Optional[List[str]] = None, + output_to_lower: bool = False, ): """Executes organize which syncs yaml files with database schema and organizes the dbt models directory, reparses the project, then executes document passing down inheritable documentation @@ -249,12 +255,15 @@ def refactor( vars=vars, use_unrendered_descriptions=use_unrendered_descriptions, add_inheritance_for_specified_keys=add_inheritance_for_specified_keys, + output_to_lower=output_to_lower, ) # Conform project structure & bootstrap undocumented models injecting columns if runner.commit_project_restructure_to_disk(): runner.safe_parse_project(reinit=True) - runner.propagate_documentation_downstream(force_inheritance=force_inheritance) + runner.propagate_documentation_downstream( + force_inheritance=force_inheritance, output_to_lower=output_to_lower + ) if check and runner.mutations > 0: exit(1) @@ -345,6 +354,11 @@ def refactor( type=click.STRING, help="If specified, will add inheritance for the specified keys.", ) +@click.option( + "--output-to-lower", + is_flag=True, + help="If specified, output yaml file in lowercase if possible.", +) @click.argument("models", nargs=-1) def organize( target: Optional[str] = None, @@ -364,6 +378,7 @@ def organize( profile: Optional[str] = None, vars: Optional[str] = None, add_inheritance_for_specified_keys: Optional[List[str]] = None, + output_to_lower: bool = False, ): """Organizes schema ymls based on config and injects undocumented models @@ -395,6 +410,7 @@ def organize( profile=profile, vars=vars, add_inheritance_for_specified_keys=add_inheritance_for_specified_keys, + output_to_lower=output_to_lower, ) # Conform project structure & bootstrap undocumented models injecting columns @@ -514,6 +530,11 @@ def organize( type=click.STRING, help="If specified, will add inheritance for the specified keys.", ) +@click.option( + "--output-to-lower", + is_flag=True, + help="If specified, output yaml file in lowercase if possible.", +) @click.argument("models", nargs=-1) def document( target: Optional[str] = None, @@ -536,6 +557,7 @@ def document( vars: Optional[str] = None, use_unrendered_descriptions: bool = False, add_inheritance_for_specified_keys: Optional[List[str]] = None, + output_to_lower: bool = False, ): """Column level documentation inheritance for existing models @@ -569,10 +591,11 @@ def document( vars=vars, use_unrendered_descriptions=use_unrendered_descriptions, add_inheritance_for_specified_keys=add_inheritance_for_specified_keys, + output_to_lower=output_to_lower, ) # Propagate documentation & inject/remove schema file columns to align with model in database - runner.propagate_documentation_downstream(force_inheritance) + runner.propagate_documentation_downstream(force_inheritance, output_to_lower) if check and runner.mutations > 0: exit(1)