From e3f0628a6fe0cdab576a455860f1f1bd114459f6 Mon Sep 17 00:00:00 2001 From: Chenyu Li Date: Tue, 17 Dec 2024 16:30:38 -0800 Subject: [PATCH] Custom SQL for get source maxLoadedAt --- .../resources/v1/source_definition.py | 1 + core/dbt/context/providers.py | 16 +++- core/dbt/contracts/graph/nodes.py | 1 + core/dbt/contracts/graph/unparsed.py | 4 + core/dbt/parser/schema_renderer.py | 17 +++- core/dbt/parser/sources.py | 19 ++++ core/dbt/task/freshness.py | 20 +++- dev-requirements.txt | 2 +- tests/functional/sources/fixtures.py | 23 +++++ .../sources/test_source_freshness.py | 16 ++++ tests/unit/parser/test_parser.py | 94 +++++++++++++++++++ tests/unit/parser/test_schema_renderer.py | 2 + 12 files changed, 207 insertions(+), 8 deletions(-) diff --git a/core/dbt/artifacts/resources/v1/source_definition.py b/core/dbt/artifacts/resources/v1/source_definition.py index 9044307563e..e09095fa0af 100644 --- a/core/dbt/artifacts/resources/v1/source_definition.py +++ b/core/dbt/artifacts/resources/v1/source_definition.py @@ -59,6 +59,7 @@ class ParsedSourceMandatory(GraphResource, HasRelationMetadata): class SourceDefinition(ParsedSourceMandatory): quoting: Quoting = field(default_factory=Quoting) loaded_at_field: Optional[str] = None + loaded_at_query: Optional[str] = None freshness: Optional[FreshnessThreshold] = None external: Optional[ExternalTable] = None description: str = "" diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index f9d436a7840..c9a4084e9d4 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -880,7 +880,7 @@ class OperationProvider(RuntimeProvider): # Base context collection, used for parsing configs. class ProviderContext(ManifestContext): - # subclasses are MacroContext, ModelContext, TestContext + # subclasses are MacroContext, ModelContext, TestContext, SourceContext def __init__( self, model, @@ -1558,6 +1558,20 @@ def __init__( self._search_package = search_package +class SourceContext(ProviderContext): + # SourceContext is being used to render jinja SQL during execution of + # custom SQL in source freshness. It is not used for parsing. + model: SourceDefinition + + @contextproperty() + def this(self) -> Optional[RelationProxy]: + return self.db_wrapper.Relation.create_from(self.config, self.model) + + @contextproperty() + def source_node(self) -> SourceDefinition: + return self.model + + class ModelContext(ProviderContext): model: ManifestNode diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 4bb70db5d9c..2aae15b1175 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -1732,6 +1732,7 @@ class ParsedSingularTestPatch(ParsedPatch): ManifestNode = Union[ ManifestSQLNode, SeedNode, + SourceDefinition, ] ResultNode = Union[ diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index f78ba15a50f..b01b9b7c34c 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -317,6 +317,7 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasColumnAndTestProps): config: Dict[str, Any] = field(default_factory=dict) loaded_at_field: Optional[str] = None loaded_at_field_present: Optional[bool] = None + loaded_at_query: Optional[str] = None identifier: Optional[str] = None quoting: Quoting = field(default_factory=Quoting) freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold) @@ -342,6 +343,7 @@ class UnparsedSourceDefinition(dbtClassMixin): freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold) loaded_at_field: Optional[str] = None loaded_at_field_present: Optional[bool] = None + loaded_at_query: Optional[str] = None tables: List[UnparsedSourceTableDefinition] = field(default_factory=list) tags: List[str] = field(default_factory=list) config: Dict[str, Any] = field(default_factory=dict) @@ -379,6 +381,7 @@ class SourceTablePatch(dbtClassMixin): docs: Optional[Docs] = None loaded_at_field: Optional[str] = None loaded_at_field_present: Optional[bool] = None + loaded_at_query: Optional[str] = None identifier: Optional[str] = None quoting: Quoting = field(default_factory=Quoting) freshness: Optional[FreshnessThreshold] = field(default_factory=FreshnessThreshold) @@ -422,6 +425,7 @@ class SourcePatch(dbtClassMixin): freshness: Optional[Optional[FreshnessThreshold]] = field(default_factory=FreshnessThreshold) loaded_at_field: Optional[str] = None loaded_at_field_present: Optional[bool] = None + loaded_at_query: Optional[str] = None tables: Optional[List[SourceTablePatch]] = None tags: Optional[List[str]] = None diff --git a/core/dbt/parser/schema_renderer.py b/core/dbt/parser/schema_renderer.py index b187c4f673f..4ca2f9832ba 100644 --- a/core/dbt/parser/schema_renderer.py +++ b/core/dbt/parser/schema_renderer.py @@ -37,11 +37,21 @@ def _is_norender_key(self, keypath: Keypath) -> bool: "tests" and "data_tests" are both currently supported but "tests" has been deprecated """ # top level descriptions and data_tests - if len(keypath) >= 1 and keypath[0] in ("tests", "data_tests", "description"): + if len(keypath) >= 1 and keypath[0] in ( + "tests", + "data_tests", + "description", + "loaded_at_query", + ): return True # columns descriptions and data_tests - if len(keypath) == 2 and keypath[1] in ("tests", "data_tests", "description"): + if len(keypath) == 2 and keypath[1] in ( + "tests", + "data_tests", + "description", + "loaded_at_query", + ): return True # pre- and post-hooks @@ -69,9 +79,8 @@ def _is_norender_key(self, keypath: Keypath) -> bool: def should_render_keypath(self, keypath: Keypath) -> bool: if len(keypath) < 1: return True - if self.key == "sources": - if keypath[0] == "description": + if keypath[0] in ("description", "loaded_at_query"): return False if keypath[0] == "tables": if self._is_norender_key(keypath[2:]): diff --git a/core/dbt/parser/sources.py b/core/dbt/parser/sources.py index 0fe882750ae..4c0e8d5f953 100644 --- a/core/dbt/parser/sources.py +++ b/core/dbt/parser/sources.py @@ -26,6 +26,7 @@ UnparsedSourceTableDefinition, ) from dbt.events.types import FreshnessConfigProblem, UnusedTables +from dbt.exceptions import ParsingError from dbt.node_types import NodeType from dbt.parser.common import ParserRef from dbt.parser.schema_generic_tests import SchemaGenericTestParser @@ -131,11 +132,28 @@ def parse_source(self, target: UnpatchedSourceDefinition) -> SourceDefinition: # We need to be able to tell the difference between explicitly setting the loaded_at_field to None/null # and when it's simply not set. This allows a user to override the source level loaded_at_field so that # specific table can default to metadata-based freshness. + if table.loaded_at_field_present and table.loaded_at_query: + raise ParsingError( + "Cannot specify both loaded_at_field and loaded_at_query at table level." + ) + if source.loaded_at_field and source.loaded_at_query: + raise ParsingError( + "Cannot specify both loaded_at_field and loaded_at_query at source level." + ) + if table.loaded_at_field_present or table.loaded_at_field is not None: loaded_at_field = table.loaded_at_field else: loaded_at_field = source.loaded_at_field # may be None, that's okay + loaded_at_query: Optional[str] + if table.loaded_at_query is not None: + loaded_at_query = table.loaded_at_query + else: + if table.loaded_at_field_present: + loaded_at_query = None + else: + loaded_at_query = source.loaded_at_query freshness = merge_freshness(source.freshness, table.freshness) quoting = source.quoting.merged(table.quoting) # path = block.path.original_file_path @@ -185,6 +203,7 @@ def parse_source(self, target: UnpatchedSourceDefinition) -> SourceDefinition: meta=meta, loader=source.loader, loaded_at_field=loaded_at_field, + loaded_at_query=loaded_at_query, freshness=freshness, quoting=quoting, resource_type=NodeType.Source, diff --git a/core/dbt/task/freshness.py b/core/dbt/task/freshness.py index 06e78b17c7b..aa434ada14e 100644 --- a/core/dbt/task/freshness.py +++ b/core/dbt/task/freshness.py @@ -15,6 +15,8 @@ PartialSourceFreshnessResult, SourceFreshnessResult, ) +from dbt.clients import jinja +from dbt.context.providers import RuntimeProvider, SourceContext from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import HookNode, SourceDefinition from dbt.contracts.results import RunStatus @@ -114,7 +116,22 @@ def execute(self, compiled_node, manifest): adapter_response: Optional[AdapterResponse] = None freshness: Optional[FreshnessResponse] = None - if compiled_node.loaded_at_field is not None: + if compiled_node.loaded_at_query is not None: + # within the context user can have access to `this`, `source_node`(`model` will point to the same thing), etc + compiled_code = jinja.get_rendered( + compiled_node.loaded_at_query, + SourceContext( + compiled_node, self.config, manifest, RuntimeProvider(), None + ).to_dict(), + compiled_node, + ) + adapter_response, freshness = self.adapter.calculate_freshness_from_custom_sql( + relation, + compiled_code, + macro_resolver=manifest, + ) + status = compiled_node.freshness.status(freshness["age"]) + elif compiled_node.loaded_at_field is not None: adapter_response, freshness = self.adapter.calculate_freshness( relation, compiled_node.loaded_at_field, @@ -146,7 +163,6 @@ def execute(self, compiled_node, manifest): raise DbtRuntimeError( f"Could not compute freshness for source {compiled_node.name}: no 'loaded_at_field' provided and {self.adapter.type()} adapter does not support metadata-based freshness checks." ) - # adapter_response was not returned in previous versions, so this will be None # we cannot call to_dict() on NoneType if adapter_response: diff --git a/dev-requirements.txt b/dev-requirements.txt index 5f393349744..9bdac757abe 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/dbt-labs/dbt-adapters.git@main +git+https://github.com/dbt-labs/dbt-adapters.git@cl/custom_freshness_sql git+https://github.com/dbt-labs/dbt-adapters.git@main#subdirectory=dbt-tests-adapter git+https://github.com/dbt-labs/dbt-common.git@main git+https://github.com/dbt-labs/dbt-postgres.git@main diff --git a/tests/functional/sources/fixtures.py b/tests/functional/sources/fixtures.py index b40b1869541..5777f1fad1c 100644 --- a/tests/functional/sources/fixtures.py +++ b/tests/functional/sources/fixtures.py @@ -472,3 +472,26 @@ - name: test_table identifier: source """ + +freshness_via_custom_sql_schema_yml = """version: 2 +sources: + - name: test_source + freshness: + warn_after: {count: 10, period: hour} + schema: "{{ var(env_var('DBT_TEST_SCHEMA_NAME_VARIABLE')) }}" + quoting: + identifier: True + tags: + - my_test_source_tag + tables: + - name: source_a + identifier: source + loaded_at_field: "{{ var('test_loaded_at') | as_text }}" + - name: source_b + identifier: source + loaded_at_query: "select max({{ var('test_loaded_at') | as_text }}) from {{this}}" + - name: source_c + identifier: source + loaded_at_query: "select {{current_timestamp()}}" + +""" diff --git a/tests/functional/sources/test_source_freshness.py b/tests/functional/sources/test_source_freshness.py index 70c8866869e..e78c30e5e97 100644 --- a/tests/functional/sources/test_source_freshness.py +++ b/tests/functional/sources/test_source_freshness.py @@ -17,6 +17,7 @@ error_models_model_sql, error_models_schema_yml, filtered_models_schema_yml, + freshness_via_custom_sql_schema_yml, freshness_via_metadata_schema_yml, override_freshness_models_schema_yml, ) @@ -578,3 +579,18 @@ def test_hooks_do_not_run_for_source_freshness( ) # default behaviour - no hooks run in source freshness self._assert_project_hooks_not_called(log_output) + + +class TestSourceFreshnessCustomSQL(SuccessfulSourceFreshnessTest): + @pytest.fixture(scope="class") + def models(self): + return {"schema.yml": freshness_via_custom_sql_schema_yml} + + def test_source_freshness_custom_sql(self, project): + result = self.run_dbt_with_vars(project, ["source", "freshness"], expect_pass=True) + # They are the same source but different queries were executed for each + assert {r.node.name: r.status for r in result} == { + "source_a": "warn", + "source_b": "warn", + "source_c": "pass", + } diff --git a/tests/unit/parser/test_parser.py b/tests/unit/parser/test_parser.py index 8894e47ce84..59d64f679fc 100644 --- a/tests/unit/parser/test_parser.py +++ b/tests/unit/parser/test_parser.py @@ -404,6 +404,48 @@ def assertEqualNodes(node_one, node_two): - unique """ +SOURCE_CUSTOM_FRESHNESS_AT_SOURCE = """ +sources: + - name: my_source + loaded_at_query: "select 1 as id" + tables: + - name: my_table +""" +SOURCE_CUSTOM_FRESHNESS_AT_SOURCE_FIELD_AT_TABLE = """ +sources: + - name: my_source + loaded_at_query: "select 1 as id" + tables: + - name: my_table + loaded_at_field: test +""" +SOURCE_FIELD_AT_SOURCE_CUSTOM_FRESHNESS_AT_TABLE = """ +sources: + - name: my_source + loaded_at_field: test + tables: + - name: my_table + loaded_at_query: "select 1 as id" +""" +SOURCE_FIELD_AT_CUSTOM_FRESHNESS_BOTH_AT_TABLE = """ +sources: + - name: my_source + loaded_at_field: test + tables: + - name: my_table + loaded_at_query: "select 1 as id" + loaded_at_field: test +""" +SOURCE_FIELD_AT_CUSTOM_FRESHNESS_BOTH_AT_SOURCE = """ +sources: + - name: my_source + loaded_at_field: test + loaded_at_query: "select 1 as id" + tables: + - name: my_table + loaded_at_field: test +""" + class SchemaParserTest(BaseParserTest): def setUp(self): @@ -448,6 +490,58 @@ def test__read_basic_source(self): self.assertEqual(source_values[0].table.description, "") self.assertEqual(len(source_values[0].table.columns), 0) + @mock.patch("dbt.parser.sources.get_adapter") + def test_parse_source_custom_freshness_at_source(self, _): + block = self.file_block_for(SOURCE_CUSTOM_FRESHNESS_AT_SOURCE, "test_one.yml") + dct = yaml_from_file(block.file) + self.parser.parse_file(block, dct) + unpatched_src_default = self.parser.manifest.sources["source.snowplow.my_source.my_table"] + src_default = self.source_patcher.parse_source(unpatched_src_default) + assert src_default.loaded_at_query == "select 1 as id" + + @mock.patch("dbt.parser.sources.get_adapter") + def test_parse_source_custom_freshness_at_source_field_at_table(self, _): + block = self.file_block_for( + SOURCE_CUSTOM_FRESHNESS_AT_SOURCE_FIELD_AT_TABLE, "test_one.yml" + ) + dct = yaml_from_file(block.file) + self.parser.parse_file(block, dct) + unpatched_src_default = self.parser.manifest.sources["source.snowplow.my_source.my_table"] + src_default = self.source_patcher.parse_source(unpatched_src_default) + # source loaded_at_query not propagate to table since there's loaded_at_field defined + assert src_default.loaded_at_query is None + + @mock.patch("dbt.parser.sources.get_adapter") + def test_parse_source_field_at_source_custom_freshness_at_table(self, _): + block = self.file_block_for( + SOURCE_FIELD_AT_SOURCE_CUSTOM_FRESHNESS_AT_TABLE, "test_one.yml" + ) + dct = yaml_from_file(block.file) + self.parser.parse_file(block, dct) + unpatched_src_default = self.parser.manifest.sources["source.snowplow.my_source.my_table"] + src_default = self.source_patcher.parse_source(unpatched_src_default) + assert src_default.loaded_at_query == "select 1 as id" + + @mock.patch("dbt.parser.sources.get_adapter") + def test_parse_source_field_at_custom_freshness_both_at_table_fails(self, _): + block = self.file_block_for(SOURCE_FIELD_AT_CUSTOM_FRESHNESS_BOTH_AT_TABLE, "test_one.yml") + dct = yaml_from_file(block.file) + self.parser.parse_file(block, dct) + unpatched_src_default = self.parser.manifest.sources["source.snowplow.my_source.my_table"] + with self.assertRaises(ParsingError): + self.source_patcher.parse_source(unpatched_src_default) + + @mock.patch("dbt.parser.sources.get_adapter") + def test_parse_source_field_at_custom_freshness_both_at_source_fails(self, _): + block = self.file_block_for( + SOURCE_FIELD_AT_CUSTOM_FRESHNESS_BOTH_AT_SOURCE, "test_one.yml" + ) + dct = yaml_from_file(block.file) + self.parser.parse_file(block, dct) + unpatched_src_default = self.parser.manifest.sources["source.snowplow.my_source.my_table"] + with self.assertRaises(ParsingError): + self.source_patcher.parse_source(unpatched_src_default) + def test__parse_basic_source(self): block = self.file_block_for(SINGLE_TABLE_SOURCE, "test_one.yml") dct = yaml_from_file(block.file) diff --git a/tests/unit/parser/test_schema_renderer.py b/tests/unit/parser/test_schema_renderer.py index 9a703617aea..c2640066047 100644 --- a/tests/unit/parser/test_schema_renderer.py +++ b/tests/unit/parser/test_schema_renderer.py @@ -56,10 +56,12 @@ def test__sources(self): dct = { "name": "my_source", "description": "{{ alt_var }}", + "loaded_at_query": "select max(ordered_at) from {{ this }}", "tables": [ { "name": "my_table", "description": "{{ alt_var }}", + "loaded_at_query": "select max(ordered_at) from {{ this }}", "columns": [ { "name": "id",