From 2a146bb51aa0135fbc36cb39d3c00e895c403ee6 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 23 Jul 2024 13:56:59 -0400 Subject: [PATCH] first stab: microbatch --- core/dbt/artifacts/resources/v1/components.py | 1 + .../resources/v1/source_definition.py | 1 + core/dbt/context/providers.py | 41 +++++++++++++++++-- core/dbt/contracts/graph/nodes.py | 4 ++ core/dbt/task/run.py | 10 ++++- tests/unit/context/test_providers.py | 10 ++++- tests/unit/contracts/graph/test_manifest.py | 3 ++ 7 files changed, 63 insertions(+), 7 deletions(-) diff --git a/core/dbt/artifacts/resources/v1/components.py b/core/dbt/artifacts/resources/v1/components.py index 6e6605c18ab..ce790dcd323 100644 --- a/core/dbt/artifacts/resources/v1/components.py +++ b/core/dbt/artifacts/resources/v1/components.py @@ -219,6 +219,7 @@ class CompiledResource(ParsedResource): extra_ctes: List[InjectedCTE] = field(default_factory=list) _pre_injected_sql: Optional[str] = None contract: Contract = field(default_factory=Contract) + event_time: Optional[str] = None def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): dct = super().__post_serialize__(dct, context) diff --git a/core/dbt/artifacts/resources/v1/source_definition.py b/core/dbt/artifacts/resources/v1/source_definition.py index ac0fcfca1b2..7e3d61fd0b8 100644 --- a/core/dbt/artifacts/resources/v1/source_definition.py +++ b/core/dbt/artifacts/resources/v1/source_definition.py @@ -70,3 +70,4 @@ class SourceDefinition(ParsedSourceMandatory): unrendered_config: Dict[str, Any] = field(default_factory=dict) relation_name: Optional[str] = None created_at: float = field(default_factory=lambda: time.time()) + event_time: Optional[str] = None diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index bbb5f269c93..30b59a937aa 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -20,6 +20,7 @@ from dbt import selected_resources from dbt.adapters.base.column import Column +from dbt.adapters.base.relation import EventTimeFilter from dbt.adapters.contracts.connection import AdapterResponse from dbt.adapters.exceptions import MissingConfigError from dbt.adapters.factory import ( @@ -230,6 +231,21 @@ def Relation(self): def resolve_limit(self) -> Optional[int]: return 0 if getattr(self.config.args, "EMPTY", False) else None + @property + def resolve_event_time_filter(self) -> Optional[EventTimeFilter]: + field_name = getattr(self.model, "event_time") + start_time = getattr(self.model, "start_time") + end_time = getattr(self.model, "end_time") + + if start_time and end_time and field_name: + return EventTimeFilter( + field_name=field_name, + start_time=start_time, + end_time=end_time, + ) + + return None + @abc.abstractmethod def __call__(self, *args: str) -> Union[str, RelationProxy, MetricReference]: pass @@ -545,7 +561,11 @@ def resolve( def create_relation(self, target_model: ManifestNode) -> RelationProxy: if target_model.is_ephemeral_model: self.model.set_cte(target_model.unique_id, None) - return self.Relation.create_ephemeral_from(target_model, limit=self.resolve_limit) + return self.Relation.create_ephemeral_from( + target_model, + limit=self.resolve_limit, + event_time_filter=self.resolve_event_time_filter, + ) elif ( hasattr(target_model, "defer_relation") and target_model.defer_relation @@ -563,10 +583,18 @@ def create_relation(self, target_model: ManifestNode) -> RelationProxy: ) ): return self.Relation.create_from( - self.config, target_model.defer_relation, limit=self.resolve_limit + self.config, + target_model.defer_relation, + limit=self.resolve_limit, + event_time_filter=self.resolve_event_time_filter, ) else: - return self.Relation.create_from(self.config, target_model, limit=self.resolve_limit) + return self.Relation.create_from( + self.config, + target_model, + limit=self.resolve_limit, + event_time_filter=self.resolve_event_time_filter, + ) def validate( self, @@ -633,7 +661,12 @@ def resolve(self, source_name: str, table_name: str): target_kind="source", disabled=(isinstance(target_source, Disabled)), ) - return self.Relation.create_from(self.config, target_source, limit=self.resolve_limit) + return self.Relation.create_from( + self.config, + target_source, + limit=self.resolve_limit, + event_time_filter=self.resolve_event_time_filter, + ) class RuntimeUnitTestSourceResolver(BaseSourceResolver): diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index da42fb7d766..f9dcfae3e00 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -379,6 +379,10 @@ class CompiledNode(CompiledResource, ParsedNode): """Contains attributes necessary for SQL files and nodes with refs, sources, etc, so all ManifestNodes except SeedNode.""" + # TODO: should these go here? and get set during execution? + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None + @property def empty(self): return not self.raw_code.strip() diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 6263ee66b46..120c4809b41 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -1,7 +1,7 @@ import functools import threading import time -from datetime import datetime +from datetime import datetime, timedelta from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Set, Tuple from dbt import tracking, utils @@ -214,6 +214,14 @@ def print_result_line(self, result): ) def before_execute(self): + if self.node.config.get("microbatch"): + # TODO: actually use partition_grain + # partition_grain = self.node.config.get("partition_grain") + lookback = self.node.config.get("lookback") + self.node.end_time = datetime.now() + self.node.start_time = self.node.end_time - timedelta(days=lookback) + self.node.start_time.replace(minute=0, hour=0, second=0, microsecond=0) + self.print_start_line() def after_execute(self, result): diff --git a/tests/unit/context/test_providers.py b/tests/unit/context/test_providers.py index 224675143e4..5b9b3e2d5ee 100644 --- a/tests/unit/context/test_providers.py +++ b/tests/unit/context/test_providers.py @@ -41,9 +41,12 @@ def resolver(self): mock_db_wrapper = mock.Mock() mock_db_wrapper.Relation = BaseRelation + mock_model = mock.Mock() + mock_model.event_time = None + return RuntimeRefResolver( db_wrapper=mock_db_wrapper, - model=mock.Mock(), + model=mock_model, config=mock.Mock(), manifest=mock.Mock(), ) @@ -82,9 +85,12 @@ def resolver(self): mock_db_wrapper = mock.Mock() mock_db_wrapper.Relation = BaseRelation + mock_model = mock.Mock() + mock_model.event_time = None + return RuntimeSourceResolver( db_wrapper=mock_db_wrapper, - model=mock.Mock(), + model=mock_model, config=mock.Mock(), manifest=mock.Mock(), ) diff --git a/tests/unit/contracts/graph/test_manifest.py b/tests/unit/contracts/graph/test_manifest.py index 35e96308da7..fa1a49da422 100644 --- a/tests/unit/contracts/graph/test_manifest.py +++ b/tests/unit/contracts/graph/test_manifest.py @@ -94,6 +94,9 @@ "constraints", "deprecation_date", "defer_relation", + "event_time", + "start_time", + "end_time", } )