From 315aeaedb708e9388750d39e2280ba61a30fa92b Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Wed, 29 May 2024 10:24:39 -0500 Subject: [PATCH 01/12] add query recording options (#135) * add query recording options * make record mode single var, add better types * fix tests, tweak logic * fix comment * clean up and fix test * changelog * add test * use test fixture --- .../Under the Hood-20240528-110518.yaml | 6 + dbt_common/record.py | 51 +++++- docs/guides/record_replay.md | 17 +- tests/unit/test_record.py | 159 ++++++++++++------ 4 files changed, 174 insertions(+), 59 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20240528-110518.yaml diff --git a/.changes/unreleased/Under the Hood-20240528-110518.yaml b/.changes/unreleased/Under the Hood-20240528-110518.yaml new file mode 100644 index 00000000..58b243c7 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240528-110518.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Allow dynamic selection of record types when recording. +time: 2024-05-28T11:05:18.290107-05:00 +custom: + Author: emmyoop + Issue: "140" diff --git a/dbt_common/record.py b/dbt_common/record.py index b2b5ba48..5edc9c7f 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -59,14 +59,18 @@ class Diff: class RecorderMode(Enum): RECORD = 1 REPLAY = 2 + RECORD_QUERIES = 3 class Recorder: _record_cls_by_name: Dict[str, Type] = {} _record_name_by_params_name: Dict[str, str] = {} - def __init__(self, mode: RecorderMode, recording_path: Optional[str] = None) -> None: + def __init__( + self, mode: RecorderMode, types: Optional[List], recording_path: Optional[str] = None + ) -> None: self.mode = mode + self.types = types self._records_by_type: Dict[str, List[Record]] = {} self._replay_diffs: List["Diff"] = [] @@ -118,13 +122,14 @@ def load(cls, file_name: str) -> Dict[str, List[Record]]: records_by_type: Dict[str, List[Record]] = {} for record_type_name in loaded_dct: + # TODO: this breaks with QueryRecord on replay since it's + # not in common so isn't part of cls._record_cls_by_name yet record_cls = cls._record_cls_by_name[record_type_name] rec_list = [] for record_dct in loaded_dct[record_type_name]: rec = record_cls.from_dict(record_dct) rec_list.append(rec) # type: ignore records_by_type[record_type_name] = rec_list - return records_by_type def expect_record(self, params: Any) -> Any: @@ -147,17 +152,46 @@ def print_diffs(self) -> None: def get_record_mode_from_env() -> Optional[RecorderMode]: - replay_val = os.environ.get("DBT_REPLAY") - if replay_val is not None and replay_val != "0" and replay_val.lower() != "false": - return RecorderMode.REPLAY + """ + Get the record mode from the environment variables. + + If the mode is not set to 'RECORD' or 'REPLAY', return None. + Expected format: 'DBT_RECORDER_MODE=RECORD' + """ + record_mode = os.environ.get("DBT_RECORDER_MODE") - record_val = os.environ.get("DBT_RECORD") - if record_val is not None and record_val != "0" and record_val.lower() != "false": + if record_mode is None: + return None + + if record_mode.lower() == "record": return RecorderMode.RECORD + # replaying requires a file path, otherwise treat as noop + elif ( + record_mode.lower() == "replay" and os.environ.get("DBT_RECORDER_REPLAY_PATH") is not None + ): + return RecorderMode.REPLAY + # if you don't specify record/replay it's a noop return None +def get_record_types_from_env() -> Optional[List]: + """ + Get the record subset from the environment variables. + + If no types are provided, there will be no filtering. + Invalid types will be ignored. + Expected format: 'DBT_RECORDER_TYPES=QueryRecord,FileLoadRecord,OtherRecord' + """ + record_types_str = os.environ.get("DBT_RECORDER_TYPES") + + # if all is specified we don't want any type filtering + if record_types_str is None or record_types_str.lower == "all": + return None + + return record_types_str.split(",") + + def record_function(record_type, method=False, tuple_result=False): def record_function_inner(func_to_record): # To avoid runtime overhead and other unpleasantness, we only apply the @@ -176,6 +210,9 @@ def record_replay_wrapper(*args, **kwargs): if recorder is None: return func_to_record(*args, **kwargs) + if recorder.types is not None and record_type.__name__ not in recorder.types: + return func_to_record(*args, **kwargs) + # For methods, peel off the 'self' argument before calling the # params constructor. param_args = args[1:] if method else args diff --git a/docs/guides/record_replay.md b/docs/guides/record_replay.md index 9d9d87f2..2103fd1b 100644 --- a/docs/guides/record_replay.md +++ b/docs/guides/record_replay.md @@ -28,7 +28,22 @@ Note also the `LoadFileRecord` class passed as a parameter to this decorator. Th The final detail needed is to define the classes specified by `params_cls` and `result_cls`, which must be dataclasses with properties whose order and names correspond to the parameters passed to the recorded function. In this case those are the `LoadFileParams` and `LoadFileResult` classes, respectively. -With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them. +With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them. + +## How to record/replay +If `DBT_RECORDER_MODE` is not `replay` or `record`, case insensitive, this is a no-op. Invalid values are ignored and do not throw exceptions. + +`DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses. Any invalid types will be ignored. `all` is a valid type and behaves the same as not populating the env var. + + +```bash +DBT_RECORDER_MODE=record DBT_RECODER_TYPES=QueryRecord,GetEnvRecord dbt run +``` + +replay need the file to replay +```bash +DBT_RECORDER_MODE=replay DBT_RECORDER_REPLAY_PATH=recording.json dbt run +``` ## Final Thoughts diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index aa7af69b..8dca1d19 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -1,5 +1,6 @@ import dataclasses import os +import pytest from typing import Optional from dbt_common.context import set_invocation_context, get_invocation_context @@ -24,58 +25,114 @@ class TestRecord(Record): result_cls = TestRecordResult -def test_decorator_records(): - prev = os.environ.get("DBT_RECORD", None) - try: - os.environ["DBT_RECORD"] = "True" - recorder = Recorder(RecorderMode.RECORD) - set_invocation_context({}) - get_invocation_context().recorder = recorder - - @record_function(TestRecord) - def test_func(a: int, b: str, c: Optional[str] = None) -> str: - return str(a) + b + (c if c else "") - - test_func(123, "abc") - - expected_record = TestRecord( - params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") - ) - - assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params - assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result - - finally: - if prev is None: - os.environ.pop("DBT_RECORD", None) - else: - os.environ["DBT_RECORD"] = prev - - -def test_decorator_replays(): - prev = os.environ.get("DBT_RECORD", None) - try: - os.environ["DBT_RECORD"] = "True" - recorder = Recorder(RecorderMode.REPLAY) - set_invocation_context({}) - get_invocation_context().recorder = recorder - - expected_record = TestRecord( - params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") - ) - - recorder._records_by_type["TestRecord"] = [expected_record] +@dataclasses.dataclass +class NotTestRecordParams: + a: int + b: str + c: Optional[str] = None - @record_function(TestRecord) - def test_func(a: int, b: str, c: Optional[str] = None) -> str: - raise Exception("This should not actually be called") - res = test_func(123, "abc") +@dataclasses.dataclass +class NotTestRecordResult: + return_val: str - assert res == "123abc" - finally: - if prev is None: - os.environ.pop("DBT_RECORD", None) - else: - os.environ["DBT_RECORD"] = prev +@Recorder.register_record_type +class NotTestRecord(Record): + params_cls = NotTestRecordParams + result_cls = NotTestRecordResult + + +@pytest.fixture(scope="function", autouse=True) +def setup(): + # capture the previous state of the environment variables + prev_mode = os.environ.get("DBT_RECORDER_MODE", None) + prev_type = os.environ.get("DBT_RECORDER_TYPES", None) + prev_fp = os.environ.get("DBT_RECORDER_REPLAY_PATH", None) + # clear the environment variables + os.environ.pop("DBT_RECORDER_MODE", None) + os.environ.pop("DBT_RECORDER_TYPES", None) + os.environ.pop("DBT_RECORDER_REPLAY_PATH", None) + yield + # reset the environment variables to their previous state + if prev_mode is None: + os.environ.pop("DBT_RECORDER_MODE", None) + else: + os.environ["DBT_RECORDER_MODE"] = prev_mode + if prev_type is None: + os.environ.pop("DBT_RECORDER_TYPES", None) + else: + os.environ["DBT_RECORDER_TYPES"] = prev_type + if prev_fp is None: + os.environ.pop("DBT_RECORDER_REPLAY_PATH", None) + else: + os.environ["DBT_RECORDER_REPLAY_PATH"] = prev_fp + + +def test_decorator_records(setup): + os.environ["DBT_RECORDER_MODE"] = "Record" + recorder = Recorder(RecorderMode.RECORD, None) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + @record_function(TestRecord) + def test_func(a: int, b: str, c: Optional[str] = None) -> str: + return str(a) + b + (c if c else "") + + test_func(123, "abc") + + expected_record = TestRecord( + params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") + ) + + assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params + assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result + + +def test_record_types(setup): + os.environ["DBT_RECORDER_MODE"] = "Record" + os.environ["DBT_RECORDER_TYPES"] = "TestRecord" + recorder = Recorder(RecorderMode.RECORD, ["TestRecord"]) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + @record_function(TestRecord) + def test_func(a: int, b: str, c: Optional[str] = None) -> str: + return str(a) + b + (c if c else "") + + @record_function(NotTestRecord) + def not_test_func(a: int, b: str, c: Optional[str] = None) -> str: + return str(a) + b + (c if c else "") + + test_func(123, "abc") + not_test_func(456, "def") + + expected_record = TestRecord( + params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") + ) + + assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params + assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result + assert NotTestRecord not in recorder._records_by_type + + +def test_decorator_replays(setup): + os.environ["DBT_RECORDER_MODE"] = "Replay" + os.environ["DBT_RECORDER_REPLAY_PATH"] = "record.json" + recorder = Recorder(RecorderMode.REPLAY, None) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + expected_record = TestRecord( + params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") + ) + + recorder._records_by_type["TestRecord"] = [expected_record] + + @record_function(TestRecord) + def test_func(a: int, b: str, c: Optional[str] = None) -> str: + raise Exception("This should not actually be called") + + res = test_func(123, "abc") + + assert res == "123abc" From 50072e70b01c988bfc7dcd777d603f8a7d7738bd Mon Sep 17 00:00:00 2001 From: Kshitij Aranke Date: Wed, 29 May 2024 19:13:40 +0100 Subject: [PATCH 02/12] Move `StatsItem`, `StatsDict`, `TableMetadata` to `dbt-common` (#141) * Move catalog artifact schema to dbt_common * Add changie * Only move StatsItem, StatsDict, TableMetadata * Discard changes to dbt_common/artifacts/exceptions/__init__.py * Discard changes to dbt_common/artifacts/exceptions/schemas.py * Discard changes to dbt_common/artifacts/schemas/base.py * Discard changes to dbt_common/version.py * Clean up artifacts directory * Update Under the Hood-20240529-143154.yaml * move to contracts/metadata.py --- .../Under the Hood-20240529-143154.yaml | 6 +++++ .gitignore | 2 +- dbt_common/contracts/metadata.py | 26 +++++++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 .changes/unreleased/Under the Hood-20240529-143154.yaml create mode 100644 dbt_common/contracts/metadata.py diff --git a/.changes/unreleased/Under the Hood-20240529-143154.yaml b/.changes/unreleased/Under the Hood-20240529-143154.yaml new file mode 100644 index 00000000..e9664254 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240529-143154.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Move StatsItem, StatsDict, TableMetadata to dbt-common +time: 2024-05-29T14:31:54.854468+01:00 +custom: + Author: aranke + Issue: "141" diff --git a/.gitignore b/.gitignore index 28030861..5ad1a6c2 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ diff --git a/dbt_common/contracts/metadata.py b/dbt_common/contracts/metadata.py new file mode 100644 index 00000000..bc738796 --- /dev/null +++ b/dbt_common/contracts/metadata.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Union + +from dbt_common.dataclass_schema import dbtClassMixin + + +@dataclass +class StatsItem(dbtClassMixin): + id: str + label: str + value: Union[bool, str, float, None] + include: bool + description: Optional[str] = None + + +StatsDict = Dict[str, StatsItem] + + +@dataclass +class TableMetadata(dbtClassMixin): + type: str + schema: str + name: str + database: Optional[str] = None + comment: Optional[str] = None + owner: Optional[str] = None From 074f15a77e64e4d3b17db8d2ac8b32899bf0891c Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Wed, 29 May 2024 13:35:13 -0500 Subject: [PATCH 03/12] rename var (#142) * rename var * remove diff --- dbt_common/record.py | 4 +--- docs/guides/record_replay.md | 2 +- tests/unit/test_record.py | 10 +++++----- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/dbt_common/record.py b/dbt_common/record.py index 5edc9c7f..428f4bc1 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -166,9 +166,7 @@ def get_record_mode_from_env() -> Optional[RecorderMode]: if record_mode.lower() == "record": return RecorderMode.RECORD # replaying requires a file path, otherwise treat as noop - elif ( - record_mode.lower() == "replay" and os.environ.get("DBT_RECORDER_REPLAY_PATH") is not None - ): + elif record_mode.lower() == "replay" and os.environ.get("DBT_RECORDER_FILE_PATH") is not None: return RecorderMode.REPLAY # if you don't specify record/replay it's a noop diff --git a/docs/guides/record_replay.md b/docs/guides/record_replay.md index 2103fd1b..aff4c77c 100644 --- a/docs/guides/record_replay.md +++ b/docs/guides/record_replay.md @@ -42,7 +42,7 @@ DBT_RECORDER_MODE=record DBT_RECODER_TYPES=QueryRecord,GetEnvRecord dbt run replay need the file to replay ```bash -DBT_RECORDER_MODE=replay DBT_RECORDER_REPLAY_PATH=recording.json dbt run +DBT_RECORDER_MODE=replay DBT_RECORDER_FILE_PATH=recording.json dbt run ``` ## Final Thoughts diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index 8dca1d19..b0371498 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -48,11 +48,11 @@ def setup(): # capture the previous state of the environment variables prev_mode = os.environ.get("DBT_RECORDER_MODE", None) prev_type = os.environ.get("DBT_RECORDER_TYPES", None) - prev_fp = os.environ.get("DBT_RECORDER_REPLAY_PATH", None) + prev_fp = os.environ.get("DBT_RECORDER_FILE_PATH", None) # clear the environment variables os.environ.pop("DBT_RECORDER_MODE", None) os.environ.pop("DBT_RECORDER_TYPES", None) - os.environ.pop("DBT_RECORDER_REPLAY_PATH", None) + os.environ.pop("DBT_RECORDER_FILE_PATH", None) yield # reset the environment variables to their previous state if prev_mode is None: @@ -64,9 +64,9 @@ def setup(): else: os.environ["DBT_RECORDER_TYPES"] = prev_type if prev_fp is None: - os.environ.pop("DBT_RECORDER_REPLAY_PATH", None) + os.environ.pop("DBT_RECORDER_FILE_PATH", None) else: - os.environ["DBT_RECORDER_REPLAY_PATH"] = prev_fp + os.environ["DBT_RECORDER_FILE_PATH"] = prev_fp def test_decorator_records(setup): @@ -118,7 +118,7 @@ def not_test_func(a: int, b: str, c: Optional[str] = None) -> str: def test_decorator_replays(setup): os.environ["DBT_RECORDER_MODE"] = "Replay" - os.environ["DBT_RECORDER_REPLAY_PATH"] = "record.json" + os.environ["DBT_RECORDER_FILE_PATH"] = "record.json" recorder = Recorder(RecorderMode.REPLAY, None) set_invocation_context({}) get_invocation_context().recorder = recorder From ddb28aaa2ca51840aad46ee34a1dcf3838fbdf96 Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Wed, 29 May 2024 14:05:21 -0500 Subject: [PATCH 04/12] minor version bump (#143) --- dbt_common/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt_common/__about__.py b/dbt_common/__about__.py index b2b60a55..c203c5d8 100644 --- a/dbt_common/__about__.py +++ b/dbt_common/__about__.py @@ -1 +1 @@ -version = "1.1.0" +version = "1.2.0" From 99d9727f2da1b4d50be2caa5d90d8c38302e66d3 Mon Sep 17 00:00:00 2001 From: Kshitij Aranke Date: Wed, 5 Jun 2024 00:17:05 +0100 Subject: [PATCH 05/12] Move CatalogKey, ColumnMetadata, ColumnMap, CatalogTable to dbt-common (#147) * Move CatalogKey, ColumnMetadata, ColumnMap, CatalogTable to dbt-common * Add changie --- .../Under the Hood-20240603-123631.yaml | 6 ++++ dbt_common/contracts/metadata.py | 35 ++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 .changes/unreleased/Under the Hood-20240603-123631.yaml diff --git a/.changes/unreleased/Under the Hood-20240603-123631.yaml b/.changes/unreleased/Under the Hood-20240603-123631.yaml new file mode 100644 index 00000000..18a649df --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240603-123631.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Move CatalogKey, ColumnMetadata, ColumnMap, CatalogTable to dbt-common +time: 2024-06-03T12:36:31.542118+02:00 +custom: + Author: aranke + Issue: "147" diff --git a/dbt_common/contracts/metadata.py b/dbt_common/contracts/metadata.py index bc738796..d71e79a7 100644 --- a/dbt_common/contracts/metadata.py +++ b/dbt_common/contracts/metadata.py @@ -1,7 +1,8 @@ from dataclasses import dataclass -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, NamedTuple from dbt_common.dataclass_schema import dbtClassMixin +from dbt_common.utils.formatting import lowercase @dataclass @@ -24,3 +25,35 @@ class TableMetadata(dbtClassMixin): database: Optional[str] = None comment: Optional[str] = None owner: Optional[str] = None + + +CatalogKey = NamedTuple( + "CatalogKey", [("database", Optional[str]), ("schema", str), ("name", str)] +) + + +@dataclass +class ColumnMetadata(dbtClassMixin): + type: str + index: int + name: str + comment: Optional[str] = None + + +ColumnMap = Dict[str, ColumnMetadata] + + +@dataclass +class CatalogTable(dbtClassMixin): + metadata: TableMetadata + columns: ColumnMap + stats: StatsDict + # the same table with two unique IDs will just be listed two times + unique_id: Optional[str] = None + + def key(self) -> CatalogKey: + return CatalogKey( + lowercase(self.metadata.database), + self.metadata.schema.lower(), + self.metadata.name.lower(), + ) From e30f65f6b93eed0422295a4282a892cdfe4b59d5 Mon Sep 17 00:00:00 2001 From: Kshitij Aranke Date: Wed, 5 Jun 2024 19:15:02 +0100 Subject: [PATCH 06/12] minor version bump to 1.3.0 (#148) --- dbt_common/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt_common/__about__.py b/dbt_common/__about__.py index c203c5d8..d28b3ddc 100644 --- a/dbt_common/__about__.py +++ b/dbt_common/__about__.py @@ -1 +1 @@ -version = "1.2.0" +version = "1.3.0" From 930a2f0e57dcb72954b7369d29e903e3c061281d Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Tue, 18 Jun 2024 12:05:24 -0500 Subject: [PATCH 07/12] Support DIFF as a new mode for recording (#149) * WIP * dumb diffing * use diff class * comments * fix query record logic * WIP * create files in a dir * remove path logic * fix diff order * clean up, add tests * remove comment * changelog * remove non-root logic --- .../Under the Hood-20240617-204541.yaml | 6 + dbt_common/record.py | 105 +++++- pyproject.toml | 1 + tests/unit/test_diff.py | 305 ++++++++++++++++++ 4 files changed, 406 insertions(+), 11 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20240617-204541.yaml create mode 100644 tests/unit/test_diff.py diff --git a/.changes/unreleased/Under the Hood-20240617-204541.yaml b/.changes/unreleased/Under the Hood-20240617-204541.yaml new file mode 100644 index 00000000..b4e33c36 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240617-204541.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Add support for basic diff of run recordings. +time: 2024-06-17T20:45:41.123374-05:00 +custom: + Author: emmyoop + Issue: "144" diff --git a/dbt_common/record.py b/dbt_common/record.py index 428f4bc1..9fed4c3d 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -9,6 +9,8 @@ import dataclasses import json import os + +from deepdiff import DeepDiff # type: ignore from enum import Enum from typing import Any, Dict, List, Mapping, Optional, Type @@ -51,15 +53,68 @@ def from_dict(cls, dct: Mapping) -> "Record": class Diff: - """Marker class for diffs?""" + def __init__(self, current_recording_path: str, previous_recording_path: str) -> None: + self.current_recording_path = current_recording_path + self.previous_recording_path = previous_recording_path + + def diff_query_records(self, current: List, previous: List) -> Dict[str, Any]: + # some of the table results are returned as a stringified list of dicts that don't + # diff because order isn't consistent. convert it into a list of dicts so it can + # be diffed ignoring order + + for i in range(len(current)): + if current[i].get("result").get("table") is not None: + current[i]["result"]["table"] = json.loads(current[i]["result"]["table"]) + for i in range(len(previous)): + if previous[i].get("result").get("table") is not None: + previous[i]["result"]["table"] = json.loads(previous[i]["result"]["table"]) + + return DeepDiff(previous, current, ignore_order=True, verbose_level=2) + + def diff_env_records(self, current: List, previous: List) -> Dict[str, Any]: + # The mode and filepath may change. Ignore them. + + exclude_paths = [ + "root[0]['result']['env']['DBT_RECORDER_FILE_PATH']", + "root[0]['result']['env']['DBT_RECORDER_MODE']", + ] + + return DeepDiff( + previous, current, ignore_order=True, verbose_level=2, exclude_paths=exclude_paths + ) + + def diff_default(self, current: List, previous: List) -> Dict[str, Any]: + return DeepDiff(previous, current, ignore_order=True, verbose_level=2) + + def calculate_diff(self) -> Dict[str, Any]: + with open(self.current_recording_path) as current_recording: + current_dct = json.load(current_recording) + + with open(self.previous_recording_path) as previous_recording: + previous_dct = json.load(previous_recording) + + diff = {} + for record_type in current_dct: + if record_type == "QueryRecord": + diff[record_type] = self.diff_query_records( + current_dct[record_type], previous_dct[record_type] + ) + elif record_type == "GetEnvRecord": + diff[record_type] = self.diff_env_records( + current_dct[record_type], previous_dct[record_type] + ) + else: + diff[record_type] = self.diff_default( + current_dct[record_type], previous_dct[record_type] + ) - pass + return diff class RecorderMode(Enum): RECORD = 1 REPLAY = 2 - RECORD_QUERIES = 3 + DIFF = 3 # records and does diffing class Recorder: @@ -67,15 +122,31 @@ class Recorder: _record_name_by_params_name: Dict[str, str] = {} def __init__( - self, mode: RecorderMode, types: Optional[List], recording_path: Optional[str] = None + self, + mode: RecorderMode, + types: Optional[List], + current_recording_path: str = "recording.json", + previous_recording_path: Optional[str] = None, ) -> None: self.mode = mode self.types = types self._records_by_type: Dict[str, List[Record]] = {} self._replay_diffs: List["Diff"] = [] + self.diff: Diff + self.previous_recording_path = previous_recording_path + self.current_recording_path = current_recording_path + + if self.previous_recording_path is not None and self.mode in ( + RecorderMode.REPLAY, + RecorderMode.DIFF, + ): + self.diff = Diff( + current_recording_path=self.current_recording_path, + previous_recording_path=self.previous_recording_path, + ) - if recording_path is not None: - self._records_by_type = self.load(recording_path) + if self.mode == RecorderMode.REPLAY: + self._records_by_type = self.load(self.previous_recording_path) @classmethod def register_record_type(cls, rec_type) -> Any: @@ -101,8 +172,8 @@ def pop_matching_record(self, params: Any) -> Optional[Record]: return match - def write(self, file_name) -> None: - with open(file_name, "w") as file: + def write(self) -> None: + with open(self.current_recording_path, "w") as file: json.dump(self._to_dict(), file) def _to_dict(self) -> Dict: @@ -143,19 +214,19 @@ def expect_record(self, params: Any) -> Any: def write_diffs(self, diff_file_name) -> None: json.dump( - self._replay_diffs, + self.diff.calculate_diff(), open(diff_file_name, "w"), ) def print_diffs(self) -> None: - print(repr(self._replay_diffs)) + print(repr(self.diff.calculate_diff())) def get_record_mode_from_env() -> Optional[RecorderMode]: """ Get the record mode from the environment variables. - If the mode is not set to 'RECORD' or 'REPLAY', return None. + If the mode is not set to 'RECORD', 'DIFF' or 'REPLAY', return None. Expected format: 'DBT_RECORDER_MODE=RECORD' """ record_mode = os.environ.get("DBT_RECORDER_MODE") @@ -165,6 +236,9 @@ def get_record_mode_from_env() -> Optional[RecorderMode]: if record_mode.lower() == "record": return RecorderMode.RECORD + # diffing requires a file path, otherwise treat as noop + elif record_mode.lower() == "diff" and os.environ.get("DBT_RECORDER_FILE_PATH") is not None: + return RecorderMode.DIFF # replaying requires a file path, otherwise treat as noop elif record_mode.lower() == "replay" and os.environ.get("DBT_RECORDER_FILE_PATH") is not None: return RecorderMode.REPLAY @@ -190,6 +264,15 @@ def get_record_types_from_env() -> Optional[List]: return record_types_str.split(",") +def get_record_types_from_dict(fp: str) -> List: + """ + Get the record subset from the dict. + """ + with open(fp) as file: + loaded_dct = json.load(file) + return list(loaded_dct.keys()) + + def record_function(record_type, method=False, tuple_result=False): def record_function_inner(func_to_record): # To avoid runtime overhead and other unpleasantness, we only apply the diff --git a/pyproject.toml b/pyproject.toml index c1f4f281..64fc04fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ classifiers = [ dependencies = [ "agate>=1.7.0,<1.10", "colorama>=0.3.9,<0.5", + "deepdiff>=7.0,<8.0", "isodate>=0.6,<0.7", "jsonschema>=4.0,<5.0", "Jinja2>=3.1.3,<4", diff --git a/tests/unit/test_diff.py b/tests/unit/test_diff.py new file mode 100644 index 00000000..791263f3 --- /dev/null +++ b/tests/unit/test_diff.py @@ -0,0 +1,305 @@ +import json +import pytest +from dbt_common.record import Diff + + +@pytest.fixture +def current_query(): + return [ + { + "params": { + "a": 1, + }, + "result": { + "this": "key", + "table": '[{"a": 5},{"b": 7}]', + }, + } + ] + + +@pytest.fixture +def query_modified_order(): + return [ + { + "params": { + "a": 1, + }, + "result": { + "this": "key", + "table": '[{"b": 7},{"a": 5}]', + }, + } + ] + + +@pytest.fixture +def query_modified_value(): + return [ + { + "params": { + "a": 1, + }, + "result": { + "this": "key", + "table": '[{"a": 5},{"b": 10}]', + }, + } + ] + + +@pytest.fixture +def current_simple(): + return [ + { + "params": { + "a": 1, + }, + "result": { + "this": "cat", + }, + } + ] + + +@pytest.fixture +def current_simple_modified(): + return [ + { + "params": { + "a": 1, + }, + "result": { + "this": "dog", + }, + } + ] + + +@pytest.fixture +def env_record(): + return [ + { + "params": {}, + "result": { + "env": { + "DBT_RECORDER_FILE_PATH": "record.json", + "ANOTHER_ENV_VAR": "dogs", + }, + }, + } + ] + + +@pytest.fixture +def modified_env_record(): + return [ + { + "params": {}, + "result": { + "env": { + "DBT_RECORDER_FILE_PATH": "another_record.json", + "ANOTHER_ENV_VAR": "cats", + }, + }, + } + ] + + +def test_diff_query_records_no_diff(current_query, query_modified_order): + # Setup: Create an instance of Diff + diff_instance = Diff( + current_recording_path="path/to/current", previous_recording_path="path/to/previous" + ) + result = diff_instance.diff_query_records(current_query, query_modified_order) + # the order changed but the diff should be empty + expected_result = {} + assert result == expected_result # Replace expected_result with what you actually expect + + +def test_diff_query_records_with_diff(current_query, query_modified_value): + diff_instance = Diff( + current_recording_path="path/to/current", previous_recording_path="path/to/previous" + ) + result = diff_instance.diff_query_records(current_query, query_modified_value) + # the values changed this time + expected_result = { + "values_changed": {"root[0]['result']['table'][1]['b']": {"new_value": 7, "old_value": 10}} + } + assert result == expected_result + + +def test_diff_env_records(env_record, modified_env_record): + diff_instance = Diff( + current_recording_path="path/to/current", previous_recording_path="path/to/previous" + ) + result = diff_instance.diff_env_records(env_record, modified_env_record) + expected_result = { + "values_changed": { + "root[0]['result']['env']['ANOTHER_ENV_VAR']": { + "new_value": "dogs", + "old_value": "cats", + } + } + } + assert result == expected_result + + +def test_diff_default_no_diff(current_simple): + diff_instance = Diff( + current_recording_path="path/to/current", previous_recording_path="path/to/previous" + ) + # use the same list to ensure no diff + result = diff_instance.diff_default(current_simple, current_simple) + expected_result = {} + assert result == expected_result + + +def test_diff_default_with_diff(current_simple, current_simple_modified): + diff_instance = Diff( + current_recording_path="path/to/current", previous_recording_path="path/to/previous" + ) + result = diff_instance.diff_default(current_simple, current_simple_modified) + expected_result = { + "values_changed": {"root[0]['result']['this']": {"new_value": "cat", "old_value": "dog"}} + } + assert result == expected_result + + +# Mock out reading the files so we don't have to +class MockFile: + def __init__(self, json_data): + self.json_data = json_data + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def read(self): + return json.dumps(self.json_data) + + +# Create a Mock Open Function +def mock_open(mock_files): + def open_mock(file, *args, **kwargs): + if file in mock_files: + return MockFile(mock_files[file]) + raise FileNotFoundError(f"No mock file found for {file}") + + return open_mock + + +def test_calculate_diff_no_diff(monkeypatch): + # Mock data for the files + current_recording_data = { + "GetEnvRecord": [ + { + "params": { + "a": 1, + }, + "result": { + "this": "dog", + }, + } + ], + "DefaultKey": [ + { + "params": { + "a": 1, + }, + "result": { + "this": "dog", + }, + } + ], + } + previous_recording_data = { + "GetEnvRecord": [ + { + "params": { + "a": 1, + }, + "result": { + "this": "dog", + }, + } + ], + "DefaultKey": [ + { + "params": { + "a": 1, + }, + "result": { + "this": "dog", + }, + } + ], + } + current_recording_path = "/path/to/current_recording.json" + previous_recording_path = "/path/to/previous_recording.json" + mock_files = { + current_recording_path: current_recording_data, + previous_recording_path: previous_recording_data, + } + monkeypatch.setattr("builtins.open", mock_open(mock_files)) + + # test the diff + diff_instance = Diff( + current_recording_path=current_recording_path, + previous_recording_path=previous_recording_path, + ) + result = diff_instance.calculate_diff() + expected_result = {"GetEnvRecord": {}, "DefaultKey": {}} + assert result == expected_result + + +def test_calculate_diff_with_diff(monkeypatch): + # Mock data for the files + current_recording_data = { + "GetEnvRecord": [ + { + "params": { + "a": 1, + }, + "result": { + "this": "dog", + }, + } + ] + } + previous_recording_data = { + "GetEnvRecord": [ + { + "params": { + "a": 1, + }, + "result": { + "this": "cats", + }, + } + ] + } + current_recording_path = "/path/to/current_recording.json" + previous_recording_path = "/path/to/previous_recording.json" + mock_files = { + current_recording_path: current_recording_data, + previous_recording_path: previous_recording_data, + } + monkeypatch.setattr("builtins.open", mock_open(mock_files)) + + # test the diff + diff_instance = Diff( + current_recording_path=current_recording_path, + previous_recording_path=previous_recording_path, + ) + result = diff_instance.calculate_diff() + expected_result = { + "GetEnvRecord": { + "values_changed": { + "root[0]['result']['this']": {"new_value": "dog", "old_value": "cats"} + } + } + } + assert result == expected_result From 00fbf918f55e72a9a6cf08767e07ee55e15cdc93 Mon Sep 17 00:00:00 2001 From: Kshitij Aranke Date: Tue, 18 Jun 2024 21:09:45 +0100 Subject: [PATCH 08/12] minor version bump to 1.4.0 (#150) Update __about__.py --- dbt_common/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt_common/__about__.py b/dbt_common/__about__.py index d28b3ddc..d619c757 100644 --- a/dbt_common/__about__.py +++ b/dbt_common/__about__.py @@ -1 +1 @@ -version = "1.3.0" +version = "1.4.0" From dbb2308e126d7306594d99a0b0521ba73822b724 Mon Sep 17 00:00:00 2001 From: Peter Webb Date: Wed, 26 Jun 2024 13:30:23 -0400 Subject: [PATCH 09/12] Support "just in time" loading of records, and add ID fields (#154) * Deserialize records "just in time" in order to avoid import order issues. * Add changelog entry * Typing and formatting fixes * Typing * Add id field to record. * Add more informative error message. * Tweak the way results are stored. * Typing and formatting fixes. --------- Co-authored-by: Emily Rockman --- .../Under the Hood-20240618-155025.yaml | 6 ++ dbt_common/clients/system.py | 4 +- dbt_common/record.py | 71 ++++++++++++------- 3 files changed, 54 insertions(+), 27 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20240618-155025.yaml diff --git a/.changes/unreleased/Under the Hood-20240618-155025.yaml b/.changes/unreleased/Under the Hood-20240618-155025.yaml new file mode 100644 index 00000000..b540d3d7 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240618-155025.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Deserialize Record objects on a just-in-time basis. +time: 2024-06-18T15:50:25.985387-04:00 +custom: + Author: peterallenwebb + Issue: "151" diff --git a/dbt_common/clients/system.py b/dbt_common/clients/system.py index bcf798d2..00a1ac69 100644 --- a/dbt_common/clients/system.py +++ b/dbt_common/clients/system.py @@ -62,7 +62,9 @@ def _include(self) -> bool: # Do not record or replay filesystem searches that were performed against # files which are actually part of dbt's implementation. return ( - "dbt/include/global_project" not in self.root_path + "dbt/include" + not in self.root_path # TODO: This actually obviates the next two checks but is probably too coarse? + and "dbt/include/global_project" not in self.root_path and "/plugins/postgres/dbt/include/" not in self.root_path ) diff --git a/dbt_common/record.py b/dbt_common/record.py index 9fed4c3d..c204faa6 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -2,7 +2,7 @@ external systems during a command invocation, so that the command can be re-run later with the recording 'replayed' to dbt. -The rationale for and architecture of this module is described in detail in the +The rationale for and architecture of this module are described in detail in the docs/guides/record_replay.md document in this repository. """ import functools @@ -12,7 +12,7 @@ from deepdiff import DeepDiff # type: ignore from enum import Enum -from typing import Any, Dict, List, Mapping, Optional, Type +from typing import Any, Callable, Dict, List, Mapping, Optional, Type from dbt_common.context import get_invocation_context @@ -129,10 +129,11 @@ def __init__( previous_recording_path: Optional[str] = None, ) -> None: self.mode = mode - self.types = types + self.recorded_types = types self._records_by_type: Dict[str, List[Record]] = {} + self._unprocessed_records_by_type: Dict[str, List[Dict[str, Any]]] = {} self._replay_diffs: List["Diff"] = [] - self.diff: Diff + self.diff: Optional[Diff] = None self.previous_recording_path = previous_recording_path self.current_recording_path = current_recording_path @@ -146,7 +147,7 @@ def __init__( ) if self.mode == RecorderMode.REPLAY: - self._records_by_type = self.load(self.previous_recording_path) + self._unprocessed_records_by_type = self.load(self.previous_recording_path) @classmethod def register_record_type(cls, rec_type) -> Any: @@ -161,7 +162,14 @@ def add_record(self, record: Record) -> None: self._records_by_type[rec_cls_name].append(record) def pop_matching_record(self, params: Any) -> Optional[Record]: - rec_type_name = self._record_name_by_params_name[type(params).__name__] + rec_type_name = self._record_name_by_params_name.get(type(params).__name__) + + if rec_type_name is None: + raise Exception( + f"A record of type {type(params).__name__} was requested, but no such type has been registered." + ) + + self._ensure_records_processed(rec_type_name) records = self._records_by_type[rec_type_name] match: Optional[Record] = None for rec in records: @@ -186,22 +194,20 @@ def _to_dict(self) -> Dict: return dct @classmethod - def load(cls, file_name: str) -> Dict[str, List[Record]]: + def load(cls, file_name: str) -> Dict[str, List[Dict[str, Any]]]: with open(file_name) as file: - loaded_dct = json.load(file) + return json.load(file) - records_by_type: Dict[str, List[Record]] = {} + def _ensure_records_processed(self, record_type_name: str) -> None: + if record_type_name in self._records_by_type: + return - for record_type_name in loaded_dct: - # TODO: this breaks with QueryRecord on replay since it's - # not in common so isn't part of cls._record_cls_by_name yet - record_cls = cls._record_cls_by_name[record_type_name] - rec_list = [] - for record_dct in loaded_dct[record_type_name]: - rec = record_cls.from_dict(record_dct) - rec_list.append(rec) # type: ignore - records_by_type[record_type_name] = rec_list - return records_by_type + rec_list = [] + record_cls = self._record_cls_by_name[record_type_name] + for record_dct in self._unprocessed_records_by_type[record_type_name]: + rec = record_cls.from_dict(record_dct) + rec_list.append(rec) # type: ignore + self._records_by_type[record_type_name] = rec_list def expect_record(self, params: Any) -> Any: record = self.pop_matching_record(params) @@ -209,16 +215,19 @@ def expect_record(self, params: Any) -> Any: if record is None: raise Exception() + if record.result is None: + return None + result_tuple = dataclasses.astuple(record.result) return result_tuple[0] if len(result_tuple) == 1 else result_tuple def write_diffs(self, diff_file_name) -> None: - json.dump( - self.diff.calculate_diff(), - open(diff_file_name, "w"), - ) + assert self.diff is not None + with open(diff_file_name, "w") as f: + json.dump(self.diff.calculate_diff(), f) def print_diffs(self) -> None: + assert self.diff is not None print(repr(self.diff.calculate_diff())) @@ -273,7 +282,12 @@ def get_record_types_from_dict(fp: str) -> List: return list(loaded_dct.keys()) -def record_function(record_type, method=False, tuple_result=False): +def record_function( + record_type, + method: bool = False, + tuple_result: bool = False, + id_field_name: Optional[str] = None, +) -> Callable: def record_function_inner(func_to_record): # To avoid runtime overhead and other unpleasantness, we only apply the # record/replay decorator if a relevant env var is set. @@ -291,12 +305,17 @@ def record_replay_wrapper(*args, **kwargs): if recorder is None: return func_to_record(*args, **kwargs) - if recorder.types is not None and record_type.__name__ not in recorder.types: + if ( + recorder.recorded_types is not None + and record_type.__name__ not in recorder.recorded_types + ): return func_to_record(*args, **kwargs) # For methods, peel off the 'self' argument before calling the # params constructor. param_args = args[1:] if method else args + if method and id_field_name is not None: + param_args = (getattr(args[0], id_field_name),) + param_args params = record_type.params_cls(*param_args, **kwargs) @@ -313,7 +332,7 @@ def record_replay_wrapper(*args, **kwargs): r = func_to_record(*args, **kwargs) result = ( None - if r is None or record_type.result_cls is None + if record_type.result_cls is None else record_type.result_cls(*r) if tuple_result else record_type.result_cls(r) From 002377d299e24912d8f43132c3ff71b9ffe335f9 Mon Sep 17 00:00:00 2001 From: Peter Webb Date: Wed, 26 Jun 2024 16:10:15 -0400 Subject: [PATCH 10/12] Deserialize records "just in time" in order to avoid import order isssue (#151) * Deserialize records "just in time" in order to avoid import order issues. * Add changelog entry * Typing and formatting fixes * Typing --------- Co-authored-by: Emily Rockman From d25c29a73925a39fdf6fb0737868683dc61ae84b Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Tue, 2 Jul 2024 12:00:38 -0500 Subject: [PATCH 11/12] bump version to 1.5.0 (#160) --- dbt_common/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt_common/__about__.py b/dbt_common/__about__.py index d619c757..e3a0f015 100644 --- a/dbt_common/__about__.py +++ b/dbt_common/__about__.py @@ -1 +1 @@ -version = "1.4.0" +version = "1.5.0" From feda22dc08950efb21e4ae4feb1b4fd20a9b090b Mon Sep 17 00:00:00 2001 From: Peter Webb Date: Thu, 4 Jul 2024 17:26:03 -0400 Subject: [PATCH 12/12] Improve Type Annotations Coverage (#162) * Add some type annotations * More type annotation. * Freedom from typlessness. * More annotations. * Clean up linter issues --- dbt_common/context.py | 5 +- dbt_common/dataclass_schema.py | 8 +- dbt_common/exceptions/base.py | 6 +- dbt_common/helper_types.py | 4 +- dbt_common/record.py | 8 +- dbt_common/semver.py | 30 ++-- dbt_common/utils/casting.py | 6 +- dbt_common/utils/executor.py | 9 +- dbt_common/utils/jinja.py | 8 +- tests/unit/test_agate_helper.py | 22 +-- tests/unit/test_connection_retries.py | 11 +- tests/unit/test_contextvars.py | 2 +- tests/unit/test_contracts_util.py | 2 +- tests/unit/test_core_dbt_utils.py | 48 ++--- tests/unit/test_diff.py | 8 +- tests/unit/test_event_handler.py | 4 +- tests/unit/test_helper_types.py | 27 +-- tests/unit/test_invocation_context.py | 8 +- tests/unit/test_jinja.py | 246 +++++++++++++++----------- tests/unit/test_model_config.py | 49 +++-- tests/unit/test_proto_events.py | 4 +- tests/unit/test_record.py | 4 +- tests/unit/test_semver.py | 24 +-- tests/unit/test_system_client.py | 56 +++--- tests/unit/test_ui.py | 4 +- tests/unit/test_utils.py | 14 +- 26 files changed, 343 insertions(+), 274 deletions(-) diff --git a/dbt_common/context.py b/dbt_common/context.py index a46b1dd2..d1775c55 100644 --- a/dbt_common/context.py +++ b/dbt_common/context.py @@ -2,6 +2,7 @@ from typing import List, Mapping, Optional from dbt_common.constants import PRIVATE_ENV_PREFIX, SECRET_ENV_PREFIX +from dbt_common.record import Recorder class InvocationContext: @@ -9,7 +10,7 @@ def __init__(self, env: Mapping[str, str]): self._env = {k: v for k, v in env.items() if not k.startswith(PRIVATE_ENV_PREFIX)} self._env_secrets: Optional[List[str]] = None self._env_private = {k: v for k, v in env.items() if k.startswith(PRIVATE_ENV_PREFIX)} - self.recorder = None + self.recorder: Optional[Recorder] = None # This class will also eventually manage the invocation_id, flags, event manager, etc. @property @@ -32,7 +33,7 @@ def env_secrets(self) -> List[str]: _INVOCATION_CONTEXT_VAR: ContextVar[InvocationContext] = ContextVar("DBT_INVOCATION_CONTEXT_VAR") -def reliably_get_invocation_var() -> ContextVar: +def reliably_get_invocation_var() -> ContextVar[InvocationContext]: invocation_var: Optional[ContextVar] = next( (cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None ) diff --git a/dbt_common/dataclass_schema.py b/dbt_common/dataclass_schema.py index 0bad081f..867d5a4c 100644 --- a/dbt_common/dataclass_schema.py +++ b/dbt_common/dataclass_schema.py @@ -1,4 +1,4 @@ -from typing import ClassVar, cast, get_type_hints, List, Tuple, Dict, Any, Optional +from typing import Any, cast, ClassVar, Dict, get_type_hints, List, Optional, Tuple import re import jsonschema from dataclasses import fields, Field @@ -26,7 +26,7 @@ class ValidationError(jsonschema.ValidationError): class DateTimeSerialization(SerializationStrategy): - def serialize(self, value) -> str: + def serialize(self, value: datetime) -> str: out = value.isoformat() # Assume UTC if timezone is missing if value.tzinfo is None: @@ -127,7 +127,7 @@ def _get_fields(cls) -> List[Tuple[Field, str]]: # copied from hologram. Used in tests @classmethod - def _get_field_names(cls): + def _get_field_names(cls) -> List[str]: return [element[1] for element in cls._get_fields()] @@ -152,7 +152,7 @@ def validate(cls, value): # These classes must be in this order or it doesn't work class StrEnum(str, SerializableType, Enum): - def __str__(self): + def __str__(self) -> str: return self.value # https://docs.python.org/3.6/library/enum.html#using-automatic-values diff --git a/dbt_common/exceptions/base.py b/dbt_common/exceptions/base.py index db619326..d966a28d 100644 --- a/dbt_common/exceptions/base.py +++ b/dbt_common/exceptions/base.py @@ -1,5 +1,5 @@ import builtins -from typing import List, Any, Optional +from typing import Any, List, Optional import os from dbt_common.constants import SECRET_ENV_PREFIX @@ -37,7 +37,7 @@ def __init__(self, msg: str): self.msg = scrub_secrets(msg, env_secrets()) @property - def type(self): + def type(self) -> str: return "Internal" def process_stack(self): @@ -59,7 +59,7 @@ def process_stack(self): return lines - def __str__(self): + def __str__(self) -> str: if hasattr(self.msg, "split"): split_msg = self.msg.split("\n") else: diff --git a/dbt_common/helper_types.py b/dbt_common/helper_types.py index 0ca435b7..8611f39f 100644 --- a/dbt_common/helper_types.py +++ b/dbt_common/helper_types.py @@ -19,7 +19,7 @@ class NVEnum(StrEnum): novalue = "novalue" - def __eq__(self, other): + def __eq__(self, other) -> bool: return isinstance(other, NVEnum) @@ -59,7 +59,7 @@ def includes(self, item_name: str) -> bool: item_name in self.include or self.include in self.INCLUDE_ALL ) and item_name not in self.exclude - def _validate_items(self, items: List[str]): + def _validate_items(self, items: List[str]) -> None: pass diff --git a/dbt_common/record.py b/dbt_common/record.py index c204faa6..8fe068bb 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -14,8 +14,6 @@ from enum import Enum from typing import Any, Callable, Dict, List, Mapping, Optional, Type -from dbt_common.context import get_invocation_context - class Record: """An instance of this abstract Record class represents a request made by dbt @@ -295,9 +293,11 @@ def record_function_inner(func_to_record): return func_to_record @functools.wraps(func_to_record) - def record_replay_wrapper(*args, **kwargs): - recorder: Recorder = None + def record_replay_wrapper(*args, **kwargs) -> Any: + recorder: Optional[Recorder] = None try: + from dbt_common.context import get_invocation_context + recorder = get_invocation_context().recorder except LookupError: pass diff --git a/dbt_common/semver.py b/dbt_common/semver.py index 951f4e8e..fbdcefa5 100644 --- a/dbt_common/semver.py +++ b/dbt_common/semver.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import re -from typing import List +from typing import List, Iterable import dbt_common.exceptions.base from dbt_common.exceptions import VersionsNotCompatibleError @@ -74,7 +74,7 @@ def _cmp(a, b): @dataclass class VersionSpecifier(VersionSpecification): - def to_version_string(self, skip_matcher=False): + def to_version_string(self, skip_matcher: bool = False) -> str: prerelease = "" build = "" matcher = "" @@ -92,7 +92,7 @@ def to_version_string(self, skip_matcher=False): ) @classmethod - def from_version_string(cls, version_string): + def from_version_string(cls, version_string: str) -> "VersionSpecifier": match = _VERSION_REGEX.match(version_string) if not match: @@ -104,7 +104,7 @@ def from_version_string(cls, version_string): return cls.from_dict(matched) - def __str__(self): + def __str__(self) -> str: return self.to_version_string() def to_range(self) -> "VersionRange": @@ -192,32 +192,32 @@ def compare(self, other): return 0 - def __lt__(self, other): + def __lt__(self, other) -> bool: return self.compare(other) == -1 - def __gt__(self, other): + def __gt__(self, other) -> bool: return self.compare(other) == 1 - def __eq___(self, other): + def __eq___(self, other) -> bool: return self.compare(other) == 0 def __cmp___(self, other): return self.compare(other) @property - def is_unbounded(self): + def is_unbounded(self) -> bool: return False @property - def is_lower_bound(self): + def is_lower_bound(self) -> bool: return self.matcher in [Matchers.GREATER_THAN, Matchers.GREATER_THAN_OR_EQUAL] @property - def is_upper_bound(self): + def is_upper_bound(self) -> bool: return self.matcher in [Matchers.LESS_THAN, Matchers.LESS_THAN_OR_EQUAL] @property - def is_exact(self): + def is_exact(self) -> bool: return self.matcher == Matchers.EXACT @classmethod @@ -418,7 +418,7 @@ def reduce_versions(*args): return to_return -def versions_compatible(*args): +def versions_compatible(*args) -> bool: if len(args) == 1: return True @@ -429,7 +429,7 @@ def versions_compatible(*args): return False -def find_possible_versions(requested_range, available_versions): +def find_possible_versions(requested_range, available_versions: Iterable[str]): possible_versions = [] for version_string in available_versions: @@ -442,7 +442,9 @@ def find_possible_versions(requested_range, available_versions): return [v.to_version_string(skip_matcher=True) for v in sorted_versions] -def resolve_to_specific_version(requested_range, available_versions): +def resolve_to_specific_version( + requested_range, available_versions: Iterable[str] +) -> Optional[str]: max_version = None max_version_string = None diff --git a/dbt_common/utils/casting.py b/dbt_common/utils/casting.py index 811ea376..f366db7f 100644 --- a/dbt_common/utils/casting.py +++ b/dbt_common/utils/casting.py @@ -1,7 +1,7 @@ # This is useful for proto generated classes in particular, since # the default for protobuf for strings is the empty string, so # Optional[str] types don't work for generated Python classes. -from typing import Optional +from typing import Any, Dict, Optional def cast_to_str(string: Optional[str]) -> str: @@ -18,8 +18,8 @@ def cast_to_int(integer: Optional[int]) -> int: return integer -def cast_dict_to_dict_of_strings(dct): - new_dct = {} +def cast_dict_to_dict_of_strings(dct: Dict[Any, Any]) -> Dict[str, str]: + new_dct: Dict[str, str] = {} for k, v in dct.items(): new_dct[str(k)] = str(v) return new_dct diff --git a/dbt_common/utils/executor.py b/dbt_common/utils/executor.py index 0dd8490c..529b02be 100644 --- a/dbt_common/utils/executor.py +++ b/dbt_common/utils/executor.py @@ -1,9 +1,12 @@ import concurrent.futures from contextlib import contextmanager -from contextvars import ContextVar from typing import Protocol, Optional -from dbt_common.context import get_invocation_context, reliably_get_invocation_var +from dbt_common.context import ( + get_invocation_context, + reliably_get_invocation_var, + InvocationContext, +) class ConnectingExecutor(concurrent.futures.Executor): @@ -63,7 +66,7 @@ class HasThreadingConfig(Protocol): threads: Optional[int] -def _thread_initializer(invocation_context: ContextVar) -> None: +def _thread_initializer(invocation_context: InvocationContext) -> None: invocation_var = reliably_get_invocation_var() invocation_var.set(invocation_context) diff --git a/dbt_common/utils/jinja.py b/dbt_common/utils/jinja.py index 36464cbe..260ccb6a 100644 --- a/dbt_common/utils/jinja.py +++ b/dbt_common/utils/jinja.py @@ -5,19 +5,21 @@ DOCS_PREFIX = "dbt_docs__" -def get_dbt_macro_name(name): +def get_dbt_macro_name(name) -> str: if name is None: raise DbtInternalError("Got None for a macro name!") return f"{MACRO_PREFIX}{name}" -def get_dbt_docs_name(name): +def get_dbt_docs_name(name) -> str: if name is None: raise DbtInternalError("Got None for a doc name!") return f"{DOCS_PREFIX}{name}" -def get_materialization_macro_name(materialization_name, adapter_type=None, with_prefix=True): +def get_materialization_macro_name( + materialization_name, adapter_type=None, with_prefix=True +) -> str: if adapter_type is None: adapter_type = "default" name = f"materialization_{materialization_name}_{adapter_type}" diff --git a/tests/unit/test_agate_helper.py b/tests/unit/test_agate_helper.py index 4c12bcd8..fff0d4c6 100644 --- a/tests/unit/test_agate_helper.py +++ b/tests/unit/test_agate_helper.py @@ -46,13 +46,13 @@ class TestAgateHelper(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.tempdir = mkdtemp() - def tearDown(self): + def tearDown(self) -> None: rmtree(self.tempdir) - def test_from_csv(self): + def test_from_csv(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_DATA.encode("utf-8")) @@ -61,7 +61,7 @@ def test_from_csv(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_bom_from_csv(self): + def test_bom_from_csv(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_BOM_DATA.encode("utf-8")) @@ -70,7 +70,7 @@ def test_bom_from_csv(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_from_csv_all_reserved(self): + def test_from_csv_all_reserved(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_DATA.encode("utf-8")) @@ -79,7 +79,7 @@ def test_from_csv_all_reserved(self): for expected, row in zip(EXPECTED_STRINGS, tbl): self.assertEqual(list(row), expected) - def test_from_data(self): + def test_from_data(self) -> None: column_names = ["a", "b", "c", "d", "e", "f", "g"] data = [ { @@ -106,7 +106,7 @@ def test_from_data(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_datetime_formats(self): + def test_datetime_formats(self) -> None: path = os.path.join(self.tempdir, "input.csv") datetimes = [ "20180806T11:33:29.000Z", @@ -120,7 +120,7 @@ def test_datetime_formats(self): tbl = agate_helper.from_csv(path, ()) self.assertEqual(tbl[0][0], expected) - def test_merge_allnull(self): + def test_merge_allnull(self) -> None: t1 = agate_helper.table_from_rows([(1, "a", None), (2, "b", None)], ("a", "b", "c")) t2 = agate_helper.table_from_rows([(3, "c", None), (4, "d", None)], ("a", "b", "c")) result = agate_helper.merge_tables([t1, t2]) @@ -130,7 +130,7 @@ def test_merge_allnull(self): assert isinstance(result.column_types[2], agate_helper.Integer) self.assertEqual(len(result), 4) - def test_merge_mixed(self): + def test_merge_mixed(self) -> None: t1 = agate_helper.table_from_rows( [(1, "a", None, None), (2, "b", None, None)], ("a", "b", "c", "d") ) @@ -181,7 +181,7 @@ def test_merge_mixed(self): assert isinstance(result.column_types[3], agate.data_types.Number) self.assertEqual(len(result), 6) - def test_nocast_string_types(self): + def test_nocast_string_types(self) -> None: # String fields should not be coerced into a representative type # See: https://github.com/dbt-labs/dbt-core/issues/2984 @@ -202,7 +202,7 @@ def test_nocast_string_types(self): for i, row in enumerate(tbl): self.assertEqual(list(row), expected[i]) - def test_nocast_bool_01(self): + def test_nocast_bool_01(self) -> None: # True and False values should not be cast to 1 and 0, and vice versa # See: https://github.com/dbt-labs/dbt-core/issues/4511 diff --git a/tests/unit/test_connection_retries.py b/tests/unit/test_connection_retries.py index 817af7a2..44fc72f5 100644 --- a/tests/unit/test_connection_retries.py +++ b/tests/unit/test_connection_retries.py @@ -19,20 +19,23 @@ def test_no_retry(self): assert result == expected -def no_success_fn(): +def no_success_fn() -> str: raise RequestException("You'll never pass") return "failure" class TestMaxRetries: - def test_no_retry(self): + def test_no_retry(self) -> None: fn_to_retry = functools.partial(no_success_fn) with pytest.raises(ConnectionError): connection_exception_retry(fn_to_retry, 3) -def single_retry_fn(): +counter = 0 + + +def single_retry_fn() -> str: global counter if counter == 0: counter += 1 @@ -45,7 +48,7 @@ def single_retry_fn(): class TestSingleRetry: - def test_no_retry(self): + def test_no_retry(self) -> None: global counter counter = 0 diff --git a/tests/unit/test_contextvars.py b/tests/unit/test_contextvars.py index 4eb58e6c..1aa9425f 100644 --- a/tests/unit/test_contextvars.py +++ b/tests/unit/test_contextvars.py @@ -1,7 +1,7 @@ from dbt_common.events.contextvars import log_contextvars, get_node_info, set_log_contextvars -def test_contextvars(): +def test_contextvars() -> None: node_info = { "unique_id": "model.test.my_model", "started_at": None, diff --git a/tests/unit/test_contracts_util.py b/tests/unit/test_contracts_util.py index 2a620370..d2fc4493 100644 --- a/tests/unit/test_contracts_util.py +++ b/tests/unit/test_contracts_util.py @@ -13,7 +13,7 @@ class ExampleMergableClass(Mergeable): class TestMergableClass(unittest.TestCase): - def test_mergeability(self): + def test_mergeability(self) -> None: mergeable1 = ExampleMergableClass( attr_a="loses", attr_b=None, attr_c=["I'll", "still", "exist"] ) diff --git a/tests/unit/test_core_dbt_utils.py b/tests/unit/test_core_dbt_utils.py index 8a0e836e..7419cd8d 100644 --- a/tests/unit/test_core_dbt_utils.py +++ b/tests/unit/test_core_dbt_utils.py @@ -7,30 +7,30 @@ class TestCommonDbtUtils(unittest.TestCase): - def test_connection_exception_retry_none(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add(self), 5) + def test_connection_exception_retry_none(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add(), 5) self.assertEqual(1, counter) - def test_connection_exception_retry_success_requests_exception(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_requests_exception(self), 5) + def test_connection_exception_retry_success_requests_exception(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_requests_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned None, plus 1 retry - def test_connection_exception_retry_max(self): - Counter._reset(self) + def test_connection_exception_retry_max(self) -> None: + Counter._reset() with self.assertRaises(ConnectionError): - connection_exception_retry(lambda: Counter._add_with_exception(self), 5) + connection_exception_retry(lambda: Counter._add_with_exception(), 5) self.assertEqual(6, counter) # 6 = original attempt plus 5 retries - def test_connection_exception_retry_success_failed_untar(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_untar_exception(self), 5) + def test_connection_exception_retry_success_failed_untar(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_untar_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned ReadError, plus 1 retry - def test_connection_exception_retry_success_failed_eofexception(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_eof_exception(self), 5) + def test_connection_exception_retry_success_failed_eofexception(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_eof_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned EOFError, plus 1 retry @@ -38,36 +38,42 @@ def test_connection_exception_retry_success_failed_eofexception(self): class Counter: - def _add(self): + @classmethod + def _add(cls) -> None: global counter counter += 1 # All exceptions that Requests explicitly raises inherit from # requests.exceptions.RequestException so we want to make sure that raises plus one exception # that inherit from it for sanity - def _add_with_requests_exception(self): + @classmethod + def _add_with_requests_exception(cls) -> None: global counter counter += 1 if counter < 2: raise requests.exceptions.RequestException - def _add_with_exception(self): + @classmethod + def _add_with_exception(cls) -> None: global counter counter += 1 raise requests.exceptions.ConnectionError - def _add_with_untar_exception(self): + @classmethod + def _add_with_untar_exception(cls) -> None: global counter counter += 1 if counter < 2: raise tarfile.ReadError - def _add_with_eof_exception(self): + @classmethod + def _add_with_eof_exception(cls) -> None: global counter counter += 1 if counter < 2: raise EOFError - def _reset(self): + @classmethod + def _reset(cls) -> None: global counter counter = 0 diff --git a/tests/unit/test_diff.py b/tests/unit/test_diff.py index 791263f3..54f735e3 100644 --- a/tests/unit/test_diff.py +++ b/tests/unit/test_diff.py @@ -1,4 +1,6 @@ import json +from typing import Any, Dict + import pytest from dbt_common.record import Diff @@ -191,7 +193,7 @@ def open_mock(file, *args, **kwargs): return open_mock -def test_calculate_diff_no_diff(monkeypatch): +def test_calculate_diff_no_diff(monkeypatch) -> None: # Mock data for the files current_recording_data = { "GetEnvRecord": [ @@ -251,11 +253,11 @@ def test_calculate_diff_no_diff(monkeypatch): previous_recording_path=previous_recording_path, ) result = diff_instance.calculate_diff() - expected_result = {"GetEnvRecord": {}, "DefaultKey": {}} + expected_result: Dict[str, Any] = {"GetEnvRecord": {}, "DefaultKey": {}} assert result == expected_result -def test_calculate_diff_with_diff(monkeypatch): +def test_calculate_diff_with_diff(monkeypatch) -> None: # Mock data for the files current_recording_data = { "GetEnvRecord": [ diff --git a/tests/unit/test_event_handler.py b/tests/unit/test_event_handler.py index 80d5ae2b..f38938b6 100644 --- a/tests/unit/test_event_handler.py +++ b/tests/unit/test_event_handler.py @@ -5,7 +5,7 @@ from dbt_common.events.event_manager import TestEventManager -def test_event_logging_handler_emits_records_correctly(): +def test_event_logging_handler_emits_records_correctly() -> None: event_manager = TestEventManager() handler = DbtEventLoggingHandler(event_manager=event_manager, level=logging.DEBUG) log = logging.getLogger("test") @@ -27,7 +27,7 @@ def test_event_logging_handler_emits_records_correctly(): assert event_manager.event_history[5][1] == EventLevel.ERROR -def test_set_package_logging_sets_level_correctly(): +def test_set_package_logging_sets_level_correctly() -> None: event_manager = TestEventManager() log = logging.getLogger("test") set_package_logging("test", logging.DEBUG, event_manager) diff --git a/tests/unit/test_helper_types.py b/tests/unit/test_helper_types.py index 1a9519de..ba98803c 100644 --- a/tests/unit/test_helper_types.py +++ b/tests/unit/test_helper_types.py @@ -1,11 +1,12 @@ import pytest +from typing import List, Union from dbt_common.helper_types import IncludeExclude, WarnErrorOptions from dbt_common.dataclass_schema import ValidationError class TestIncludeExclude: - def test_init_invalid(self): + def test_init_invalid(self) -> None: with pytest.raises(ValidationError): IncludeExclude(include="invalid") @@ -22,14 +23,16 @@ def test_init_invalid(self): (["ItemA", "ItemB"], [], True), ], ) - def test_includes(self, include, exclude, expected_includes): + def test_includes( + self, include: Union[str, List[str]], exclude: List[str], expected_includes: bool + ) -> None: include_exclude = IncludeExclude(include=include, exclude=exclude) assert include_exclude.includes("ItemA") == expected_includes class TestWarnErrorOptions: - def test_init_invalid_error(self): + def test_init_invalid_error(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include=["InvalidError"], valid_error_names=set(["ValidError"])) @@ -38,14 +41,14 @@ def test_init_invalid_error(self): include="*", exclude=["InvalidError"], valid_error_names=set(["ValidError"]) ) - def test_init_invalid_error_default_valid_error_names(self): + def test_init_invalid_error_default_valid_error_names(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include=["InvalidError"]) with pytest.raises(ValidationError): WarnErrorOptions(include="*", exclude=["InvalidError"]) - def test_init_valid_error(self): + def test_init_valid_error(self) -> None: warn_error_options = WarnErrorOptions( include=["ValidError"], valid_error_names=set(["ValidError"]) ) @@ -58,18 +61,18 @@ def test_init_valid_error(self): assert warn_error_options.include == "*" assert warn_error_options.exclude == ["ValidError"] - def test_init_default_silence(self): + def test_init_default_silence(self) -> None: my_options = WarnErrorOptions(include="*") assert my_options.silence == [] - def test_init_invalid_silence_event(self): + def test_init_invalid_silence_event(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include="*", silence=["InvalidError"]) - def test_init_valid_silence_event(self): + def test_init_valid_silence_event(self) -> None: all_events = ["MySilencedEvent"] my_options = WarnErrorOptions( - include="*", silence=all_events, valid_error_names=all_events + include="*", silence=all_events, valid_error_names=set(all_events) ) assert my_options.silence == all_events @@ -81,14 +84,16 @@ def test_init_valid_silence_event(self): ("*", ["ItemB"], True), ], ) - def test_includes(self, include, silence, expected_includes): + def test_includes( + self, include: Union[str, List[str]], silence: List[str], expected_includes: bool + ) -> None: include_exclude = WarnErrorOptions( include=include, silence=silence, valid_error_names={"ItemA", "ItemB"} ) assert include_exclude.includes("ItemA") == expected_includes - def test_silenced(self): + def test_silenced(self) -> None: my_options = WarnErrorOptions(include="*", silence=["ItemA"], valid_error_names={"ItemA"}) assert my_options.silenced("ItemA") assert not my_options.silenced("ItemB") diff --git a/tests/unit/test_invocation_context.py b/tests/unit/test_invocation_context.py index b6697f8e..3dc832d3 100644 --- a/tests/unit/test_invocation_context.py +++ b/tests/unit/test_invocation_context.py @@ -2,13 +2,13 @@ from dbt_common.context import InvocationContext -def test_invocation_context_env(): +def test_invocation_context_env() -> None: test_env = {"VAR_1": "value1", "VAR_2": "value2"} ic = InvocationContext(env=test_env) assert ic.env == test_env -def test_invocation_context_secrets(): +def test_invocation_context_secrets() -> None: test_env = { f"{SECRET_ENV_PREFIX}_VAR_1": "secret1", f"{SECRET_ENV_PREFIX}VAR_2": "secret2", @@ -16,10 +16,10 @@ def test_invocation_context_secrets(): f"foo{SECRET_ENV_PREFIX}": "non-secret", } ic = InvocationContext(env=test_env) - assert set(ic.env_secrets) == set(["secret1", "secret2"]) + assert set(ic.env_secrets) == {"secret1", "secret2"} -def test_invocation_context_private(): +def test_invocation_context_private() -> None: test_env = { f"{PRIVATE_ENV_PREFIX}_VAR_1": "private1", f"{PRIVATE_ENV_PREFIX}VAR_2": "private2", diff --git a/tests/unit/test_jinja.py b/tests/unit/test_jinja.py index f038a1ec..e906a0ac 100644 --- a/tests/unit/test_jinja.py +++ b/tests/unit/test_jinja.py @@ -1,23 +1,26 @@ import unittest +from dbt_common.clients._jinja_blocks import BlockTag from dbt_common.clients.jinja import extract_toplevel_blocks from dbt_common.exceptions import CompilationError class TestBlockLexer(unittest.TestCase): - def test_basic(self): + def test_basic(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" blocks = extract_toplevel_blocks( block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_multiple(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_multiple(self) -> None: body_one = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' body_two = ( "{{ config(bar=1)}}\r\nselect * from {% if foo %} thing " @@ -37,7 +40,7 @@ def test_multiple(self): ) self.assertEqual(len(blocks), 2) - def test_comments(self): + def test_comments(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' comment = "{# my comment #}" block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" @@ -45,12 +48,14 @@ def test_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_evil_comments(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_evil_comments(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' comment = ( "{# external comment {% othertype bar %} select * from " @@ -61,12 +66,14 @@ def test_evil_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_nested_comments(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_nested_comments(self) -> None: body = ( '{# my comment #} {{ config(foo="bar") }}' "\r\nselect * from {# my other comment embedding {% endmytype %} #} this.that\r\n" @@ -80,33 +87,43 @@ def test_nested_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_complex_file(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_complex_file(self) -> None: blocks = extract_toplevel_blocks( complex_snapshot_file, allowed_blocks={"mytype", "myothertype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 3) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].full_block, "{% mytype foo %} some stuff {% endmytype %}") - self.assertEqual(blocks[0].contents, " some stuff ") - self.assertEqual(blocks[1].block_type_name, "mytype") - self.assertEqual(blocks[1].block_name, "bar") - self.assertEqual(blocks[1].full_block, bar_block) - self.assertEqual(blocks[1].contents, bar_block[16:-15].rstrip()) - self.assertEqual(blocks[2].block_type_name, "myothertype") - self.assertEqual(blocks[2].block_name, "x") - self.assertEqual(blocks[2].full_block, x_block.strip()) + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.full_block, "{% mytype foo %} some stuff {% endmytype %}") + self.assertEqual(b0.contents, " some stuff ") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "mytype") + self.assertEqual(b1.block_name, "bar") + self.assertEqual(b1.full_block, bar_block) + self.assertEqual(b1.contents, bar_block[16:-15].rstrip()) + + b2 = blocks[2] + assert isinstance(b2, BlockTag) + self.assertEqual(b2.block_type_name, "myothertype") + self.assertEqual(b2.block_name, "x") + self.assertEqual(b2.full_block, x_block.strip()) self.assertEqual( - blocks[2].contents, + b2.contents, x_block[len("\n{% myothertype x %}") : -len("{% endmyothertype %}\n")], ) - def test_peaceful_macro_coexistence(self): + def test_peaceful_macro_coexistence(self) -> None: body = ( "{# my macro #} {% macro foo(a, b) %} do a thing " "{%- endmacro %} {# my model #} {% a b %} test {% enda %}" @@ -116,15 +133,22 @@ def test_peaceful_macro_coexistence(self): ) self.assertEqual(len(blocks), 4) self.assertEqual(blocks[0].full_block, "{# my macro #} ") - self.assertEqual(blocks[1].block_type_name, "macro") - self.assertEqual(blocks[1].block_name, "foo") - self.assertEqual(blocks[1].contents, " do a thing") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "macro") + self.assertEqual(b1.block_name, "foo") + self.assertEqual(b1.contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") - self.assertEqual(blocks[3].block_type_name, "a") - self.assertEqual(blocks[3].block_name, "b") - self.assertEqual(blocks[3].contents, " test ") - def test_macro_with_trailing_data(self): + b3 = blocks[3] + assert isinstance(b3, BlockTag) + self.assertEqual(b3.block_type_name, "a") + self.assertEqual(b3.block_name, "b") + self.assertEqual(b3.contents, " test ") + + def test_macro_with_trailing_data(self) -> None: body = ( "{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} " "{# my model #} {% a b %} test {% enda %} raw data so cool" @@ -134,16 +158,24 @@ def test_macro_with_trailing_data(self): ) self.assertEqual(len(blocks), 5) self.assertEqual(blocks[0].full_block, "{# my macro #} ") - self.assertEqual(blocks[1].block_type_name, "macro") - self.assertEqual(blocks[1].block_name, "foo") - self.assertEqual(blocks[1].contents, " do a thing") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "macro") + self.assertEqual(b1.block_name, "foo") + self.assertEqual(b1.contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") - self.assertEqual(blocks[3].block_type_name, "a") - self.assertEqual(blocks[3].block_name, "b") - self.assertEqual(blocks[3].contents, " test ") + + b3 = blocks[3] + assert isinstance(b3, BlockTag) + self.assertEqual(b3.block_type_name, "a") + self.assertEqual(b3.block_name, "b") + self.assertEqual(b3.contents, " test ") + self.assertEqual(blocks[4].full_block, " raw data so cool") - def test_macro_with_crazy_args(self): + def test_macro_with_crazy_args(self) -> None: body = ( """{% macro foo(a, b=asdf("cool this is 'embedded'" * 3) + external_var, c)%}""" "cool{# block comment with {% endmacro %} in it #} stuff here " @@ -151,38 +183,44 @@ def test_macro_with_crazy_args(self): ) blocks = extract_toplevel_blocks(body, allowed_blocks={"macro"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "macro") - self.assertEqual(blocks[0].block_name, "foo") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "macro") + self.assertEqual(b0.block_name, "foo") self.assertEqual( blocks[0].contents, "cool{# block comment with {% endmacro %} in it #} stuff here " ) - def test_materialization_parse(self): + def test_materialization_parse(self) -> None: body = "{% materialization xxx, default %} ... {% endmaterialization %}" blocks = extract_toplevel_blocks( body, allowed_blocks={"materialization"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "materialization") - self.assertEqual(blocks[0].block_name, "xxx") - self.assertEqual(blocks[0].full_block, body) + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "materialization") + self.assertEqual(b0.block_name, "xxx") + self.assertEqual(b0.full_block, body) body = '{% materialization xxx, adapter="other" %} ... {% endmaterialization %}' blocks = extract_toplevel_blocks( body, allowed_blocks={"materialization"}, collect_raw_data=False ) + b0 = blocks[0] + assert isinstance(b0, BlockTag) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "materialization") - self.assertEqual(blocks[0].block_name, "xxx") - self.assertEqual(blocks[0].full_block, body) + self.assertEqual(b0.block_type_name, "materialization") + self.assertEqual(b0.block_name, "xxx") + self.assertEqual(b0.full_block, body) - def test_nested_not_ok(self): + def test_nested_not_ok(self) -> None: # we don't allow nesting same blocks body = "{% myblock a %} {% myblock b %} {% endmyblock %} {% endmyblock %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock"}) - def test_incomplete_block_failure(self): + def test_incomplete_block_failure(self) -> None: fullbody = "{% myblock foo %} {% endmyblock %}" for length in range(len("{% myblock foo %}"), len(fullbody) - 1): body = fullbody[:length] @@ -194,45 +232,45 @@ def test_wrong_end_failure(self): with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"}) - def test_comment_no_end_failure(self): + def test_comment_no_end_failure(self) -> None: body = "{# " with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_comment_only(self): + def test_comment_only(self) -> None: body = "{# myblock #}" blocks = extract_toplevel_blocks(body) self.assertEqual(len(blocks), 1) blocks = extract_toplevel_blocks(body, collect_raw_data=False) self.assertEqual(len(blocks), 0) - def test_comment_block_self_closing(self): + def test_comment_block_self_closing(self) -> None: # test the case where a comment start looks a lot like it closes itself # (but it doesn't in jinja!) body = "{#} {% myblock foo %} {#}" blocks = extract_toplevel_blocks(body, collect_raw_data=False) self.assertEqual(len(blocks), 0) - def test_embedded_self_closing_comment_block(self): + def test_embedded_self_closing_comment_block(self) -> None: body = "{% myblock foo %} {#}{% endmyblock %} {#}{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) self.assertEqual(blocks[0].contents, " {#}{% endmyblock %} {#}") - def test_set_statement(self): + def test_set_statement(self) -> None: body = "{% set x = 1 %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_set_block(self): + def test_set_block(self) -> None: body = "{% set x %}1{% endset %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_crazy_set_statement(self): + def test_crazy_set_statement(self) -> None: body = ( '{% set x = (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}' '{% set y = otherthing("{% myblock foo %}") %}' @@ -244,19 +282,19 @@ def test_crazy_set_statement(self): self.assertEqual(blocks[0].full_block, "{% otherblock bar %}x{% endotherblock %}") self.assertEqual(blocks[0].block_type_name, "otherblock") - def test_do_statement(self): + def test_do_statement(self) -> None: body = "{% do thing.update() %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_deceptive_do_statement(self): + def test_deceptive_do_statement(self) -> None: body = "{% do thing %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_do_block(self): + def test_do_block(self) -> None: body = "{% do %}thing.update(){% enddo %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks( body, allowed_blocks={"do", "myblock"}, collect_raw_data=False @@ -266,7 +304,7 @@ def test_do_block(self): self.assertEqual(blocks[0].block_type_name, "do") self.assertEqual(blocks[1].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_crazy_do_statement(self): + def test_crazy_do_statement(self) -> None: body = ( '{% do (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}' '{% do otherthing("{% myblock foo %}") %}{% myblock x %}hi{% endmyblock %}' @@ -280,7 +318,7 @@ def test_crazy_do_statement(self): self.assertEqual(blocks[1].full_block, "{% myblock x %}hi{% endmyblock %}") self.assertEqual(blocks[1].block_type_name, "myblock") - def test_awful_jinja(self): + def test_awful_jinja(self) -> None: blocks = extract_toplevel_blocks( if_you_do_this_you_are_awful, allowed_blocks={"snapshot", "materialization"}, @@ -304,63 +342,71 @@ def test_awful_jinja(self): self.assertEqual(blocks[1].block_type_name, "materialization") self.assertEqual(blocks[1].contents, "\nhi\n") - def test_quoted_endblock_within_block(self): + def test_quoted_endblock_within_block(self) -> None: body = '{% myblock something -%} {% set x = ("{% endmyblock %}") %} {% endmyblock %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].block_type_name, "myblock") self.assertEqual(blocks[0].contents, '{% set x = ("{% endmyblock %}") %} ') - def test_docs_block(self): + def test_docs_block(self) -> None: body = ( "{% docs __my_doc__ %} asdf {# nope {% enddocs %}} #} {% enddocs %}" '{% docs __my_other_doc__ %} asdf "{% enddocs %}' ) blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) self.assertEqual(len(blocks), 2) - self.assertEqual(blocks[0].block_type_name, "docs") - self.assertEqual(blocks[0].contents, " asdf {# nope {% enddocs %}} #} ") - self.assertEqual(blocks[0].block_name, "__my_doc__") - self.assertEqual(blocks[1].block_type_name, "docs") - self.assertEqual(blocks[1].contents, ' asdf "') - self.assertEqual(blocks[1].block_name, "__my_other_doc__") - - def test_docs_block_expr(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "docs") + self.assertEqual(b0.contents, " asdf {# nope {% enddocs %}} #} ") + self.assertEqual(b0.block_name, "__my_doc__") + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "docs") + self.assertEqual(b1.contents, ' asdf "') + self.assertEqual(b1.block_name, "__my_other_doc__") + + def test_docs_block_expr(self) -> None: body = '{% docs more_doc %} asdf {{ "{% enddocs %}" ~ "}}" }}{% enddocs %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "docs") - self.assertEqual(blocks[0].contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}') - self.assertEqual(blocks[0].block_name, "more_doc") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "docs") + self.assertEqual(b0.contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}') + self.assertEqual(b0.block_name, "more_doc") - def test_unclosed_model_quotes(self): + def test_unclosed_model_quotes(self) -> None: # test case for https://github.com/dbt-labs/dbt-core/issues/1533 body = '{% model my_model -%} select * from "something"."something_else{% endmodel %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"model"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "model") - self.assertEqual(blocks[0].contents, 'select * from "something"."something_else') - self.assertEqual(blocks[0].block_name, "my_model") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "model") + self.assertEqual(b0.contents, 'select * from "something"."something_else') + self.assertEqual(b0.block_name, "my_model") - def test_if(self): + def test_if(self) -> None: # if you conditionally define your macros/models, don't body = "{% if true %}{% macro my_macro() %} adsf {% endmacro %}{% endif %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_if_innocuous(self): + def test_if_innocuous(self) -> None: body = "{% if true %}{% something %}asdfasd{% endsomething %}{% endif %}" blocks = extract_toplevel_blocks(body) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) - def test_for(self): + def test_for(self) -> None: # no for-loops over macros. body = "{% for x in range(10) %}{% macro my_macro() %} adsf {% endmacro %}{% endfor %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_for_innocuous(self): + def test_for_innocuous(self) -> None: # no for-loops over macros. body = ( "{% for x in range(10) %}{% something my_something %} adsf " @@ -370,7 +416,7 @@ def test_for_innocuous(self): self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) - def test_endif(self): + def test_endif(self) -> None: body = "{% snapshot foo %}select * from thing{% endsnapshot%}{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) @@ -382,7 +428,7 @@ def test_endif(self): str(err.exception), ) - def test_if_endfor(self): + def test_if_endfor(self) -> None: body = "{% if x %}...{% endfor %}{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) @@ -391,7 +437,7 @@ def test_if_endfor(self): str(err.exception), ) - def test_if_endfor_newlines(self): + def test_if_endfor_newlines(self) -> None: body = "{% if x %}\n ...\n {% endfor %}\n{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) diff --git a/tests/unit/test_model_config.py b/tests/unit/test_model_config.py index 0cc1e711..57a14438 100644 --- a/tests/unit/test_model_config.py +++ b/tests/unit/test_model_config.py @@ -14,7 +14,7 @@ class ThingWithMergeBehavior(dbtClassMixin): keysappended: Dict[str, int] = field(metadata={"merge": MergeBehavior.DictKeyAppend}) -def test_merge_behavior_meta(): +def test_merge_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(MergeBehavior) == { @@ -29,15 +29,14 @@ def test_merge_behavior_meta(): assert existing == initial_existing -def test_merge_behavior_from_field(): - fields = [f[0] for f in ThingWithMergeBehavior._get_fields()] - fields = {name: f for f, name in ThingWithMergeBehavior._get_fields()} - assert set(fields) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"} - assert MergeBehavior.from_field(fields["default_behavior"]) == MergeBehavior.Clobber - assert MergeBehavior.from_field(fields["appended"]) == MergeBehavior.Append - assert MergeBehavior.from_field(fields["updated"]) == MergeBehavior.Update - assert MergeBehavior.from_field(fields["clobbered"]) == MergeBehavior.Clobber - assert MergeBehavior.from_field(fields["keysappended"]) == MergeBehavior.DictKeyAppend +def test_merge_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithMergeBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"} + assert MergeBehavior.from_field(fields2["default_behavior"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields2["appended"]) == MergeBehavior.Append + assert MergeBehavior.from_field(fields2["updated"]) == MergeBehavior.Update + assert MergeBehavior.from_field(fields2["clobbered"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields2["keysappended"]) == MergeBehavior.DictKeyAppend @dataclass @@ -47,7 +46,7 @@ class ThingWithShowBehavior(dbtClassMixin): shown: float = field(metadata={"show_hide": ShowBehavior.Show}) -def test_show_behavior_meta(): +def test_show_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(ShowBehavior) == {ShowBehavior.Hide, ShowBehavior.Show} @@ -57,13 +56,12 @@ def test_show_behavior_meta(): assert existing == initial_existing -def test_show_behavior_from_field(): - fields = [f[0] for f in ThingWithShowBehavior._get_fields()] - fields = {name: f for f, name in ThingWithShowBehavior._get_fields()} - assert set(fields) == {"default_behavior", "hidden", "shown"} - assert ShowBehavior.from_field(fields["default_behavior"]) == ShowBehavior.Show - assert ShowBehavior.from_field(fields["hidden"]) == ShowBehavior.Hide - assert ShowBehavior.from_field(fields["shown"]) == ShowBehavior.Show +def test_show_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithShowBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "hidden", "shown"} + assert ShowBehavior.from_field(fields2["default_behavior"]) == ShowBehavior.Show + assert ShowBehavior.from_field(fields2["hidden"]) == ShowBehavior.Hide + assert ShowBehavior.from_field(fields2["shown"]) == ShowBehavior.Show @dataclass @@ -73,7 +71,7 @@ class ThingWithCompareBehavior(dbtClassMixin): excluded: str = field(metadata={"compare": CompareBehavior.Exclude}) -def test_compare_behavior_meta(): +def test_compare_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(CompareBehavior) == {CompareBehavior.Include, CompareBehavior.Exclude} @@ -83,10 +81,9 @@ def test_compare_behavior_meta(): assert existing == initial_existing -def test_compare_behavior_from_field(): - fields = [f[0] for f in ThingWithCompareBehavior._get_fields()] - fields = {name: f for f, name in ThingWithCompareBehavior._get_fields()} - assert set(fields) == {"default_behavior", "included", "excluded"} - assert CompareBehavior.from_field(fields["default_behavior"]) == CompareBehavior.Include - assert CompareBehavior.from_field(fields["included"]) == CompareBehavior.Include - assert CompareBehavior.from_field(fields["excluded"]) == CompareBehavior.Exclude +def test_compare_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithCompareBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "included", "excluded"} + assert CompareBehavior.from_field(fields2["default_behavior"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields2["included"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields2["excluded"]) == CompareBehavior.Exclude diff --git a/tests/unit/test_proto_events.py b/tests/unit/test_proto_events.py index 32eb08ae..d21b5062 100644 --- a/tests/unit/test_proto_events.py +++ b/tests/unit/test_proto_events.py @@ -18,7 +18,7 @@ } -def test_events(): +def test_events() -> None: # M020 event event_code = "M020" event = RetryExternalCall(attempt=3, max=5) @@ -45,7 +45,7 @@ def test_events(): assert new_msg.data.attempt == msg.data.attempt -def test_extra_dict_on_event(monkeypatch): +def test_extra_dict_on_event(monkeypatch) -> None: monkeypatch.setenv("DBT_ENV_CUSTOM_ENV_env_key", "env_value") reset_metadata_vars() diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index b0371498..6e02d710 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -69,7 +69,7 @@ def setup(): os.environ["DBT_RECORDER_FILE_PATH"] = prev_fp -def test_decorator_records(setup): +def test_decorator_records(setup) -> None: os.environ["DBT_RECORDER_MODE"] = "Record" recorder = Recorder(RecorderMode.RECORD, None) set_invocation_context({}) @@ -116,7 +116,7 @@ def not_test_func(a: int, b: str, c: Optional[str] = None) -> str: assert NotTestRecord not in recorder._records_by_type -def test_decorator_replays(setup): +def test_decorator_replays(setup) -> None: os.environ["DBT_RECORDER_MODE"] = "Replay" os.environ["DBT_RECORDER_FILE_PATH"] = "record.json" recorder = Recorder(RecorderMode.REPLAY, None) diff --git a/tests/unit/test_semver.py b/tests/unit/test_semver.py index ae48e592..383d3479 100644 --- a/tests/unit/test_semver.py +++ b/tests/unit/test_semver.py @@ -1,6 +1,6 @@ import itertools import unittest -from typing import List +from typing import List, Optional from dbt_common.exceptions import VersionsNotCompatibleError from dbt_common.semver import ( @@ -23,9 +23,11 @@ def semver_regex_versioning(versions: List[str]) -> bool: return True -def create_range(start_version_string, end_version_string): - start = UnboundedVersionSpecifier() - end = UnboundedVersionSpecifier() +def create_range( + start_version_string: Optional[str], end_version_string: Optional[str] +) -> VersionRange: + start: VersionSpecifier = UnboundedVersionSpecifier() + end: VersionSpecifier = UnboundedVersionSpecifier() if start_version_string is not None: start = VersionSpecifier.from_version_string(start_version_string) @@ -37,24 +39,24 @@ def create_range(start_version_string, end_version_string): class TestSemver(unittest.TestCase): - def assertVersionSetResult(self, inputs, output_range): + def assertVersionSetResult(self, inputs, output_range) -> None: expected = create_range(*output_range) for permutation in itertools.permutations(inputs): self.assertEqual(reduce_versions(*permutation), expected) - def assertInvalidVersionSet(self, inputs): + def assertInvalidVersionSet(self, inputs) -> None: for permutation in itertools.permutations(inputs): with self.assertRaises(VersionsNotCompatibleError): reduce_versions(*permutation) - def test__versions_compatible(self): + def test__versions_compatible(self) -> None: self.assertTrue(versions_compatible("0.0.1", "0.0.1")) self.assertFalse(versions_compatible("0.0.1", "0.0.2")) self.assertTrue(versions_compatible(">0.0.1", "0.0.2")) self.assertFalse(versions_compatible("0.4.5a1", "0.4.5a2")) - def test__semver_regex_versions(self): + def test__semver_regex_versions(self) -> None: self.assertTrue( semver_regex_versioning( [ @@ -140,7 +142,7 @@ def test__semver_regex_versions(self): ) ) - def test__reduce_versions(self): + def test__reduce_versions(self) -> None: self.assertVersionSetResult(["0.0.1", "0.0.1"], ["=0.0.1", "=0.0.1"]) self.assertVersionSetResult(["0.0.1"], ["=0.0.1", "=0.0.1"]) @@ -175,7 +177,7 @@ def test__reduce_versions(self): self.assertInvalidVersionSet(["<0.0.3", ">=0.0.3"]) self.assertInvalidVersionSet(["<0.0.3", ">0.0.3"]) - def test__resolve_to_specific_version(self): + def test__resolve_to_specific_version(self) -> None: self.assertEqual( resolve_to_specific_version(create_range(">0.0.1", None), ["0.0.1", "0.0.2"]), "0.0.2" ) @@ -253,7 +255,7 @@ def test__resolve_to_specific_version(self): "0.9.1", ) - def test__filter_installable(self): + def test__filter_installable(self) -> None: installable = filter_installable( [ "1.1.0", diff --git a/tests/unit/test_system_client.py b/tests/unit/test_system_client.py index a4dcc323..d2cf27ed 100644 --- a/tests/unit/test_system_client.py +++ b/tests/unit/test_system_client.py @@ -12,39 +12,39 @@ class SystemClient(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.tmp_dir = mkdtemp() self.profiles_path = "{}/profiles.yml".format(self.tmp_dir) - def set_up_profile(self): + def set_up_profile(self) -> None: with open(self.profiles_path, "w") as f: f.write("ORIGINAL_TEXT") - def get_profile_text(self): + def get_profile_text(self) -> str: with open(self.profiles_path, "r") as f: return f.read() - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.tmp_dir) except Exception as e: # noqa: F841 pass - def test__make_file_when_exists(self): + def test__make_file_when_exists(self) -> None: self.set_up_profile() written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") self.assertFalse(written) self.assertEqual(self.get_profile_text(), "ORIGINAL_TEXT") - def test__make_file_when_not_exists(self): + def test__make_file_when_not_exists(self) -> None: written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") self.assertTrue(written) self.assertEqual(self.get_profile_text(), "NEW_TEXT") - def test__make_file_with_overwrite(self): + def test__make_file_with_overwrite(self) -> None: self.set_up_profile() written = dbt_common.clients.system.make_file( self.profiles_path, contents="NEW_TEXT", overwrite=True @@ -53,12 +53,12 @@ def test__make_file_with_overwrite(self): self.assertTrue(written) self.assertEqual(self.get_profile_text(), "NEW_TEXT") - def test__make_dir_from_str(self): + def test__make_dir_from_str(self) -> None: test_dir_str = self.tmp_dir + "/test_make_from_str/sub_dir" dbt_common.clients.system.make_directory(test_dir_str) self.assertTrue(Path(test_dir_str).is_dir()) - def test__make_dir_from_pathobj(self): + def test__make_dir_from_pathobj(self) -> None: test_dir_pathobj = Path(self.tmp_dir + "/test_make_from_pathobj/sub_dir") dbt_common.clients.system.make_directory(test_dir_pathobj) self.assertTrue(test_dir_pathobj.is_dir()) @@ -72,7 +72,7 @@ class TestRunCmd(unittest.TestCase): not_a_file = "zzzbbfasdfasdfsdaq" - def setUp(self): + def setUp(self) -> None: self.tempdir = mkdtemp() self.run_dir = os.path.join(self.tempdir, "run_dir") self.does_not_exist = os.path.join(self.tempdir, "does_not_exist") @@ -86,10 +86,10 @@ def setUp(self): with open(self.empty_file, "w") as fp: # noqa: F841 pass # "touch" - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def test__executable_does_not_exist(self): + def test__executable_does_not_exist(self) -> None: with self.assertRaises(ExecutableError) as exc: dbt_common.clients.system.run_cmd(self.run_dir, [self.does_not_exist]) @@ -99,7 +99,7 @@ def test__executable_does_not_exist(self): self.assertIn("could not find", msg) self.assertIn(self.does_not_exist.lower(), msg) - def test__not_exe(self): + def test__not_exe(self) -> None: with self.assertRaises(ExecutableError) as exc: dbt_common.clients.system.run_cmd(self.run_dir, [self.empty_file]) @@ -112,14 +112,14 @@ def test__not_exe(self): self.assertIn("permissions", msg) self.assertIn(self.empty_file.lower(), msg) - def test__cwd_does_not_exist(self): + def test__cwd_does_not_exist(self) -> None: with self.assertRaises(WorkingDirectoryError) as exc: dbt_common.clients.system.run_cmd(self.does_not_exist, self.exists_cmd) msg = str(exc.exception).lower() self.assertIn("does not exist", msg) self.assertIn(self.does_not_exist.lower(), msg) - def test__cwd_not_directory(self): + def test__cwd_not_directory(self) -> None: with self.assertRaises(WorkingDirectoryError) as exc: dbt_common.clients.system.run_cmd(self.empty_file, self.exists_cmd) @@ -127,7 +127,7 @@ def test__cwd_not_directory(self): self.assertIn("not a directory", msg) self.assertIn(self.empty_file.lower(), msg) - def test__cwd_no_permissions(self): + def test__cwd_no_permissions(self) -> None: # it would be nice to add a windows test. Possible path to that is via # `psexec` (to get SYSTEM privs), use `icacls` to set permissions on # the directory for the test user. I'm pretty sure windows users can't @@ -145,18 +145,18 @@ def test__cwd_no_permissions(self): self.assertIn("permissions", msg) self.assertIn(self.run_dir.lower(), msg) - def test__ok(self): + def test__ok(self) -> None: out, err = dbt_common.clients.system.run_cmd(self.run_dir, self.exists_cmd) self.assertEqual(out.strip(), b"hello") self.assertEqual(err.strip(), b"") class TestFindMatching(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.base_dir = mkdtemp() self.tempdir = mkdtemp(dir=self.base_dir) - def test_find_matching_lowercase_file_pattern(self): + def test_find_matching_lowercase_file_pattern(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir) as named_file: file_path = os.path.dirname(named_file.name) relative_path = os.path.basename(file_path) @@ -175,7 +175,7 @@ def test_find_matching_lowercase_file_pattern(self): ] self.assertEqual(out, expected_output) - def test_find_matching_uppercase_file_pattern(self): + def test_find_matching_uppercase_file_pattern(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".SQL", dir=self.tempdir) as named_file: file_path = os.path.dirname(named_file.name) relative_path = os.path.basename(file_path) @@ -190,12 +190,12 @@ def test_find_matching_uppercase_file_pattern(self): ] self.assertEqual(out, expected_output) - def test_find_matching_file_pattern_not_found(self): + def test_find_matching_file_pattern_not_found(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".SQLT", dir=self.tempdir): out = dbt_common.clients.system.find_matching(self.tempdir, [""], "*.sql") self.assertEqual(out, []) - def test_ignore_spec(self): + def test_ignore_spec(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir): out = dbt_common.clients.system.find_matching( self.tempdir, @@ -207,7 +207,7 @@ def test_ignore_spec(self): ) self.assertEqual(out, []) - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.base_dir) except Exception as e: # noqa: F841 @@ -215,18 +215,18 @@ def tearDown(self): class TestUntarPackage(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.base_dir = mkdtemp() self.tempdir = mkdtemp(dir=self.base_dir) self.tempdest = mkdtemp(dir=self.base_dir) - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.base_dir) except Exception as e: # noqa: F841 pass - def test_untar_package_success(self): + def test_untar_package_success(self) -> None: # set up a valid tarball to test against with NamedTemporaryFile( prefix="my-package.2", suffix=".tar.gz", dir=self.tempdir, delete=False @@ -244,7 +244,7 @@ def test_untar_package_success(self): path = Path(os.path.join(self.tempdest, relative_file_a)) assert path.is_file() - def test_untar_package_failure(self): + def test_untar_package_failure(self) -> None: # create a text file then rename it as a tar (so it's invalid) with NamedTemporaryFile( prefix="a", suffix=".txt", dir=self.tempdir, delete=False @@ -259,7 +259,7 @@ def test_untar_package_failure(self): with self.assertRaises(tarfile.ReadError) as exc: # noqa: F841 dbt_common.clients.system.untar_package(tar_file_path, self.tempdest) - def test_untar_package_empty(self): + def test_untar_package_empty(self) -> None: # create a tarball with nothing in it with NamedTemporaryFile( prefix="my-empty-package.2", suffix=".tar.gz", dir=self.tempdir diff --git a/tests/unit/test_ui.py b/tests/unit/test_ui.py index 22e431d5..5b70b1d1 100644 --- a/tests/unit/test_ui.py +++ b/tests/unit/test_ui.py @@ -1,11 +1,11 @@ from dbt_common.ui import warning_tag, error_tag -def test_warning_tag(): +def test_warning_tag() -> None: tagged = warning_tag("hi") assert "WARNING" in tagged -def test_error_tag(): +def test_error_tag() -> None: tagged = error_tag("hi") assert "ERROR" in tagged diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 250c20cc..93c57046 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -5,7 +5,7 @@ class TestDeepMerge(unittest.TestCase): - def test__simple_cases(self): + def test__simple_cases(self) -> None: cases = [ {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, { @@ -27,7 +27,7 @@ def test__simple_cases(self): class TestMerge(unittest.TestCase): - def test__simple_cases(self): + def test__simple_cases(self) -> None: cases = [ {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, { @@ -49,7 +49,7 @@ def test__simple_cases(self): class TestDeepMap(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.input_value = { "foo": { "bar": "hello", @@ -74,7 +74,7 @@ def intify_all(value, _): except (TypeError, ValueError): return -1 - def test__simple_cases(self): + def test__simple_cases(self) -> None: expected = { "foo": { "bar": -1, @@ -104,7 +104,7 @@ def special_keypath(value, keypath): else: return value - def test__keypath(self): + def test__keypath(self) -> None: expected = { "foo": { "bar": "hello", @@ -128,11 +128,11 @@ def test__keypath(self): actual = dbt_common.utils.dict.deep_map_render(self.special_keypath, expected) self.assertEqual(actual, expected) - def test__noop(self): + def test__noop(self) -> None: actual = dbt_common.utils.dict.deep_map_render(lambda x, _: x, self.input_value) self.assertEqual(actual, self.input_value) - def test_trivial(self): + def test_trivial(self) -> None: cases = [[], {}, 1, "abc", None, True] for case in cases: result = dbt_common.utils.dict.deep_map_render(lambda x, _: x, case)