Skip to content

Commit

Permalink
Support sql check (#44)
Browse files Browse the repository at this point in the history
Co-authored-by: Michiel De Smet <[email protected]>
  • Loading branch information
mdesmet and Michiel De Smet authored Nov 20, 2024
1 parent f20c81b commit 35d2598
Show file tree
Hide file tree
Showing 12 changed files with 154 additions and 9 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def read(*names, **kwargs):
"ruamel.yaml==0.18.6",
"tabulate==0.9.0",
"requests==2.31.0",
"sqlglot==25.30.0",
],
extras_require={
# eg:
Expand Down
11 changes: 2 additions & 9 deletions src/datapilot/core/insights/sql/base/insight.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
from abc import abstractmethod
from typing import Optional

from datapilot.core.insights.base.insight import Insight
from datapilot.schemas.sql import Dialect
from datapilot.core.platforms.dbt.insights.checks.base import ChecksInsight


class SqlInsight(Insight):
class SqlInsight(ChecksInsight):
NAME = "SqlInsight"

def __init__(self, sql: str, dialect: Optional[Dialect], *args, **kwargs):
self.sql = sql
self.dialect = dialect
super().__init__(*args, **kwargs)

@abstractmethod
def generate(self, *args, **kwargs) -> dict:
pass
2 changes: 2 additions & 0 deletions src/datapilot/core/platforms/dbt/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self.macros = self.manifest_wrapper.get_macros()
self.sources = self.manifest_wrapper.get_sources()
self.exposures = self.manifest_wrapper.get_exposures()
self.adapter_type = self.manifest_wrapper.get_adapter_type()
self.seeds = self.manifest_wrapper.get_seeds()
self.children_map = self.manifest_wrapper.parent_to_child_map(self.nodes)
self.tests = self.manifest_wrapper.get_tests()
Expand Down Expand Up @@ -112,6 +113,7 @@ def run(self):
children_map=self.children_map,
tests=self.tests,
project_name=self.project_name,
adapter_type=self.adapter_type,
config=self.config,
selected_models=self.selected_models,
excluded_models=self.excluded_models,
Expand Down
2 changes: 2 additions & 0 deletions src/datapilot/core/platforms/dbt/insights/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from datapilot.core.platforms.dbt.insights.modelling.unused_sources import DBTUnusedSources
from datapilot.core.platforms.dbt.insights.performance.chain_view_linking import DBTChainViewLinking
from datapilot.core.platforms.dbt.insights.performance.exposure_parent_materializations import DBTExposureParentMaterialization
from datapilot.core.platforms.dbt.insights.sql.sql_check import SqlCheck
from datapilot.core.platforms.dbt.insights.structure.model_directories_structure import DBTModelDirectoryStructure
from datapilot.core.platforms.dbt.insights.structure.model_naming_conventions import DBTModelNamingConvention
from datapilot.core.platforms.dbt.insights.structure.source_directories_structure import DBTSourceDirectoryStructure
Expand Down Expand Up @@ -112,4 +113,5 @@
CheckSourceHasTests,
CheckSourceTableHasDescription,
CheckSourceTags,
SqlCheck,
]
3 changes: 3 additions & 0 deletions src/datapilot/core/platforms/dbt/insights/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import ClassVar
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

from datapilot.config.utils import get_insight_config
Expand Down Expand Up @@ -33,6 +34,7 @@ def __init__(
macros: Dict[str, AltimateManifestMacroNode],
children_map: Dict[str, List[str]],
project_name: str,
adapter_type: Optional[str],
selected_models: Union[List[str], None] = None,
excluded_models: Union[List[str], None] = None,
*args,
Expand All @@ -47,6 +49,7 @@ def __init__(
self.seeds = seeds
self.children_map = children_map
self.project_name = project_name
self.adapter_type = adapter_type
self.selected_models = selected_models
self.excluded_models = excluded_models
super().__init__(*args, **kwargs)
Expand Down
Empty file.
23 changes: 23 additions & 0 deletions src/datapilot/core/platforms/dbt/insights/sql/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import abstractmethod
from typing import Tuple

from datapilot.core.platforms.dbt.insights.base import DBTInsight


class SqlInsight(DBTInsight):
TYPE = "governance"

@abstractmethod
def generate(self, *args, **kwargs) -> dict:
pass

@classmethod
def has_all_required_data(cls, has_manifest: bool, **kwargs) -> Tuple[bool, str]:
"""
Check if all required data is available for the insight to run.
:param has_manifest: A boolean indicating if manifest is available.
:return: A boolean indicating if all required data is available.
"""
if not has_manifest:
return False, "manifest is required for insight to run."
return True, ""
101 changes: 101 additions & 0 deletions src/datapilot/core/platforms/dbt/insights/sql/sql_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import inspect
from typing import List

from sqlglot import parse_one
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify import qualify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries

from datapilot.core.insights.sql.base.insight import SqlInsight
from datapilot.core.insights.utils import get_severity
from datapilot.core.platforms.dbt.insights.schema import DBTInsightResult
from datapilot.core.platforms.dbt.insights.schema import DBTModelInsightResponse

RULES = (
pushdown_projections,
normalize,
unnest_subqueries,
eliminate_subqueries,
eliminate_joins,
eliminate_ctes,
)


class SqlCheck(SqlInsight):
"""
This class identifies DBT models with SQL optimization issues.
"""

NAME = "sql optimization issues"
ALIAS = "check_sql_optimization"
DESCRIPTION = "Checks if the model has SQL optimization issues. "
REASON_TO_FLAG = "The query can be optimized."
FAILURE_MESSAGE = "The query for model `{model_unique_id}` has optimization opportunities:\n{rule_name}. "
RECOMMENDATION = "Please adapt the query of the model `{model_unique_id}` as in following example:\n{optimized_sql}"

def _build_failure_result(self, model_unique_id: str, rule_name: str, optimized_sql: str) -> DBTInsightResult:
"""
Constructs a failure result for a given model with sql optimization issues.
:param model_unique_id: The unique id of the dbt model.
:param rule_name: The rule that generated this failure result.
:param optimized_sql: The optimized sql.
:return: An instance of DBTInsightResult containing failure details.
"""
failure_message = self.FAILURE_MESSAGE.format(model_unique_id=model_unique_id, rule_name=rule_name)
recommendation = self.RECOMMENDATION.format(model_unique_id=model_unique_id, optimized_sql=optimized_sql)
return DBTInsightResult(
type=self.TYPE,
name=self.NAME,
message=failure_message,
recommendation=recommendation,
reason_to_flag=self.REASON_TO_FLAG,
metadata={"model_unique_id": model_unique_id, "rule_name": rule_name},
)

def generate(self, *args, **kwargs) -> List[DBTModelInsightResponse]:
"""
Generates insights for each DBT model in the project, focusing on sql optimization issues.
:return: A list of DBTModelInsightResponse objects with insights for each model.
"""
self.logger.debug("Generating sql insights for DBT models")
insights = []

possible_kwargs = {
"db": None,
"catalog": None,
"dialect": self.adapter_type,
"isolate_tables": True, # needed for other optimizations to perform well
"quote_identifiers": False,
**kwargs,
}
for node_id, node in self.nodes.items():
try:
compiled_query = node.compiled_code
if compiled_query:
parsed_query = parse_one(compiled_query, dialect=self.adapter_type)
qualified = qualify(parsed_query, **possible_kwargs)
changed = qualified.copy()
for rule in RULES:
original = changed.copy()
rule_params = inspect.getfullargspec(rule).args
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
changed = rule(changed, **rule_kwargs)
if changed.sql() != original.sql():
insights.append(
DBTModelInsightResponse(
unique_id=node_id,
package_name=node.package_name,
path=node.original_file_path,
original_file_path=node.original_file_path,
insight=self._build_failure_result(node_id, rule.__name__, changed.sql()),
severity=get_severity(self.config, self.ALIAS, self.DEFAULT_SEVERITY),
)
)
except Exception as e:
self.logger.error(e)
return insights
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
from typing import Optional
from typing import Set

from dbt_artifacts_parser.parsers.manifest.manifest_v10 import GenericTestNode
Expand Down Expand Up @@ -67,6 +68,7 @@ def _get_node(self, node: ManifestNode) -> AltimateManifestNode:
depends_on_macros = node.depends_on.macros if node.depends_on else None
compiled_path = node.compiled_path
compiled = node.compiled
compiled_code = node.compiled_code
raw_code = node.raw_code
language = node.language
contract = AltimateDBTContract(**node.contract.__dict__) if node.contract else None
Expand Down Expand Up @@ -381,6 +383,9 @@ def get_seeds(self) -> Dict[str, AltimateSeedNode]:
seeds[seed.unique_id] = self._get_seed(seed)
return seeds

def get_adapter_type(self) -> Optional[str]:
return self.manifest.metadata.adapter_type

def parent_to_child_map(self, nodes: Dict[str, AltimateManifestNode]) -> Dict[str, Set[str]]:
"""
Current manifest contains information about parents
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
from typing import Optional
from typing import Set

from dbt_artifacts_parser.parsers.manifest.manifest_v11 import GenericTestNode
Expand Down Expand Up @@ -67,6 +68,7 @@ def _get_node(self, node: ManifestNode) -> AltimateManifestNode:
depends_on_macros = node.depends_on.macros if node.depends_on else None
compiled_path = node.compiled_path
compiled = node.compiled
compiled_code = node.compiled_code
raw_code = node.raw_code
language = node.language
contract = AltimateDBTContract(**node.contract.__dict__) if node.contract else None
Expand Down Expand Up @@ -381,6 +383,9 @@ def get_seeds(self) -> Dict[str, AltimateSeedNode]:
seeds[seed.unique_id] = self._get_seed(seed)
return seeds

def get_adapter_type(self) -> Optional[str]:
return self.manifest.metadata.adapter_type

def parent_to_child_map(self, nodes: Dict[str, AltimateManifestNode]) -> Dict[str, Set[str]]:
"""
Current manifest contains information about parents
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
from typing import Optional
from typing import Set

from dbt_artifacts_parser.parsers.manifest.manifest_v12 import ManifestV12
Expand Down Expand Up @@ -67,6 +68,7 @@ def _get_node(self, node: ManifestNode) -> AltimateManifestNode:
depends_on_macros = node.depends_on.macros if node.depends_on else None
compiled_path = node.compiled_path
compiled = node.compiled
compiled_code = node.compiled_code
raw_code = node.raw_code
language = node.language
contract = AltimateDBTContract(**node.contract.__dict__) if node.contract else None
Expand Down Expand Up @@ -393,6 +395,9 @@ def get_seeds(self) -> Dict[str, AltimateSeedNode]:
seeds[seed.unique_id] = self._get_seed(seed)
return seeds

def get_adapter_type(self) -> Optional[str]:
return self.manifest.metadata.adapter_type

def parent_to_child_map(self, nodes: Dict[str, AltimateManifestNode]) -> Dict[str, Set[str]]:
"""
Current manifest contains information about parents
Expand Down
5 changes: 5 additions & 0 deletions src/datapilot/core/platforms/dbt/wrappers/manifest/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC
from abc import abstractmethod
from typing import Dict
from typing import Optional
from typing import Set

from datapilot.core.platforms.dbt.schemas.manifest import AltimateManifestExposureNode
Expand All @@ -26,6 +27,10 @@ def get_sources(self) -> Dict[str, AltimateManifestSourceNode]:
def get_exposures(self) -> Dict[str, AltimateManifestExposureNode]:
pass

@abstractmethod
def get_adapter_type(self) -> Optional[str]:
pass

@abstractmethod
def parent_to_child_map(self, nodes: Dict[str, AltimateManifestNode]) -> Dict[str, Set[str]]:
pass
Expand Down

0 comments on commit 35d2598

Please sign in to comment.