Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using sql in unit testing fixtures #9873

Merged
merged 9 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240408-094132.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support SQL in unit testing fixtures
time: 2024-04-08T09:41:32.15936-04:00
custom:
Author: gshank
Issue: "9405"
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/unit_test_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class UnitTestConfig(BaseConfig):
class UnitTestFormat(StrEnum):
CSV = "csv"
Dict = "dict"
SQL = "sql"


@dataclass
Expand Down
1 change: 1 addition & 0 deletions core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def insensitive_patterns(*patterns: str):
@dataclass
class UnitTestNodeConfig(NodeConfig):
expected_rows: List[Dict[str, Any]] = field(default_factory=list)
expected_sql: Optional[str] = None


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ def same_contents(self, other: Optional["UnitTestDefinition"]) -> bool:
@dataclass
class UnitTestFileFixture(BaseNode):
resource_type: Literal[NodeType.Fixture]
rows: Optional[List[Dict[str, Any]]] = None
rows: Optional[Union[List[Dict[str, Any]], str]] = None


# ====================================
Expand Down
7 changes: 6 additions & 1 deletion core/dbt/parser/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,19 @@ def parse_file(self, file_block: FileBlock):
assert isinstance(file_block.file, FixtureSourceFile)
unique_id = self.generate_unique_id(file_block.name)

if file_block.file.path.relative_path.endswith(".sql"):
rows = file_block.file.contents # type: ignore
else: # endswith('.csv')
rows = self.get_rows(file_block.file.contents) # type: ignore

fixture = UnitTestFileFixture(
name=file_block.name,
path=file_block.file.path.relative_path,
original_file_path=file_block.path.original_file_path,
package_name=self.project.project_name,
unique_id=unique_id,
resource_type=NodeType.Fixture,
rows=self.get_rows(file_block.file.contents),
rows=rows,
)
self.manifest.add_fixture(file_block.file, fixture)

Expand Down
6 changes: 3 additions & 3 deletions core/dbt/parser/read_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ def get_source_files(project, paths, extension, parse_file_type, saved_files, ig
if parse_file_type == ParseFileType.Seed:
fb_list.append(load_seed_source_file(fp, project.project_name))
# singular tests live in /tests but only generic tests live
# in /tests/generic so we want to skip those
# in /tests/generic and fixtures in /tests/fixture so we want to skip those
else:
if parse_file_type == ParseFileType.SingularTest:
path = pathlib.Path(fp.relative_path)
if path.parts[0] == "generic":
if path.parts[0] in ["generic", "fixtures"]:
continue
file = load_source_file(fp, parse_file_type, project.project_name, saved_files)
# only append the list if it has contents. added to fix #3568
Expand Down Expand Up @@ -431,7 +431,7 @@ def get_file_types_for_project(project):
},
ParseFileType.Fixture: {
"paths": project.fixture_paths,
"extensions": [".csv"],
"extensions": [".csv", ".sql"],
"parser": "FixtureParser",
},
}
Expand Down
51 changes: 39 additions & 12 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@
name = test_case.name
if tested_node.is_versioned:
name = name + f"_v{tested_node.version}"
expected_sql: Optional[str] = None
if test_case.expect.format == UnitTestFormat.SQL:
expected_rows: List[Dict[str, Any]] = []
expected_sql = test_case.expect.rows # type: ignore
else:
assert isinstance(test_case.expect.rows, List)
expected_rows = deepcopy(test_case.expect.rows)

assert isinstance(expected_rows, List)
unit_test_node = UnitTestNode(
name=name,
resource_type=NodeType.Unit,
Expand All @@ -76,8 +85,7 @@
original_file_path=test_case.original_file_path,
unique_id=test_case.unique_id,
config=UnitTestNodeConfig(
materialized="unit",
expected_rows=deepcopy(test_case.expect.rows), # type:ignore
materialized="unit", expected_rows=expected_rows, expected_sql=expected_sql
),
raw_code=tested_node.raw_code,
database=tested_node.database,
Expand Down Expand Up @@ -132,7 +140,7 @@
"schema": original_input_node.schema,
"fqn": original_input_node.fqn,
"checksum": FileHash.empty(),
"raw_code": self._build_fixture_raw_code(given.rows, None),
"raw_code": self._build_fixture_raw_code(given.rows, None, given.format),
"package_name": original_input_node.package_name,
"unique_id": f"model.{original_input_node.package_name}.{input_name}",
"name": input_name,
Expand Down Expand Up @@ -172,12 +180,15 @@
# Add unique ids of input_nodes to depends_on
unit_test_node.depends_on.nodes.append(input_node.unique_id)

def _build_fixture_raw_code(self, rows, column_name_to_data_types) -> str:
def _build_fixture_raw_code(self, rows, column_name_to_data_types, fixture_format) -> str:
# We're not currently using column_name_to_data_types, but leaving here for
# possible future use.
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format(
rows=rows, column_name_to_data_types=column_name_to_data_types
)
if fixture_format == UnitTestFormat.SQL:
return rows
else:
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format(
rows=rows, column_name_to_data_types=column_name_to_data_types
)

def _get_original_input_node(self, input: str, tested_node: ModelNode, test_case_name: str):
"""
Expand Down Expand Up @@ -352,13 +363,29 @@
)

if ut_fixture.fixture:
# find fixture file object and store unit_test_definition unique_id
fixture = self._get_fixture(ut_fixture.fixture, self.project.project_name)
fixture_source_file = self.manifest.files[fixture.file_id]
fixture_source_file.unit_tests.append(unit_test_definition.unique_id)
ut_fixture.rows = fixture.rows
ut_fixture.rows = self.get_fixture_file_rows(
ut_fixture.fixture, self.project.project_name, unit_test_definition.unique_id
)
else:
ut_fixture.rows = self._convert_csv_to_list_of_dicts(ut_fixture.rows)
elif ut_fixture.format == UnitTestFormat.SQL:
if not (isinstance(ut_fixture.rows, str) or isinstance(ut_fixture.fixture, str)):
raise ParsingError(

Check warning on line 373 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L373

Added line #L373 was not covered by tests
f"Unit test {unit_test_definition.name} has {fixture_type} rows or fixtures "
f"which do not match format {ut_fixture.format}. Expected string."
)

if ut_fixture.fixture:
ut_fixture.rows = self.get_fixture_file_rows(
ut_fixture.fixture, self.project.project_name, unit_test_definition.unique_id
)

def get_fixture_file_rows(self, fixture_name, project_name, utdef_unique_id):
# find fixture file object and store unit_test_definition unique_id
fixture = self._get_fixture(fixture_name, project_name)
fixture_source_file = self.manifest.files[fixture.file_id]
fixture_source_file.unit_tests.append(utdef_unique_id)
return fixture.rows

def _convert_csv_to_list_of_dicts(self, csv_string: str) -> List[Dict[str, Any]]:
dummy_file = StringIO(csv_string)
Expand Down
Loading