From 199ebf4ba05e434c71350c96bf27af0d6729c06d Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Thu, 19 Sep 2024 19:18:54 +0100 Subject: [PATCH 1/7] first pass: split_suffix --- core/dbt/compilation.py | 11 ++++++++--- core/dbt/contracts/graph/nodes.py | 14 +++++++++++++- core/dbt/task/run.py | 29 ++++++++++++++++++++++++----- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 72a5b6f016b..0ffa73df715 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -521,7 +521,9 @@ def write_graph_file(self, linker: Linker, manifest: Manifest): linker.write_graph(graph_path, manifest) # writes the "compiled_code" into the target/compiled directory - def _write_node(self, node: ManifestSQLNode) -> ManifestSQLNode: + def _write_node( + self, node: ManifestSQLNode, split_suffix: Optional[str] = None + ) -> ManifestSQLNode: if not node.extra_ctes_injected or node.resource_type in ( NodeType.Snapshot, NodeType.Seed, @@ -530,7 +532,9 @@ def _write_node(self, node: ManifestSQLNode) -> ManifestSQLNode: fire_event(WritingInjectedSQLForNode(node_info=get_node_info())) if node.compiled_code: - node.compiled_path = node.get_target_write_path(self.config.target_path, "compiled") + node.compiled_path = node.get_target_write_path( + self.config.target_path, "compiled", split_suffix + ) node.write_node(self.config.project_root, node.compiled_path, node.compiled_code) return node @@ -540,6 +544,7 @@ def compile_node( manifest: Manifest, extra_context: Optional[Dict[str, Any]] = None, write: bool = True, + split_suffix: Optional[str] = None, ) -> ManifestSQLNode: """This is the main entry point into this code. It's called by CompileRunner.compile, GenericRPCRunner.compile, and @@ -562,7 +567,7 @@ def compile_node( node, _ = self._recursively_prepend_ctes(node, manifest, extra_context) if write: - self._write_node(node) + self._write_node(node, split_suffix=split_suffix) return node diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index b28910c0de3..13c3df0e2e7 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass, field from datetime import datetime +from pathlib import Path from typing import ( Any, Dict, @@ -243,7 +244,9 @@ def clear_event_status(self): @dataclass class ParsedNode(ParsedResource, NodeInfoMixin, ParsedNodeMandatory, SerializableType): - def get_target_write_path(self, target_path: str, subdirectory: str): + def get_target_write_path( + self, target_path: str, subdirectory: str, split_suffix: Optional[str] = None + ): # This is called for both the "compiled" subdirectory of "target" and the "run" subdirectory if os.path.basename(self.path) == os.path.basename(self.original_file_path): # One-to-one relationship of nodes to files. @@ -251,6 +254,15 @@ def get_target_write_path(self, target_path: str, subdirectory: str): else: # Many-to-one relationship of nodes to files. path = os.path.join(self.original_file_path, self.path) + + if split_suffix: + pathlib_path = Path(path) + path = str( + pathlib_path.parent + / pathlib_path.stem + / (pathlib_path.stem + f"_{split_suffix}" + pathlib_path.suffix) + ) + target_write_path = os.path.join(target_path, subdirectory, self.package_name, path) return target_write_path diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index e6e380b4063..531059bbad2 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -2,8 +2,19 @@ import os import threading import time -from datetime import datetime -from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Set, Tuple, Type +from datetime import date, datetime +from typing import ( + AbstractSet, + Any, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + Union, +) from dbt import tracking, utils from dbt.adapters.base import BaseRelation @@ -197,12 +208,18 @@ def describe_node(self) -> str: def describe_batch(self, batch_start: Optional[datetime]) -> str: # Only visualize date if batch_start year/month/day - formatted_batch_start = ( + formatted_batch_start = self.format_batch_start(batch_start) + + return f"batch {formatted_batch_start} of {self.get_node_representation()}" + + def format_batch_start( + self, batch_start: Optional[datetime] + ) -> Optional[Union[date, datetime]]: + return ( batch_start.date() if (batch_start and self.node.config.batch_size != BatchSize.hour) else batch_start ) - return f"batch {formatted_batch_start} of {self.get_node_representation()}" def print_start_line(self): fire_event( @@ -463,7 +480,9 @@ def _execute_microbatch_materialization( model.config["__dbt_internal_microbatch_event_time_end"] = batch[1] # Recompile node to re-resolve refs with event time filters rendered, update context - self.compiler.compile_node(model, manifest, {}) + self.compiler.compile_node( + model, manifest, {}, split_suffix=str(self.format_batch_start(batch[0])) + ) context["model"] = model context["sql"] = model.compiled_code context["compiled_code"] = model.compiled_code From 6c3aecca4a573bc68e3f0be58807fef07037a8c9 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Fri, 20 Sep 2024 15:47:30 +0100 Subject: [PATCH 2/7] tests, split_suffix in run artifacts --- core/dbt/context/providers.py | 15 ++++- core/dbt/contracts/graph/nodes.py | 11 ++++ core/dbt/task/run.py | 29 ++------ .../functional/microbatch/test_microbatch.py | 66 +++++++++++++++++++ 4 files changed, 95 insertions(+), 26 deletions(-) diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 898437bf4da..7d9ce8befcf 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -51,6 +51,7 @@ Exposure, Macro, ManifestNode, + ModelNode, Resource, SeedNode, SemanticModel, @@ -972,7 +973,19 @@ def write(self, payload: str) -> str: # macros/source defs aren't 'writeable'. if isinstance(self.model, (Macro, SourceDefinition)): raise MacrosSourcesUnWriteableError(node=self.model) - self.model.build_path = self.model.get_target_write_path(self.config.target_path, "run") + + split_suffix = None + if ( + isinstance(self.model, ModelNode) + and self.model.config.get("incremental_strategy") == "microbatch" + ): + split_suffix = self.model.format_batch_start( + self.model.config.get("__dbt_internal_microbatch_event_time_start") + ) + + self.model.build_path = self.model.get_target_write_path( + self.config.target_path, "run", split_suffix=split_suffix + ) self.model.write_node(self.config.project_root, self.model.build_path, payload) return "" diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 13c3df0e2e7..4494d1c6efc 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -60,6 +60,7 @@ from dbt.artifacts.resources import SqlOperation as SqlOperationResource from dbt.artifacts.resources import TimeSpine from dbt.artifacts.resources import UnitTestDefinition as UnitTestDefinitionResource +from dbt.artifacts.resources.types import BatchSize from dbt.contracts.graph.model_config import UnitTestNodeConfig from dbt.contracts.graph.node_args import ModelNodeArgs from dbt.contracts.graph.unparsed import ( @@ -571,6 +572,16 @@ def infer_primary_key(self, data_tests: List["GenericTestNode"]) -> List[str]: return [] + def format_batch_start(self, batch_start: Optional[datetime]) -> Optional[str]: + if batch_start is None: + return batch_start + + return str( + batch_start.date() + if (batch_start and self.config.batch_size != BatchSize.hour) + else batch_start + ) + def same_contents(self, old, adapter_type) -> bool: return super().same_contents(old, adapter_type) and self.same_ref_representation(old) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 531059bbad2..ae3908f4596 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -2,19 +2,8 @@ import os import threading import time -from datetime import date, datetime -from typing import ( - AbstractSet, - Any, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Type, - Union, -) +from datetime import datetime +from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Set, Tuple, Type from dbt import tracking, utils from dbt.adapters.base import BaseRelation @@ -25,7 +14,6 @@ ) from dbt.adapters.exceptions import MissingMaterializationError from dbt.artifacts.resources import Hook -from dbt.artifacts.resources.types import BatchSize from dbt.artifacts.schemas.results import ( BaseResult, NodeStatus, @@ -208,19 +196,10 @@ def describe_node(self) -> str: def describe_batch(self, batch_start: Optional[datetime]) -> str: # Only visualize date if batch_start year/month/day - formatted_batch_start = self.format_batch_start(batch_start) + formatted_batch_start = self.node.format_batch_start(batch_start) return f"batch {formatted_batch_start} of {self.get_node_representation()}" - def format_batch_start( - self, batch_start: Optional[datetime] - ) -> Optional[Union[date, datetime]]: - return ( - batch_start.date() - if (batch_start and self.node.config.batch_size != BatchSize.hour) - else batch_start - ) - def print_start_line(self): fire_event( LogStartLine( @@ -481,7 +460,7 @@ def _execute_microbatch_materialization( # Recompile node to re-resolve refs with event time filters rendered, update context self.compiler.compile_node( - model, manifest, {}, split_suffix=str(self.format_batch_start(batch[0])) + model, manifest, {}, split_suffix=model.format_batch_start(batch[0]) ) context["model"] = model context["sql"] = model.compiled_code diff --git a/tests/functional/microbatch/test_microbatch.py b/tests/functional/microbatch/test_microbatch.py index cf8e018727f..7710e09b6e5 100644 --- a/tests/functional/microbatch/test_microbatch.py +++ b/tests/functional/microbatch/test_microbatch.py @@ -5,6 +5,7 @@ from dbt.tests.util import ( patch_microbatch_end_time, + read_file, relation_from_name, run_dbt, run_dbt_and_capture, @@ -442,3 +443,68 @@ def test_run_with_event_time(self, project): with patch_microbatch_end_time("2020-01-03 13:57:00"): run_dbt(["run", "--event-time-start", "2020-01-01"]) self.assert_row_count(project, "microbatch_model", 2) + + +class TestMicrobatchCompiledRunPaths(BaseMicrobatchTest): + @mock.patch.dict(os.environ, {"DBT_EXPERIMENTAL_MICROBATCH": "True"}) + def test_run_with_event_time(self, project): + # run all partitions from start - 2 expected rows in output, one failed + with patch_microbatch_end_time("2020-01-03 13:57:00"): + run_dbt(["run", "--event-time-start", "2020-01-01"]) + + # Compiled paths + assert read_file( + project.project_root, + "target", + "compiled", + "test", + "models", + "microbatch_model", + "microbatch_model_2020-01-01.sql", + ) + assert read_file( + project.project_root, + "target", + "compiled", + "test", + "models", + "microbatch_model", + "microbatch_model_2020-01-02.sql", + ) + assert read_file( + project.project_root, + "target", + "compiled", + "test", + "models", + "microbatch_model", + "microbatch_model_2020-01-03.sql", + ) + + assert read_file( + project.project_root, + "target", + "run", + "test", + "models", + "microbatch_model", + "microbatch_model_2020-01-01.sql", + ) + assert read_file( + project.project_root, + "target", + "run", + "test", + "models", + "microbatch_model", + "microbatch_model_2020-01-02.sql", + ) + assert read_file( + project.project_root, + "target", + "run", + "test", + "models", + "microbatch_model", + "microbatch_model_2020-01-03.sql", + ) From 9b92a1b08f87bc6cfca6ad2648f9bf742c2f5d9a Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Fri, 20 Sep 2024 17:15:20 +0100 Subject: [PATCH 3/7] unit tests --- tests/unit/graph/test_nodes.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/unit/graph/test_nodes.py b/tests/unit/graph/test_nodes.py index 79522d06427..a0e0a8d7e56 100644 --- a/tests/unit/graph/test_nodes.py +++ b/tests/unit/graph/test_nodes.py @@ -15,7 +15,7 @@ ) from dbt.artifacts.resources.v1.semantic_model import NodeRelation from dbt.contracts.graph.model_config import TestConfig -from dbt.contracts.graph.nodes import ColumnInfo, ModelNode, SemanticModel +from dbt.contracts.graph.nodes import ColumnInfo, ModelNode, ParsedNode, SemanticModel from dbt.node_types import NodeType from dbt_common.contracts.constraints import ( ColumnLevelConstraint, @@ -391,3 +391,35 @@ def test_disabled_unique_combo_multiple(): def assertSameContents(list1, list2): assert sorted(list1) == sorted(list2) + + +class TestParsedNode: + @pytest.fixture(scope="class") + def parsed_node(self) -> ParsedNode: + return ParsedNode( + resource_type=NodeType.Model, + unique_id="model.test_package.test_name", + name="test_name", + package_name="test_package", + schema="test_schema", + alias="test_alias", + fqn=["models", "test_name"], + original_file_path="test_original_file_path", + checksum=FileHash.from_contents("checksum"), + path="test_path.sql", + database=None, + ) + + def test_get_target_write_path(self, parsed_node): + write_path = parsed_node.get_target_write_path("target_path", "subdirectory") + assert ( + write_path + == "target_path/subdirectory/test_package/test_original_file_path/test_path.sql" + ) + + def test_get_target_write_path_split(self, parsed_node): + write_path = parsed_node.get_target_write_path("target_path", "subdirectory", "split") + assert ( + write_path + == "target_path/subdirectory/test_package/test_original_file_path/test_path/test_path_split.sql" + ) From 18c814be11e92e35306519df1b8e40a64d6a3975 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Fri, 20 Sep 2024 17:24:24 +0100 Subject: [PATCH 4/7] changelog entry --- .changes/unreleased/Features-20240920-172419.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Features-20240920-172419.yaml diff --git a/.changes/unreleased/Features-20240920-172419.yaml b/.changes/unreleased/Features-20240920-172419.yaml new file mode 100644 index 00000000000..1647d48f1da --- /dev/null +++ b/.changes/unreleased/Features-20240920-172419.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Write microbatch compiled/run targets to separate files, one per batch +time: 2024-09-20T17:24:19.219556+01:00 +custom: + Author: michelleark + Issue: "10714" From cdb0bcb2c4a92983cea7a864881bd461f6d491ed Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Mon, 23 Sep 2024 00:37:06 +0100 Subject: [PATCH 5/7] test non-batch compilation --- tests/functional/microbatch/test_microbatch.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/functional/microbatch/test_microbatch.py b/tests/functional/microbatch/test_microbatch.py index 7710e09b6e5..e38c60f4247 100644 --- a/tests/functional/microbatch/test_microbatch.py +++ b/tests/functional/microbatch/test_microbatch.py @@ -452,7 +452,17 @@ def test_run_with_event_time(self, project): with patch_microbatch_end_time("2020-01-03 13:57:00"): run_dbt(["run", "--event-time-start", "2020-01-01"]) - # Compiled paths + # Compiled paths - compiled model without filter only + assert read_file( + project.project_root, + "target", + "compiled", + "test", + "models", + "microbatch_model.sql", + ) + + # Compiled paths - batch compilations assert read_file( project.project_root, "target", From efe3a1dd0db47cc64770597ce66150a1660f74cf Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Mon, 23 Sep 2024 00:42:55 +0100 Subject: [PATCH 6/7] unit test format_batch_start --- tests/unit/graph/test_nodes.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/unit/graph/test_nodes.py b/tests/unit/graph/test_nodes.py index a0e0a8d7e56..ad60f54d917 100644 --- a/tests/unit/graph/test_nodes.py +++ b/tests/unit/graph/test_nodes.py @@ -13,6 +13,7 @@ Measure, TestMetadata, ) +from dbt.artifacts.resources.types import BatchSize from dbt.artifacts.resources.v1.semantic_model import NodeRelation from dbt.contracts.graph.model_config import TestConfig from dbt.contracts.graph.nodes import ColumnInfo, ModelNode, ParsedNode, SemanticModel @@ -110,6 +111,22 @@ def test_all_constraints( assert default_model_node.all_constraints == expected_all_constraints + @pytest.mark.parametrize( + "batch_size,batch_start,expected_formatted_batch_start", + [ + (None, None, None), + (BatchSize.year, datetime(2020, 1, 1, 1), "2020-01-01"), + (BatchSize.month, datetime(2020, 1, 1, 1), "2020-01-01"), + (BatchSize.day, datetime(2020, 1, 1, 1), "2020-01-01"), + (BatchSize.hour, datetime(2020, 1, 1, 1), "2020-01-01 01:00:00"), + ], + ) + def test_format_batch_start( + self, default_model_node, batch_size, batch_start, expected_formatted_batch_start + ): + default_model_node.config.batch_size = batch_size + assert default_model_node.format_batch_start(batch_start) == expected_formatted_batch_start + class TestSemanticModel: @pytest.fixture(scope="function") From 6da854a038472f957e0b488058a87f70446485bb Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Tue, 24 Sep 2024 12:32:59 +0100 Subject: [PATCH 7/7] move format_batch_start to MicrobatchBuilder --- core/dbt/context/providers.py | 6 ++++-- core/dbt/contracts/graph/nodes.py | 11 ----------- .../materializations/incremental/microbatch.py | 11 +++++++++++ core/dbt/task/run.py | 11 +++++++++-- tests/unit/graph/test_nodes.py | 17 ----------------- .../incremental/test_microbatch.py | 16 ++++++++++++++++ 6 files changed, 40 insertions(+), 32 deletions(-) diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 7d9ce8befcf..5e2a4d14b2e 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -78,6 +78,7 @@ SecretEnvVarLocationError, TargetNotFoundError, ) +from dbt.materializations.incremental.microbatch import MicrobatchBuilder from dbt.node_types import ModelLanguage, NodeType from dbt.utils import MultiDict, args_to_dict from dbt_common.clients.jinja import MacroProtocol @@ -979,8 +980,9 @@ def write(self, payload: str) -> str: isinstance(self.model, ModelNode) and self.model.config.get("incremental_strategy") == "microbatch" ): - split_suffix = self.model.format_batch_start( - self.model.config.get("__dbt_internal_microbatch_event_time_start") + split_suffix = MicrobatchBuilder.format_batch_start( + self.model.config.get("__dbt_internal_microbatch_event_time_start"), + self.model.config.batch_size, ) self.model.build_path = self.model.get_target_write_path( diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 4494d1c6efc..13c3df0e2e7 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -60,7 +60,6 @@ from dbt.artifacts.resources import SqlOperation as SqlOperationResource from dbt.artifacts.resources import TimeSpine from dbt.artifacts.resources import UnitTestDefinition as UnitTestDefinitionResource -from dbt.artifacts.resources.types import BatchSize from dbt.contracts.graph.model_config import UnitTestNodeConfig from dbt.contracts.graph.node_args import ModelNodeArgs from dbt.contracts.graph.unparsed import ( @@ -572,16 +571,6 @@ def infer_primary_key(self, data_tests: List["GenericTestNode"]) -> List[str]: return [] - def format_batch_start(self, batch_start: Optional[datetime]) -> Optional[str]: - if batch_start is None: - return batch_start - - return str( - batch_start.date() - if (batch_start and self.config.batch_size != BatchSize.hour) - else batch_start - ) - def same_contents(self, old, adapter_type) -> bool: return super().same_contents(old, adapter_type) and self.same_ref_representation(old) diff --git a/core/dbt/materializations/incremental/microbatch.py b/core/dbt/materializations/incremental/microbatch.py index 5bd46eae5e9..4f538529d2d 100644 --- a/core/dbt/materializations/incremental/microbatch.py +++ b/core/dbt/materializations/incremental/microbatch.py @@ -162,3 +162,14 @@ def truncate_timestamp(timestamp: datetime, batch_size: BatchSize): truncated = datetime(timestamp.year, 1, 1, 0, 0, 0, 0, pytz.utc) return truncated + + @staticmethod + def format_batch_start( + batch_start: Optional[datetime], batch_size: BatchSize + ) -> Optional[str]: + if batch_start is None: + return batch_start + + return str( + batch_start.date() if (batch_start and batch_size != BatchSize.hour) else batch_start + ) diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index ae3908f4596..70db4d52920 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -196,7 +196,9 @@ def describe_node(self) -> str: def describe_batch(self, batch_start: Optional[datetime]) -> str: # Only visualize date if batch_start year/month/day - formatted_batch_start = self.node.format_batch_start(batch_start) + formatted_batch_start = MicrobatchBuilder.format_batch_start( + batch_start, self.node.config.batch_size + ) return f"batch {formatted_batch_start} of {self.get_node_representation()}" @@ -460,7 +462,12 @@ def _execute_microbatch_materialization( # Recompile node to re-resolve refs with event time filters rendered, update context self.compiler.compile_node( - model, manifest, {}, split_suffix=model.format_batch_start(batch[0]) + model, + manifest, + {}, + split_suffix=MicrobatchBuilder.format_batch_start( + batch[0], model.config.batch_size + ), ) context["model"] = model context["sql"] = model.compiled_code diff --git a/tests/unit/graph/test_nodes.py b/tests/unit/graph/test_nodes.py index ad60f54d917..a0e0a8d7e56 100644 --- a/tests/unit/graph/test_nodes.py +++ b/tests/unit/graph/test_nodes.py @@ -13,7 +13,6 @@ Measure, TestMetadata, ) -from dbt.artifacts.resources.types import BatchSize from dbt.artifacts.resources.v1.semantic_model import NodeRelation from dbt.contracts.graph.model_config import TestConfig from dbt.contracts.graph.nodes import ColumnInfo, ModelNode, ParsedNode, SemanticModel @@ -111,22 +110,6 @@ def test_all_constraints( assert default_model_node.all_constraints == expected_all_constraints - @pytest.mark.parametrize( - "batch_size,batch_start,expected_formatted_batch_start", - [ - (None, None, None), - (BatchSize.year, datetime(2020, 1, 1, 1), "2020-01-01"), - (BatchSize.month, datetime(2020, 1, 1, 1), "2020-01-01"), - (BatchSize.day, datetime(2020, 1, 1, 1), "2020-01-01"), - (BatchSize.hour, datetime(2020, 1, 1, 1), "2020-01-01 01:00:00"), - ], - ) - def test_format_batch_start( - self, default_model_node, batch_size, batch_start, expected_formatted_batch_start - ): - default_model_node.config.batch_size = batch_size - assert default_model_node.format_batch_start(batch_start) == expected_formatted_batch_start - class TestSemanticModel: @pytest.fixture(scope="function") diff --git a/tests/unit/materializations/incremental/test_microbatch.py b/tests/unit/materializations/incremental/test_microbatch.py index 68521a84e1e..5a5f445a104 100644 --- a/tests/unit/materializations/incremental/test_microbatch.py +++ b/tests/unit/materializations/incremental/test_microbatch.py @@ -444,3 +444,19 @@ def test_offset_timestamp(self, timestamp, batch_size, offset, expected_timestam ) def test_truncate_timestamp(self, timestamp, batch_size, expected_timestamp): assert MicrobatchBuilder.truncate_timestamp(timestamp, batch_size) == expected_timestamp + + @pytest.mark.parametrize( + "batch_size,batch_start,expected_formatted_batch_start", + [ + (None, None, None), + (BatchSize.year, datetime(2020, 1, 1, 1), "2020-01-01"), + (BatchSize.month, datetime(2020, 1, 1, 1), "2020-01-01"), + (BatchSize.day, datetime(2020, 1, 1, 1), "2020-01-01"), + (BatchSize.hour, datetime(2020, 1, 1, 1), "2020-01-01 01:00:00"), + ], + ) + def test_format_batch_start(self, batch_size, batch_start, expected_formatted_batch_start): + assert ( + MicrobatchBuilder.format_batch_start(batch_start, batch_size) + == expected_formatted_batch_start + )