Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move microbatch compilation to .compile method #11060

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20241121-125630.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add `batch` context object to model jinja context
time: 2024-11-21T12:56:30.715473-06:00
custom:
Author: QMalcolm
Issue: "11025"
5 changes: 3 additions & 2 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,10 @@ def resolve_event_time_filter(self, target: ManifestNode) -> Optional[EventTimeF
and self.model.config.materialized == "incremental"
and self.model.config.incremental_strategy == "microbatch"
and self.manifest.use_microbatch_batches(project_name=self.config.project_name)
and self.model.batch is not None
):
start = self.model.config.get("__dbt_internal_microbatch_event_time_start")
end = self.model.config.get("__dbt_internal_microbatch_event_time_end")
start = self.model.batch.event_time_start
end = self.model.batch.event_time_end

if start is not None or end is not None:
event_time_filter = EventTimeFilter(
Expand Down
22 changes: 22 additions & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
ConstraintType,
ModelLevelConstraint,
)
from dbt_common.dataclass_schema import dbtClassMixin
from dbt_common.events.contextvars import set_log_contextvars
from dbt_common.events.functions import warn_or_error

Expand Down Expand Up @@ -442,9 +443,30 @@ def resource_class(cls) -> Type[HookNodeResource]:
return HookNodeResource


@dataclass
class BatchContext(dbtClassMixin):
id: str
event_time_start: datetime
event_time_end: datetime

def __post_serialize__(self, data, context):
# This is insane, but necessary, I apologize. Mashumaro handles the
# dictification of this class via a compile time generated `to_dict`
# method based off of the _typing_ of th class. By default `datetime`
# types are converted to strings. We don't want that, we want them to
# stay datetimes.
# Note: This is safe because the `BatchContext` isn't part of the artifact
# and thus doesn't get written out.
new_data = super().__post_serialize__(data, context)
new_data["event_time_start"] = self.event_time_start
new_data["event_time_end"] = self.event_time_end
return new_data


@dataclass
class ModelNode(ModelResource, CompiledNode):
previous_batch_results: Optional[BatchResults] = None
batch: Optional[BatchContext] = None
_has_this: Optional[bool] = None

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
Expand Down
25 changes: 12 additions & 13 deletions core/dbt/materializations/incremental/microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,25 +100,25 @@ def build_batches(self, start: datetime, end: datetime) -> List[BatchType]:

return batches

def build_batch_context(self, incremental_batch: bool) -> Dict[str, Any]:
def build_jinja_context_for_batch(self, incremental_batch: bool) -> Dict[str, Any]:
"""
Create context with entries that reflect microbatch model + incremental execution state

Assumes self.model has been (re)-compiled with necessary batch filters applied.
"""
batch_context: Dict[str, Any] = {}
jinja_context: Dict[str, Any] = {}

# Microbatch model properties
batch_context["model"] = self.model.to_dict()
batch_context["sql"] = self.model.compiled_code
batch_context["compiled_code"] = self.model.compiled_code
jinja_context["model"] = self.model.to_dict()
jinja_context["sql"] = self.model.compiled_code
jinja_context["compiled_code"] = self.model.compiled_code

# Add incremental context variables for batches running incrementally
if incremental_batch:
batch_context["is_incremental"] = lambda: True
batch_context["should_full_refresh"] = lambda: False
jinja_context["is_incremental"] = lambda: True
jinja_context["should_full_refresh"] = lambda: False

return batch_context
return jinja_context

@staticmethod
def offset_timestamp(timestamp: datetime, batch_size: BatchSize, offset: int) -> datetime:
Expand Down Expand Up @@ -193,12 +193,11 @@ def truncate_timestamp(timestamp: datetime, batch_size: BatchSize) -> datetime:
return truncated

@staticmethod
def format_batch_start(
batch_start: Optional[datetime], batch_size: BatchSize
) -> Optional[str]:
if batch_start is None:
return batch_start
def batch_id(start_time: datetime, batch_size: BatchSize) -> str:
return MicrobatchBuilder.format_batch_start(start_time, batch_size).replace("-", "")

@staticmethod
def format_batch_start(batch_start: datetime, batch_size: BatchSize) -> str:
return str(
batch_start.date() if (batch_start and batch_size != BatchSize.hour) else batch_start
)
Expand Down
48 changes: 31 additions & 17 deletions core/dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from dbt.config import RuntimeConfig
from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import HookNode, ModelNode, ResultNode
from dbt.contracts.graph.nodes import BatchContext, HookNode, ModelNode, ResultNode
from dbt.events.types import (
GenericExceptionOnRun,
LogHookEndLine,
Expand Down Expand Up @@ -341,6 +341,33 @@
self.batches: Dict[int, BatchType] = {}
self.relation_exists: bool = False

def compile(self, manifest: Manifest):
if self.batch_idx is not None:
batch = self.batches[self.batch_idx]

Check warning on line 346 in core/dbt/task/run.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/run.py#L345-L346

Added lines #L345 - L346 were not covered by tests

# LEGACY: Set start/end in context prior to re-compiling (Will be removed for 1.10+)
# TODO: REMOVE before 1.10 GA
self.node.config["__dbt_internal_microbatch_event_time_start"] = batch[0]
self.node.config["__dbt_internal_microbatch_event_time_end"] = batch[1]

Check warning on line 351 in core/dbt/task/run.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/run.py#L350-L351

Added lines #L350 - L351 were not covered by tests
# Create batch context on model node prior to re-compiling
self.node.batch = BatchContext(

Check warning on line 353 in core/dbt/task/run.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/run.py#L353

Added line #L353 was not covered by tests
id=MicrobatchBuilder.batch_id(batch[0], self.node.config.batch_size),
event_time_start=batch[0],
event_time_end=batch[1],
)
# Recompile node to re-resolve refs with event time filters rendered, update context
self.compiler.compile_node(

Check warning on line 359 in core/dbt/task/run.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/run.py#L359

Added line #L359 was not covered by tests
self.node,
manifest,
{},
split_suffix=MicrobatchBuilder.format_batch_start(
batch[0], self.node.config.batch_size
),
)

# Skips compilation for non-batch runs
return self.node

def set_batch_idx(self, batch_idx: int) -> None:
self.batch_idx = batch_idx

Expand All @@ -353,7 +380,7 @@
def describe_node(self) -> str:
return f"{self.node.language} microbatch model {self.get_node_representation()}"

def describe_batch(self, batch_start: Optional[datetime]) -> str:
def describe_batch(self, batch_start: datetime) -> str:
# Only visualize date if batch_start year/month/day
formatted_batch_start = MicrobatchBuilder.format_batch_start(
batch_start, self.node.config.batch_size
Expand Down Expand Up @@ -530,24 +557,11 @@
# call materialization_macro to get a batch-level run result
start_time = time.perf_counter()
try:
# Set start/end in context prior to re-compiling
model.config["__dbt_internal_microbatch_event_time_start"] = batch[0]
model.config["__dbt_internal_microbatch_event_time_end"] = batch[1]

# Recompile node to re-resolve refs with event time filters rendered, update context
self.compiler.compile_node(
model,
manifest,
{},
split_suffix=MicrobatchBuilder.format_batch_start(
batch[0], model.config.batch_size
),
)
# Update jinja context with batch context members
batch_context = microbatch_builder.build_batch_context(
jinja_context = microbatch_builder.build_jinja_context_for_batch(
incremental_batch=self.relation_exists
)
context.update(batch_context)
context.update(jinja_context)

# Materialize batch and cache any materialized relations
result = MacroGenerator(
Expand Down
38 changes: 23 additions & 15 deletions tests/functional/microbatch/test_microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@
select * from {{ ref('microbatch_model') }}
"""

invalid_batch_context_macro_sql = """
{% macro check_invalid_batch_context() %}
invalid_batch_jinja_context_macro_sql = """
{% macro check_invalid_batch_jinja_context() %}

{% if model is not mapping %}
{{ exceptions.raise_compiler_error("`model` is invalid: expected mapping type") }}
Expand All @@ -83,9 +83,9 @@
"""

microbatch_model_with_context_checks_sql = """
{{ config(pre_hook="{{ check_invalid_batch_context() }}", materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}
{{ config(pre_hook="{{ check_invalid_batch_jinja_context() }}", materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}

{{ check_invalid_batch_context() }}
{{ check_invalid_batch_jinja_context() }}
select * from {{ ref('input_model') }}
"""

Expand Down Expand Up @@ -404,7 +404,7 @@ class TestMicrobatchJinjaContext(BaseMicrobatchTest):

@pytest.fixture(scope="class")
def macros(self):
return {"check_batch_context.sql": invalid_batch_context_macro_sql}
return {"check_batch_jinja_context.sql": invalid_batch_jinja_context_macro_sql}

@pytest.fixture(scope="class")
def models(self):
Expand Down Expand Up @@ -498,6 +498,13 @@ def test_run_with_event_time(self, project):
{{ config(materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}
{{ log("start: "~ model.config.__dbt_internal_microbatch_event_time_start, info=True)}}
{{ log("end: "~ model.config.__dbt_internal_microbatch_event_time_end, info=True)}}
{% if model.batch %}
{{ log("batch.event_time_start: "~ model.batch.event_time_start, info=True)}}
{{ log("batch.event_time_end: "~ model.batch.event_time_end, info=True)}}
{{ log("batch.id: "~ model.batch.id, info=True)}}
{{ log("start timezone: "~ model.batch.event_time_start.tzinfo, info=True)}}
{{ log("end timezone: "~ model.batch.event_time_end.tzinfo, info=True)}}
{% endif %}
select * from {{ ref('input_model') }}
"""

Expand All @@ -516,12 +523,23 @@ def test_run_with_event_time_logs(self, project):

assert "start: 2020-01-01 00:00:00+00:00" in logs
assert "end: 2020-01-02 00:00:00+00:00" in logs
assert "batch.event_time_start: 2020-01-01 00:00:00+00:00" in logs
assert "batch.event_time_end: 2020-01-02 00:00:00+00:00" in logs
assert "batch.id: 20200101" in logs
assert "start timezone: UTC" in logs
assert "end timezone: UTC" in logs

assert "start: 2020-01-02 00:00:00+00:00" in logs
assert "end: 2020-01-03 00:00:00+00:00" in logs
assert "batch.event_time_start: 2020-01-02 00:00:00+00:00" in logs
assert "batch.event_time_end: 2020-01-03 00:00:00+00:00" in logs
assert "batch.id: 20200102" in logs

assert "start: 2020-01-03 00:00:00+00:00" in logs
assert "end: 2020-01-03 13:57:00+00:00" in logs
assert "batch.event_time_start: 2020-01-03 00:00:00+00:00" in logs
assert "batch.event_time_end: 2020-01-03 13:57:00+00:00" in logs
assert "batch.id: 20200103" in logs


microbatch_model_failing_incremental_partition_sql = """
Expand Down Expand Up @@ -675,16 +693,6 @@ def test_run_with_event_time(self, project):
with patch_microbatch_end_time("2020-01-03 13:57:00"):
run_dbt(["run"])

# Compiled paths - compiled model without filter only
assert read_file(
project.project_root,
"target",
"compiled",
"test",
"models",
"microbatch_model.sql",
)

# Compiled paths - batch compilations
assert read_file(
project.project_root,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/contracts/graph/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
"deprecation_date",
"defer_relation",
"time_spine",
"batch",
}
)

Expand Down
9 changes: 4 additions & 5 deletions tests/unit/materializations/incremental/test_microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,11 +489,11 @@ def test_build_batches(self, microbatch_model, start, end, batch_size, expected_
assert len(actual_batches) == len(expected_batches)
assert actual_batches == expected_batches

def test_build_batch_context_incremental_batch(self, microbatch_model):
def test_build_jinja_context_for_incremental_batch(self, microbatch_model):
microbatch_builder = MicrobatchBuilder(
model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None
)
context = microbatch_builder.build_batch_context(incremental_batch=True)
context = microbatch_builder.build_jinja_context_for_batch(incremental_batch=True)

assert context["model"] == microbatch_model.to_dict()
assert context["sql"] == microbatch_model.compiled_code
Expand All @@ -502,11 +502,11 @@ def test_build_batch_context_incremental_batch(self, microbatch_model):
assert context["is_incremental"]() is True
assert context["should_full_refresh"]() is False

def test_build_batch_context_incremental_batch_false(self, microbatch_model):
def test_build_jinja_context_for_incremental_batch_false(self, microbatch_model):
microbatch_builder = MicrobatchBuilder(
model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None
)
context = microbatch_builder.build_batch_context(incremental_batch=False)
context = microbatch_builder.build_jinja_context_for_batch(incremental_batch=False)

assert context["model"] == microbatch_model.to_dict()
assert context["sql"] == microbatch_model.compiled_code
Expand Down Expand Up @@ -605,7 +605,6 @@ def test_truncate_timestamp(self, timestamp, batch_size, expected_timestamp):
@pytest.mark.parametrize(
"batch_size,batch_start,expected_formatted_batch_start",
[
(None, None, None),
(BatchSize.year, datetime(2020, 1, 1, 1), "2020-01-01"),
(BatchSize.month, datetime(2020, 1, 1, 1), "2020-01-01"),
(BatchSize.day, datetime(2020, 1, 1, 1), "2020-01-01"),
Expand Down
Loading