diff --git a/.changes/unreleased/Under the Hood-20240528-110518.yaml b/.changes/unreleased/Under the Hood-20240528-110518.yaml new file mode 100644 index 00000000..58b243c7 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240528-110518.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Allow dynamic selection of record types when recording. +time: 2024-05-28T11:05:18.290107-05:00 +custom: + Author: emmyoop + Issue: "140" diff --git a/.changes/unreleased/Under the Hood-20240529-143154.yaml b/.changes/unreleased/Under the Hood-20240529-143154.yaml new file mode 100644 index 00000000..e9664254 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240529-143154.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Move StatsItem, StatsDict, TableMetadata to dbt-common +time: 2024-05-29T14:31:54.854468+01:00 +custom: + Author: aranke + Issue: "141" diff --git a/.changes/unreleased/Under the Hood-20240603-123631.yaml b/.changes/unreleased/Under the Hood-20240603-123631.yaml new file mode 100644 index 00000000..18a649df --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240603-123631.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Move CatalogKey, ColumnMetadata, ColumnMap, CatalogTable to dbt-common +time: 2024-06-03T12:36:31.542118+02:00 +custom: + Author: aranke + Issue: "147" diff --git a/.changes/unreleased/Under the Hood-20240617-204541.yaml b/.changes/unreleased/Under the Hood-20240617-204541.yaml new file mode 100644 index 00000000..b4e33c36 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240617-204541.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Add support for basic diff of run recordings. +time: 2024-06-17T20:45:41.123374-05:00 +custom: + Author: emmyoop + Issue: "144" diff --git a/.changes/unreleased/Under the Hood-20240618-155025.yaml b/.changes/unreleased/Under the Hood-20240618-155025.yaml new file mode 100644 index 00000000..b540d3d7 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240618-155025.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Deserialize Record objects on a just-in-time basis. +time: 2024-06-18T15:50:25.985387-04:00 +custom: + Author: peterallenwebb + Issue: "151" diff --git a/.gitignore b/.gitignore index 28030861..5ad1a6c2 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ diff --git a/dbt_common/__about__.py b/dbt_common/__about__.py index b2b60a55..e3a0f015 100644 --- a/dbt_common/__about__.py +++ b/dbt_common/__about__.py @@ -1 +1 @@ -version = "1.1.0" +version = "1.5.0" diff --git a/dbt_common/clients/system.py b/dbt_common/clients/system.py index bcf798d2..00a1ac69 100644 --- a/dbt_common/clients/system.py +++ b/dbt_common/clients/system.py @@ -62,7 +62,9 @@ def _include(self) -> bool: # Do not record or replay filesystem searches that were performed against # files which are actually part of dbt's implementation. return ( - "dbt/include/global_project" not in self.root_path + "dbt/include" + not in self.root_path # TODO: This actually obviates the next two checks but is probably too coarse? + and "dbt/include/global_project" not in self.root_path and "/plugins/postgres/dbt/include/" not in self.root_path ) diff --git a/dbt_common/context.py b/dbt_common/context.py index a46b1dd2..d1775c55 100644 --- a/dbt_common/context.py +++ b/dbt_common/context.py @@ -2,6 +2,7 @@ from typing import List, Mapping, Optional from dbt_common.constants import PRIVATE_ENV_PREFIX, SECRET_ENV_PREFIX +from dbt_common.record import Recorder class InvocationContext: @@ -9,7 +10,7 @@ def __init__(self, env: Mapping[str, str]): self._env = {k: v for k, v in env.items() if not k.startswith(PRIVATE_ENV_PREFIX)} self._env_secrets: Optional[List[str]] = None self._env_private = {k: v for k, v in env.items() if k.startswith(PRIVATE_ENV_PREFIX)} - self.recorder = None + self.recorder: Optional[Recorder] = None # This class will also eventually manage the invocation_id, flags, event manager, etc. @property @@ -32,7 +33,7 @@ def env_secrets(self) -> List[str]: _INVOCATION_CONTEXT_VAR: ContextVar[InvocationContext] = ContextVar("DBT_INVOCATION_CONTEXT_VAR") -def reliably_get_invocation_var() -> ContextVar: +def reliably_get_invocation_var() -> ContextVar[InvocationContext]: invocation_var: Optional[ContextVar] = next( (cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None ) diff --git a/dbt_common/contracts/metadata.py b/dbt_common/contracts/metadata.py new file mode 100644 index 00000000..d71e79a7 --- /dev/null +++ b/dbt_common/contracts/metadata.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Union, NamedTuple + +from dbt_common.dataclass_schema import dbtClassMixin +from dbt_common.utils.formatting import lowercase + + +@dataclass +class StatsItem(dbtClassMixin): + id: str + label: str + value: Union[bool, str, float, None] + include: bool + description: Optional[str] = None + + +StatsDict = Dict[str, StatsItem] + + +@dataclass +class TableMetadata(dbtClassMixin): + type: str + schema: str + name: str + database: Optional[str] = None + comment: Optional[str] = None + owner: Optional[str] = None + + +CatalogKey = NamedTuple( + "CatalogKey", [("database", Optional[str]), ("schema", str), ("name", str)] +) + + +@dataclass +class ColumnMetadata(dbtClassMixin): + type: str + index: int + name: str + comment: Optional[str] = None + + +ColumnMap = Dict[str, ColumnMetadata] + + +@dataclass +class CatalogTable(dbtClassMixin): + metadata: TableMetadata + columns: ColumnMap + stats: StatsDict + # the same table with two unique IDs will just be listed two times + unique_id: Optional[str] = None + + def key(self) -> CatalogKey: + return CatalogKey( + lowercase(self.metadata.database), + self.metadata.schema.lower(), + self.metadata.name.lower(), + ) diff --git a/dbt_common/dataclass_schema.py b/dbt_common/dataclass_schema.py index 0bad081f..867d5a4c 100644 --- a/dbt_common/dataclass_schema.py +++ b/dbt_common/dataclass_schema.py @@ -1,4 +1,4 @@ -from typing import ClassVar, cast, get_type_hints, List, Tuple, Dict, Any, Optional +from typing import Any, cast, ClassVar, Dict, get_type_hints, List, Optional, Tuple import re import jsonschema from dataclasses import fields, Field @@ -26,7 +26,7 @@ class ValidationError(jsonschema.ValidationError): class DateTimeSerialization(SerializationStrategy): - def serialize(self, value) -> str: + def serialize(self, value: datetime) -> str: out = value.isoformat() # Assume UTC if timezone is missing if value.tzinfo is None: @@ -127,7 +127,7 @@ def _get_fields(cls) -> List[Tuple[Field, str]]: # copied from hologram. Used in tests @classmethod - def _get_field_names(cls): + def _get_field_names(cls) -> List[str]: return [element[1] for element in cls._get_fields()] @@ -152,7 +152,7 @@ def validate(cls, value): # These classes must be in this order or it doesn't work class StrEnum(str, SerializableType, Enum): - def __str__(self): + def __str__(self) -> str: return self.value # https://docs.python.org/3.6/library/enum.html#using-automatic-values diff --git a/dbt_common/exceptions/base.py b/dbt_common/exceptions/base.py index db619326..d966a28d 100644 --- a/dbt_common/exceptions/base.py +++ b/dbt_common/exceptions/base.py @@ -1,5 +1,5 @@ import builtins -from typing import List, Any, Optional +from typing import Any, List, Optional import os from dbt_common.constants import SECRET_ENV_PREFIX @@ -37,7 +37,7 @@ def __init__(self, msg: str): self.msg = scrub_secrets(msg, env_secrets()) @property - def type(self): + def type(self) -> str: return "Internal" def process_stack(self): @@ -59,7 +59,7 @@ def process_stack(self): return lines - def __str__(self): + def __str__(self) -> str: if hasattr(self.msg, "split"): split_msg = self.msg.split("\n") else: diff --git a/dbt_common/helper_types.py b/dbt_common/helper_types.py index 0ca435b7..8611f39f 100644 --- a/dbt_common/helper_types.py +++ b/dbt_common/helper_types.py @@ -19,7 +19,7 @@ class NVEnum(StrEnum): novalue = "novalue" - def __eq__(self, other): + def __eq__(self, other) -> bool: return isinstance(other, NVEnum) @@ -59,7 +59,7 @@ def includes(self, item_name: str) -> bool: item_name in self.include or self.include in self.INCLUDE_ALL ) and item_name not in self.exclude - def _validate_items(self, items: List[str]): + def _validate_items(self, items: List[str]) -> None: pass diff --git a/dbt_common/record.py b/dbt_common/record.py index b2b5ba48..8fe068bb 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -2,17 +2,17 @@ external systems during a command invocation, so that the command can be re-run later with the recording 'replayed' to dbt. -The rationale for and architecture of this module is described in detail in the +The rationale for and architecture of this module are described in detail in the docs/guides/record_replay.md document in this repository. """ import functools import dataclasses import json import os -from enum import Enum -from typing import Any, Dict, List, Mapping, Optional, Type -from dbt_common.context import get_invocation_context +from deepdiff import DeepDiff # type: ignore +from enum import Enum +from typing import Any, Callable, Dict, List, Mapping, Optional, Type class Record: @@ -51,27 +51,101 @@ def from_dict(cls, dct: Mapping) -> "Record": class Diff: - """Marker class for diffs?""" + def __init__(self, current_recording_path: str, previous_recording_path: str) -> None: + self.current_recording_path = current_recording_path + self.previous_recording_path = previous_recording_path + + def diff_query_records(self, current: List, previous: List) -> Dict[str, Any]: + # some of the table results are returned as a stringified list of dicts that don't + # diff because order isn't consistent. convert it into a list of dicts so it can + # be diffed ignoring order + + for i in range(len(current)): + if current[i].get("result").get("table") is not None: + current[i]["result"]["table"] = json.loads(current[i]["result"]["table"]) + for i in range(len(previous)): + if previous[i].get("result").get("table") is not None: + previous[i]["result"]["table"] = json.loads(previous[i]["result"]["table"]) + + return DeepDiff(previous, current, ignore_order=True, verbose_level=2) + + def diff_env_records(self, current: List, previous: List) -> Dict[str, Any]: + # The mode and filepath may change. Ignore them. + + exclude_paths = [ + "root[0]['result']['env']['DBT_RECORDER_FILE_PATH']", + "root[0]['result']['env']['DBT_RECORDER_MODE']", + ] + + return DeepDiff( + previous, current, ignore_order=True, verbose_level=2, exclude_paths=exclude_paths + ) + + def diff_default(self, current: List, previous: List) -> Dict[str, Any]: + return DeepDiff(previous, current, ignore_order=True, verbose_level=2) + + def calculate_diff(self) -> Dict[str, Any]: + with open(self.current_recording_path) as current_recording: + current_dct = json.load(current_recording) + + with open(self.previous_recording_path) as previous_recording: + previous_dct = json.load(previous_recording) + + diff = {} + for record_type in current_dct: + if record_type == "QueryRecord": + diff[record_type] = self.diff_query_records( + current_dct[record_type], previous_dct[record_type] + ) + elif record_type == "GetEnvRecord": + diff[record_type] = self.diff_env_records( + current_dct[record_type], previous_dct[record_type] + ) + else: + diff[record_type] = self.diff_default( + current_dct[record_type], previous_dct[record_type] + ) - pass + return diff class RecorderMode(Enum): RECORD = 1 REPLAY = 2 + DIFF = 3 # records and does diffing class Recorder: _record_cls_by_name: Dict[str, Type] = {} _record_name_by_params_name: Dict[str, str] = {} - def __init__(self, mode: RecorderMode, recording_path: Optional[str] = None) -> None: + def __init__( + self, + mode: RecorderMode, + types: Optional[List], + current_recording_path: str = "recording.json", + previous_recording_path: Optional[str] = None, + ) -> None: self.mode = mode + self.recorded_types = types self._records_by_type: Dict[str, List[Record]] = {} + self._unprocessed_records_by_type: Dict[str, List[Dict[str, Any]]] = {} self._replay_diffs: List["Diff"] = [] + self.diff: Optional[Diff] = None + self.previous_recording_path = previous_recording_path + self.current_recording_path = current_recording_path + + if self.previous_recording_path is not None and self.mode in ( + RecorderMode.REPLAY, + RecorderMode.DIFF, + ): + self.diff = Diff( + current_recording_path=self.current_recording_path, + previous_recording_path=self.previous_recording_path, + ) - if recording_path is not None: - self._records_by_type = self.load(recording_path) + if self.mode == RecorderMode.REPLAY: + self._unprocessed_records_by_type = self.load(self.previous_recording_path) @classmethod def register_record_type(cls, rec_type) -> Any: @@ -86,7 +160,14 @@ def add_record(self, record: Record) -> None: self._records_by_type[rec_cls_name].append(record) def pop_matching_record(self, params: Any) -> Optional[Record]: - rec_type_name = self._record_name_by_params_name[type(params).__name__] + rec_type_name = self._record_name_by_params_name.get(type(params).__name__) + + if rec_type_name is None: + raise Exception( + f"A record of type {type(params).__name__} was requested, but no such type has been registered." + ) + + self._ensure_records_processed(rec_type_name) records = self._records_by_type[rec_type_name] match: Optional[Record] = None for rec in records: @@ -97,8 +178,8 @@ def pop_matching_record(self, params: Any) -> Optional[Record]: return match - def write(self, file_name) -> None: - with open(file_name, "w") as file: + def write(self) -> None: + with open(self.current_recording_path, "w") as file: json.dump(self._to_dict(), file) def _to_dict(self) -> Dict: @@ -111,21 +192,20 @@ def _to_dict(self) -> Dict: return dct @classmethod - def load(cls, file_name: str) -> Dict[str, List[Record]]: + def load(cls, file_name: str) -> Dict[str, List[Dict[str, Any]]]: with open(file_name) as file: - loaded_dct = json.load(file) + return json.load(file) - records_by_type: Dict[str, List[Record]] = {} + def _ensure_records_processed(self, record_type_name: str) -> None: + if record_type_name in self._records_by_type: + return - for record_type_name in loaded_dct: - record_cls = cls._record_cls_by_name[record_type_name] - rec_list = [] - for record_dct in loaded_dct[record_type_name]: - rec = record_cls.from_dict(record_dct) - rec_list.append(rec) # type: ignore - records_by_type[record_type_name] = rec_list - - return records_by_type + rec_list = [] + record_cls = self._record_cls_by_name[record_type_name] + for record_dct in self._unprocessed_records_by_type[record_type_name]: + rec = record_cls.from_dict(record_dct) + rec_list.append(rec) # type: ignore + self._records_by_type[record_type_name] = rec_list def expect_record(self, params: Any) -> Any: record = self.pop_matching_record(params) @@ -133,32 +213,79 @@ def expect_record(self, params: Any) -> Any: if record is None: raise Exception() + if record.result is None: + return None + result_tuple = dataclasses.astuple(record.result) return result_tuple[0] if len(result_tuple) == 1 else result_tuple def write_diffs(self, diff_file_name) -> None: - json.dump( - self._replay_diffs, - open(diff_file_name, "w"), - ) + assert self.diff is not None + with open(diff_file_name, "w") as f: + json.dump(self.diff.calculate_diff(), f) def print_diffs(self) -> None: - print(repr(self._replay_diffs)) + assert self.diff is not None + print(repr(self.diff.calculate_diff())) def get_record_mode_from_env() -> Optional[RecorderMode]: - replay_val = os.environ.get("DBT_REPLAY") - if replay_val is not None and replay_val != "0" and replay_val.lower() != "false": - return RecorderMode.REPLAY + """ + Get the record mode from the environment variables. + + If the mode is not set to 'RECORD', 'DIFF' or 'REPLAY', return None. + Expected format: 'DBT_RECORDER_MODE=RECORD' + """ + record_mode = os.environ.get("DBT_RECORDER_MODE") - record_val = os.environ.get("DBT_RECORD") - if record_val is not None and record_val != "0" and record_val.lower() != "false": + if record_mode is None: + return None + + if record_mode.lower() == "record": return RecorderMode.RECORD + # diffing requires a file path, otherwise treat as noop + elif record_mode.lower() == "diff" and os.environ.get("DBT_RECORDER_FILE_PATH") is not None: + return RecorderMode.DIFF + # replaying requires a file path, otherwise treat as noop + elif record_mode.lower() == "replay" and os.environ.get("DBT_RECORDER_FILE_PATH") is not None: + return RecorderMode.REPLAY + # if you don't specify record/replay it's a noop return None -def record_function(record_type, method=False, tuple_result=False): +def get_record_types_from_env() -> Optional[List]: + """ + Get the record subset from the environment variables. + + If no types are provided, there will be no filtering. + Invalid types will be ignored. + Expected format: 'DBT_RECORDER_TYPES=QueryRecord,FileLoadRecord,OtherRecord' + """ + record_types_str = os.environ.get("DBT_RECORDER_TYPES") + + # if all is specified we don't want any type filtering + if record_types_str is None or record_types_str.lower == "all": + return None + + return record_types_str.split(",") + + +def get_record_types_from_dict(fp: str) -> List: + """ + Get the record subset from the dict. + """ + with open(fp) as file: + loaded_dct = json.load(file) + return list(loaded_dct.keys()) + + +def record_function( + record_type, + method: bool = False, + tuple_result: bool = False, + id_field_name: Optional[str] = None, +) -> Callable: def record_function_inner(func_to_record): # To avoid runtime overhead and other unpleasantness, we only apply the # record/replay decorator if a relevant env var is set. @@ -166,9 +293,11 @@ def record_function_inner(func_to_record): return func_to_record @functools.wraps(func_to_record) - def record_replay_wrapper(*args, **kwargs): - recorder: Recorder = None + def record_replay_wrapper(*args, **kwargs) -> Any: + recorder: Optional[Recorder] = None try: + from dbt_common.context import get_invocation_context + recorder = get_invocation_context().recorder except LookupError: pass @@ -176,9 +305,17 @@ def record_replay_wrapper(*args, **kwargs): if recorder is None: return func_to_record(*args, **kwargs) + if ( + recorder.recorded_types is not None + and record_type.__name__ not in recorder.recorded_types + ): + return func_to_record(*args, **kwargs) + # For methods, peel off the 'self' argument before calling the # params constructor. param_args = args[1:] if method else args + if method and id_field_name is not None: + param_args = (getattr(args[0], id_field_name),) + param_args params = record_type.params_cls(*param_args, **kwargs) @@ -195,7 +332,7 @@ def record_replay_wrapper(*args, **kwargs): r = func_to_record(*args, **kwargs) result = ( None - if r is None or record_type.result_cls is None + if record_type.result_cls is None else record_type.result_cls(*r) if tuple_result else record_type.result_cls(r) diff --git a/dbt_common/semver.py b/dbt_common/semver.py index 951f4e8e..fbdcefa5 100644 --- a/dbt_common/semver.py +++ b/dbt_common/semver.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import re -from typing import List +from typing import List, Iterable import dbt_common.exceptions.base from dbt_common.exceptions import VersionsNotCompatibleError @@ -74,7 +74,7 @@ def _cmp(a, b): @dataclass class VersionSpecifier(VersionSpecification): - def to_version_string(self, skip_matcher=False): + def to_version_string(self, skip_matcher: bool = False) -> str: prerelease = "" build = "" matcher = "" @@ -92,7 +92,7 @@ def to_version_string(self, skip_matcher=False): ) @classmethod - def from_version_string(cls, version_string): + def from_version_string(cls, version_string: str) -> "VersionSpecifier": match = _VERSION_REGEX.match(version_string) if not match: @@ -104,7 +104,7 @@ def from_version_string(cls, version_string): return cls.from_dict(matched) - def __str__(self): + def __str__(self) -> str: return self.to_version_string() def to_range(self) -> "VersionRange": @@ -192,32 +192,32 @@ def compare(self, other): return 0 - def __lt__(self, other): + def __lt__(self, other) -> bool: return self.compare(other) == -1 - def __gt__(self, other): + def __gt__(self, other) -> bool: return self.compare(other) == 1 - def __eq___(self, other): + def __eq___(self, other) -> bool: return self.compare(other) == 0 def __cmp___(self, other): return self.compare(other) @property - def is_unbounded(self): + def is_unbounded(self) -> bool: return False @property - def is_lower_bound(self): + def is_lower_bound(self) -> bool: return self.matcher in [Matchers.GREATER_THAN, Matchers.GREATER_THAN_OR_EQUAL] @property - def is_upper_bound(self): + def is_upper_bound(self) -> bool: return self.matcher in [Matchers.LESS_THAN, Matchers.LESS_THAN_OR_EQUAL] @property - def is_exact(self): + def is_exact(self) -> bool: return self.matcher == Matchers.EXACT @classmethod @@ -418,7 +418,7 @@ def reduce_versions(*args): return to_return -def versions_compatible(*args): +def versions_compatible(*args) -> bool: if len(args) == 1: return True @@ -429,7 +429,7 @@ def versions_compatible(*args): return False -def find_possible_versions(requested_range, available_versions): +def find_possible_versions(requested_range, available_versions: Iterable[str]): possible_versions = [] for version_string in available_versions: @@ -442,7 +442,9 @@ def find_possible_versions(requested_range, available_versions): return [v.to_version_string(skip_matcher=True) for v in sorted_versions] -def resolve_to_specific_version(requested_range, available_versions): +def resolve_to_specific_version( + requested_range, available_versions: Iterable[str] +) -> Optional[str]: max_version = None max_version_string = None diff --git a/dbt_common/utils/casting.py b/dbt_common/utils/casting.py index 811ea376..f366db7f 100644 --- a/dbt_common/utils/casting.py +++ b/dbt_common/utils/casting.py @@ -1,7 +1,7 @@ # This is useful for proto generated classes in particular, since # the default for protobuf for strings is the empty string, so # Optional[str] types don't work for generated Python classes. -from typing import Optional +from typing import Any, Dict, Optional def cast_to_str(string: Optional[str]) -> str: @@ -18,8 +18,8 @@ def cast_to_int(integer: Optional[int]) -> int: return integer -def cast_dict_to_dict_of_strings(dct): - new_dct = {} +def cast_dict_to_dict_of_strings(dct: Dict[Any, Any]) -> Dict[str, str]: + new_dct: Dict[str, str] = {} for k, v in dct.items(): new_dct[str(k)] = str(v) return new_dct diff --git a/dbt_common/utils/executor.py b/dbt_common/utils/executor.py index 0dd8490c..529b02be 100644 --- a/dbt_common/utils/executor.py +++ b/dbt_common/utils/executor.py @@ -1,9 +1,12 @@ import concurrent.futures from contextlib import contextmanager -from contextvars import ContextVar from typing import Protocol, Optional -from dbt_common.context import get_invocation_context, reliably_get_invocation_var +from dbt_common.context import ( + get_invocation_context, + reliably_get_invocation_var, + InvocationContext, +) class ConnectingExecutor(concurrent.futures.Executor): @@ -63,7 +66,7 @@ class HasThreadingConfig(Protocol): threads: Optional[int] -def _thread_initializer(invocation_context: ContextVar) -> None: +def _thread_initializer(invocation_context: InvocationContext) -> None: invocation_var = reliably_get_invocation_var() invocation_var.set(invocation_context) diff --git a/dbt_common/utils/jinja.py b/dbt_common/utils/jinja.py index 36464cbe..260ccb6a 100644 --- a/dbt_common/utils/jinja.py +++ b/dbt_common/utils/jinja.py @@ -5,19 +5,21 @@ DOCS_PREFIX = "dbt_docs__" -def get_dbt_macro_name(name): +def get_dbt_macro_name(name) -> str: if name is None: raise DbtInternalError("Got None for a macro name!") return f"{MACRO_PREFIX}{name}" -def get_dbt_docs_name(name): +def get_dbt_docs_name(name) -> str: if name is None: raise DbtInternalError("Got None for a doc name!") return f"{DOCS_PREFIX}{name}" -def get_materialization_macro_name(materialization_name, adapter_type=None, with_prefix=True): +def get_materialization_macro_name( + materialization_name, adapter_type=None, with_prefix=True +) -> str: if adapter_type is None: adapter_type = "default" name = f"materialization_{materialization_name}_{adapter_type}" diff --git a/docs/guides/record_replay.md b/docs/guides/record_replay.md index 9d9d87f2..aff4c77c 100644 --- a/docs/guides/record_replay.md +++ b/docs/guides/record_replay.md @@ -28,7 +28,22 @@ Note also the `LoadFileRecord` class passed as a parameter to this decorator. Th The final detail needed is to define the classes specified by `params_cls` and `result_cls`, which must be dataclasses with properties whose order and names correspond to the parameters passed to the recorded function. In this case those are the `LoadFileParams` and `LoadFileResult` classes, respectively. -With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them. +With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them. + +## How to record/replay +If `DBT_RECORDER_MODE` is not `replay` or `record`, case insensitive, this is a no-op. Invalid values are ignored and do not throw exceptions. + +`DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses. Any invalid types will be ignored. `all` is a valid type and behaves the same as not populating the env var. + + +```bash +DBT_RECORDER_MODE=record DBT_RECODER_TYPES=QueryRecord,GetEnvRecord dbt run +``` + +replay need the file to replay +```bash +DBT_RECORDER_MODE=replay DBT_RECORDER_FILE_PATH=recording.json dbt run +``` ## Final Thoughts diff --git a/pyproject.toml b/pyproject.toml index c1f4f281..64fc04fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ classifiers = [ dependencies = [ "agate>=1.7.0,<1.10", "colorama>=0.3.9,<0.5", + "deepdiff>=7.0,<8.0", "isodate>=0.6,<0.7", "jsonschema>=4.0,<5.0", "Jinja2>=3.1.3,<4", diff --git a/tests/unit/test_agate_helper.py b/tests/unit/test_agate_helper.py index 4c12bcd8..fff0d4c6 100644 --- a/tests/unit/test_agate_helper.py +++ b/tests/unit/test_agate_helper.py @@ -46,13 +46,13 @@ class TestAgateHelper(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.tempdir = mkdtemp() - def tearDown(self): + def tearDown(self) -> None: rmtree(self.tempdir) - def test_from_csv(self): + def test_from_csv(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_DATA.encode("utf-8")) @@ -61,7 +61,7 @@ def test_from_csv(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_bom_from_csv(self): + def test_bom_from_csv(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_BOM_DATA.encode("utf-8")) @@ -70,7 +70,7 @@ def test_bom_from_csv(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_from_csv_all_reserved(self): + def test_from_csv_all_reserved(self) -> None: path = os.path.join(self.tempdir, "input.csv") with open(path, "wb") as fp: fp.write(SAMPLE_CSV_DATA.encode("utf-8")) @@ -79,7 +79,7 @@ def test_from_csv_all_reserved(self): for expected, row in zip(EXPECTED_STRINGS, tbl): self.assertEqual(list(row), expected) - def test_from_data(self): + def test_from_data(self) -> None: column_names = ["a", "b", "c", "d", "e", "f", "g"] data = [ { @@ -106,7 +106,7 @@ def test_from_data(self): for idx, row in enumerate(tbl): self.assertEqual(list(row), EXPECTED[idx]) - def test_datetime_formats(self): + def test_datetime_formats(self) -> None: path = os.path.join(self.tempdir, "input.csv") datetimes = [ "20180806T11:33:29.000Z", @@ -120,7 +120,7 @@ def test_datetime_formats(self): tbl = agate_helper.from_csv(path, ()) self.assertEqual(tbl[0][0], expected) - def test_merge_allnull(self): + def test_merge_allnull(self) -> None: t1 = agate_helper.table_from_rows([(1, "a", None), (2, "b", None)], ("a", "b", "c")) t2 = agate_helper.table_from_rows([(3, "c", None), (4, "d", None)], ("a", "b", "c")) result = agate_helper.merge_tables([t1, t2]) @@ -130,7 +130,7 @@ def test_merge_allnull(self): assert isinstance(result.column_types[2], agate_helper.Integer) self.assertEqual(len(result), 4) - def test_merge_mixed(self): + def test_merge_mixed(self) -> None: t1 = agate_helper.table_from_rows( [(1, "a", None, None), (2, "b", None, None)], ("a", "b", "c", "d") ) @@ -181,7 +181,7 @@ def test_merge_mixed(self): assert isinstance(result.column_types[3], agate.data_types.Number) self.assertEqual(len(result), 6) - def test_nocast_string_types(self): + def test_nocast_string_types(self) -> None: # String fields should not be coerced into a representative type # See: https://github.com/dbt-labs/dbt-core/issues/2984 @@ -202,7 +202,7 @@ def test_nocast_string_types(self): for i, row in enumerate(tbl): self.assertEqual(list(row), expected[i]) - def test_nocast_bool_01(self): + def test_nocast_bool_01(self) -> None: # True and False values should not be cast to 1 and 0, and vice versa # See: https://github.com/dbt-labs/dbt-core/issues/4511 diff --git a/tests/unit/test_connection_retries.py b/tests/unit/test_connection_retries.py index 817af7a2..44fc72f5 100644 --- a/tests/unit/test_connection_retries.py +++ b/tests/unit/test_connection_retries.py @@ -19,20 +19,23 @@ def test_no_retry(self): assert result == expected -def no_success_fn(): +def no_success_fn() -> str: raise RequestException("You'll never pass") return "failure" class TestMaxRetries: - def test_no_retry(self): + def test_no_retry(self) -> None: fn_to_retry = functools.partial(no_success_fn) with pytest.raises(ConnectionError): connection_exception_retry(fn_to_retry, 3) -def single_retry_fn(): +counter = 0 + + +def single_retry_fn() -> str: global counter if counter == 0: counter += 1 @@ -45,7 +48,7 @@ def single_retry_fn(): class TestSingleRetry: - def test_no_retry(self): + def test_no_retry(self) -> None: global counter counter = 0 diff --git a/tests/unit/test_contextvars.py b/tests/unit/test_contextvars.py index 4eb58e6c..1aa9425f 100644 --- a/tests/unit/test_contextvars.py +++ b/tests/unit/test_contextvars.py @@ -1,7 +1,7 @@ from dbt_common.events.contextvars import log_contextvars, get_node_info, set_log_contextvars -def test_contextvars(): +def test_contextvars() -> None: node_info = { "unique_id": "model.test.my_model", "started_at": None, diff --git a/tests/unit/test_contracts_util.py b/tests/unit/test_contracts_util.py index 2a620370..d2fc4493 100644 --- a/tests/unit/test_contracts_util.py +++ b/tests/unit/test_contracts_util.py @@ -13,7 +13,7 @@ class ExampleMergableClass(Mergeable): class TestMergableClass(unittest.TestCase): - def test_mergeability(self): + def test_mergeability(self) -> None: mergeable1 = ExampleMergableClass( attr_a="loses", attr_b=None, attr_c=["I'll", "still", "exist"] ) diff --git a/tests/unit/test_core_dbt_utils.py b/tests/unit/test_core_dbt_utils.py index 8a0e836e..7419cd8d 100644 --- a/tests/unit/test_core_dbt_utils.py +++ b/tests/unit/test_core_dbt_utils.py @@ -7,30 +7,30 @@ class TestCommonDbtUtils(unittest.TestCase): - def test_connection_exception_retry_none(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add(self), 5) + def test_connection_exception_retry_none(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add(), 5) self.assertEqual(1, counter) - def test_connection_exception_retry_success_requests_exception(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_requests_exception(self), 5) + def test_connection_exception_retry_success_requests_exception(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_requests_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned None, plus 1 retry - def test_connection_exception_retry_max(self): - Counter._reset(self) + def test_connection_exception_retry_max(self) -> None: + Counter._reset() with self.assertRaises(ConnectionError): - connection_exception_retry(lambda: Counter._add_with_exception(self), 5) + connection_exception_retry(lambda: Counter._add_with_exception(), 5) self.assertEqual(6, counter) # 6 = original attempt plus 5 retries - def test_connection_exception_retry_success_failed_untar(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_untar_exception(self), 5) + def test_connection_exception_retry_success_failed_untar(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_untar_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned ReadError, plus 1 retry - def test_connection_exception_retry_success_failed_eofexception(self): - Counter._reset(self) - connection_exception_retry(lambda: Counter._add_with_eof_exception(self), 5) + def test_connection_exception_retry_success_failed_eofexception(self) -> None: + Counter._reset() + connection_exception_retry(lambda: Counter._add_with_eof_exception(), 5) self.assertEqual(2, counter) # 2 = original attempt returned EOFError, plus 1 retry @@ -38,36 +38,42 @@ def test_connection_exception_retry_success_failed_eofexception(self): class Counter: - def _add(self): + @classmethod + def _add(cls) -> None: global counter counter += 1 # All exceptions that Requests explicitly raises inherit from # requests.exceptions.RequestException so we want to make sure that raises plus one exception # that inherit from it for sanity - def _add_with_requests_exception(self): + @classmethod + def _add_with_requests_exception(cls) -> None: global counter counter += 1 if counter < 2: raise requests.exceptions.RequestException - def _add_with_exception(self): + @classmethod + def _add_with_exception(cls) -> None: global counter counter += 1 raise requests.exceptions.ConnectionError - def _add_with_untar_exception(self): + @classmethod + def _add_with_untar_exception(cls) -> None: global counter counter += 1 if counter < 2: raise tarfile.ReadError - def _add_with_eof_exception(self): + @classmethod + def _add_with_eof_exception(cls) -> None: global counter counter += 1 if counter < 2: raise EOFError - def _reset(self): + @classmethod + def _reset(cls) -> None: global counter counter = 0 diff --git a/tests/unit/test_diff.py b/tests/unit/test_diff.py new file mode 100644 index 00000000..54f735e3 --- /dev/null +++ b/tests/unit/test_diff.py @@ -0,0 +1,307 @@ +import json +from typing import Any, Dict + +import pytest +from dbt_common.record import Diff + + +@pytest.fixture +def current_query(): + return [ + { + "params": { + "a": 1, + }, + "result": { + "this": "key", + "table": '[{"a": 5},{"b": 7}]', + }, + } + ] + + +@pytest.fixture +def query_modified_order(): + return [ + { + "params": { + "a": 1, + }, + "result": { + "this": "key", + "table": '[{"b": 7},{"a": 5}]', + }, + } + ] + + +@pytest.fixture +def query_modified_value(): + return [ + { + "params": { + "a": 1, + }, + "result": { + "this": "key", + "table": '[{"a": 5},{"b": 10}]', + }, + } + ] + + +@pytest.fixture +def current_simple(): + return [ + { + "params": { + "a": 1, + }, + "result": { + "this": "cat", + }, + } + ] + + +@pytest.fixture +def current_simple_modified(): + return [ + { + "params": { + "a": 1, + }, + "result": { + "this": "dog", + }, + } + ] + + +@pytest.fixture +def env_record(): + return [ + { + "params": {}, + "result": { + "env": { + "DBT_RECORDER_FILE_PATH": "record.json", + "ANOTHER_ENV_VAR": "dogs", + }, + }, + } + ] + + +@pytest.fixture +def modified_env_record(): + return [ + { + "params": {}, + "result": { + "env": { + "DBT_RECORDER_FILE_PATH": "another_record.json", + "ANOTHER_ENV_VAR": "cats", + }, + }, + } + ] + + +def test_diff_query_records_no_diff(current_query, query_modified_order): + # Setup: Create an instance of Diff + diff_instance = Diff( + current_recording_path="path/to/current", previous_recording_path="path/to/previous" + ) + result = diff_instance.diff_query_records(current_query, query_modified_order) + # the order changed but the diff should be empty + expected_result = {} + assert result == expected_result # Replace expected_result with what you actually expect + + +def test_diff_query_records_with_diff(current_query, query_modified_value): + diff_instance = Diff( + current_recording_path="path/to/current", previous_recording_path="path/to/previous" + ) + result = diff_instance.diff_query_records(current_query, query_modified_value) + # the values changed this time + expected_result = { + "values_changed": {"root[0]['result']['table'][1]['b']": {"new_value": 7, "old_value": 10}} + } + assert result == expected_result + + +def test_diff_env_records(env_record, modified_env_record): + diff_instance = Diff( + current_recording_path="path/to/current", previous_recording_path="path/to/previous" + ) + result = diff_instance.diff_env_records(env_record, modified_env_record) + expected_result = { + "values_changed": { + "root[0]['result']['env']['ANOTHER_ENV_VAR']": { + "new_value": "dogs", + "old_value": "cats", + } + } + } + assert result == expected_result + + +def test_diff_default_no_diff(current_simple): + diff_instance = Diff( + current_recording_path="path/to/current", previous_recording_path="path/to/previous" + ) + # use the same list to ensure no diff + result = diff_instance.diff_default(current_simple, current_simple) + expected_result = {} + assert result == expected_result + + +def test_diff_default_with_diff(current_simple, current_simple_modified): + diff_instance = Diff( + current_recording_path="path/to/current", previous_recording_path="path/to/previous" + ) + result = diff_instance.diff_default(current_simple, current_simple_modified) + expected_result = { + "values_changed": {"root[0]['result']['this']": {"new_value": "cat", "old_value": "dog"}} + } + assert result == expected_result + + +# Mock out reading the files so we don't have to +class MockFile: + def __init__(self, json_data): + self.json_data = json_data + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def read(self): + return json.dumps(self.json_data) + + +# Create a Mock Open Function +def mock_open(mock_files): + def open_mock(file, *args, **kwargs): + if file in mock_files: + return MockFile(mock_files[file]) + raise FileNotFoundError(f"No mock file found for {file}") + + return open_mock + + +def test_calculate_diff_no_diff(monkeypatch) -> None: + # Mock data for the files + current_recording_data = { + "GetEnvRecord": [ + { + "params": { + "a": 1, + }, + "result": { + "this": "dog", + }, + } + ], + "DefaultKey": [ + { + "params": { + "a": 1, + }, + "result": { + "this": "dog", + }, + } + ], + } + previous_recording_data = { + "GetEnvRecord": [ + { + "params": { + "a": 1, + }, + "result": { + "this": "dog", + }, + } + ], + "DefaultKey": [ + { + "params": { + "a": 1, + }, + "result": { + "this": "dog", + }, + } + ], + } + current_recording_path = "/path/to/current_recording.json" + previous_recording_path = "/path/to/previous_recording.json" + mock_files = { + current_recording_path: current_recording_data, + previous_recording_path: previous_recording_data, + } + monkeypatch.setattr("builtins.open", mock_open(mock_files)) + + # test the diff + diff_instance = Diff( + current_recording_path=current_recording_path, + previous_recording_path=previous_recording_path, + ) + result = diff_instance.calculate_diff() + expected_result: Dict[str, Any] = {"GetEnvRecord": {}, "DefaultKey": {}} + assert result == expected_result + + +def test_calculate_diff_with_diff(monkeypatch) -> None: + # Mock data for the files + current_recording_data = { + "GetEnvRecord": [ + { + "params": { + "a": 1, + }, + "result": { + "this": "dog", + }, + } + ] + } + previous_recording_data = { + "GetEnvRecord": [ + { + "params": { + "a": 1, + }, + "result": { + "this": "cats", + }, + } + ] + } + current_recording_path = "/path/to/current_recording.json" + previous_recording_path = "/path/to/previous_recording.json" + mock_files = { + current_recording_path: current_recording_data, + previous_recording_path: previous_recording_data, + } + monkeypatch.setattr("builtins.open", mock_open(mock_files)) + + # test the diff + diff_instance = Diff( + current_recording_path=current_recording_path, + previous_recording_path=previous_recording_path, + ) + result = diff_instance.calculate_diff() + expected_result = { + "GetEnvRecord": { + "values_changed": { + "root[0]['result']['this']": {"new_value": "dog", "old_value": "cats"} + } + } + } + assert result == expected_result diff --git a/tests/unit/test_event_handler.py b/tests/unit/test_event_handler.py index 80d5ae2b..f38938b6 100644 --- a/tests/unit/test_event_handler.py +++ b/tests/unit/test_event_handler.py @@ -5,7 +5,7 @@ from dbt_common.events.event_manager import TestEventManager -def test_event_logging_handler_emits_records_correctly(): +def test_event_logging_handler_emits_records_correctly() -> None: event_manager = TestEventManager() handler = DbtEventLoggingHandler(event_manager=event_manager, level=logging.DEBUG) log = logging.getLogger("test") @@ -27,7 +27,7 @@ def test_event_logging_handler_emits_records_correctly(): assert event_manager.event_history[5][1] == EventLevel.ERROR -def test_set_package_logging_sets_level_correctly(): +def test_set_package_logging_sets_level_correctly() -> None: event_manager = TestEventManager() log = logging.getLogger("test") set_package_logging("test", logging.DEBUG, event_manager) diff --git a/tests/unit/test_helper_types.py b/tests/unit/test_helper_types.py index 1a9519de..ba98803c 100644 --- a/tests/unit/test_helper_types.py +++ b/tests/unit/test_helper_types.py @@ -1,11 +1,12 @@ import pytest +from typing import List, Union from dbt_common.helper_types import IncludeExclude, WarnErrorOptions from dbt_common.dataclass_schema import ValidationError class TestIncludeExclude: - def test_init_invalid(self): + def test_init_invalid(self) -> None: with pytest.raises(ValidationError): IncludeExclude(include="invalid") @@ -22,14 +23,16 @@ def test_init_invalid(self): (["ItemA", "ItemB"], [], True), ], ) - def test_includes(self, include, exclude, expected_includes): + def test_includes( + self, include: Union[str, List[str]], exclude: List[str], expected_includes: bool + ) -> None: include_exclude = IncludeExclude(include=include, exclude=exclude) assert include_exclude.includes("ItemA") == expected_includes class TestWarnErrorOptions: - def test_init_invalid_error(self): + def test_init_invalid_error(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include=["InvalidError"], valid_error_names=set(["ValidError"])) @@ -38,14 +41,14 @@ def test_init_invalid_error(self): include="*", exclude=["InvalidError"], valid_error_names=set(["ValidError"]) ) - def test_init_invalid_error_default_valid_error_names(self): + def test_init_invalid_error_default_valid_error_names(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include=["InvalidError"]) with pytest.raises(ValidationError): WarnErrorOptions(include="*", exclude=["InvalidError"]) - def test_init_valid_error(self): + def test_init_valid_error(self) -> None: warn_error_options = WarnErrorOptions( include=["ValidError"], valid_error_names=set(["ValidError"]) ) @@ -58,18 +61,18 @@ def test_init_valid_error(self): assert warn_error_options.include == "*" assert warn_error_options.exclude == ["ValidError"] - def test_init_default_silence(self): + def test_init_default_silence(self) -> None: my_options = WarnErrorOptions(include="*") assert my_options.silence == [] - def test_init_invalid_silence_event(self): + def test_init_invalid_silence_event(self) -> None: with pytest.raises(ValidationError): WarnErrorOptions(include="*", silence=["InvalidError"]) - def test_init_valid_silence_event(self): + def test_init_valid_silence_event(self) -> None: all_events = ["MySilencedEvent"] my_options = WarnErrorOptions( - include="*", silence=all_events, valid_error_names=all_events + include="*", silence=all_events, valid_error_names=set(all_events) ) assert my_options.silence == all_events @@ -81,14 +84,16 @@ def test_init_valid_silence_event(self): ("*", ["ItemB"], True), ], ) - def test_includes(self, include, silence, expected_includes): + def test_includes( + self, include: Union[str, List[str]], silence: List[str], expected_includes: bool + ) -> None: include_exclude = WarnErrorOptions( include=include, silence=silence, valid_error_names={"ItemA", "ItemB"} ) assert include_exclude.includes("ItemA") == expected_includes - def test_silenced(self): + def test_silenced(self) -> None: my_options = WarnErrorOptions(include="*", silence=["ItemA"], valid_error_names={"ItemA"}) assert my_options.silenced("ItemA") assert not my_options.silenced("ItemB") diff --git a/tests/unit/test_invocation_context.py b/tests/unit/test_invocation_context.py index b6697f8e..3dc832d3 100644 --- a/tests/unit/test_invocation_context.py +++ b/tests/unit/test_invocation_context.py @@ -2,13 +2,13 @@ from dbt_common.context import InvocationContext -def test_invocation_context_env(): +def test_invocation_context_env() -> None: test_env = {"VAR_1": "value1", "VAR_2": "value2"} ic = InvocationContext(env=test_env) assert ic.env == test_env -def test_invocation_context_secrets(): +def test_invocation_context_secrets() -> None: test_env = { f"{SECRET_ENV_PREFIX}_VAR_1": "secret1", f"{SECRET_ENV_PREFIX}VAR_2": "secret2", @@ -16,10 +16,10 @@ def test_invocation_context_secrets(): f"foo{SECRET_ENV_PREFIX}": "non-secret", } ic = InvocationContext(env=test_env) - assert set(ic.env_secrets) == set(["secret1", "secret2"]) + assert set(ic.env_secrets) == {"secret1", "secret2"} -def test_invocation_context_private(): +def test_invocation_context_private() -> None: test_env = { f"{PRIVATE_ENV_PREFIX}_VAR_1": "private1", f"{PRIVATE_ENV_PREFIX}VAR_2": "private2", diff --git a/tests/unit/test_jinja.py b/tests/unit/test_jinja.py index f038a1ec..e906a0ac 100644 --- a/tests/unit/test_jinja.py +++ b/tests/unit/test_jinja.py @@ -1,23 +1,26 @@ import unittest +from dbt_common.clients._jinja_blocks import BlockTag from dbt_common.clients.jinja import extract_toplevel_blocks from dbt_common.exceptions import CompilationError class TestBlockLexer(unittest.TestCase): - def test_basic(self): + def test_basic(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" blocks = extract_toplevel_blocks( block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_multiple(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_multiple(self) -> None: body_one = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' body_two = ( "{{ config(bar=1)}}\r\nselect * from {% if foo %} thing " @@ -37,7 +40,7 @@ def test_multiple(self): ) self.assertEqual(len(blocks), 2) - def test_comments(self): + def test_comments(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' comment = "{# my comment #}" block_data = " \n\r\t{%- mytype foo %}" + body + "{%endmytype -%}" @@ -45,12 +48,14 @@ def test_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_evil_comments(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_evil_comments(self) -> None: body = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' comment = ( "{# external comment {% othertype bar %} select * from " @@ -61,12 +66,14 @@ def test_evil_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_nested_comments(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_nested_comments(self) -> None: body = ( '{# my comment #} {{ config(foo="bar") }}' "\r\nselect * from {# my other comment embedding {% endmytype %} #} this.that\r\n" @@ -80,33 +87,43 @@ def test_nested_comments(self): comment + block_data, allowed_blocks={"mytype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].contents, body) - self.assertEqual(blocks[0].full_block, block_data) - - def test_complex_file(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.contents, body) + self.assertEqual(b0.full_block, block_data) + + def test_complex_file(self) -> None: blocks = extract_toplevel_blocks( complex_snapshot_file, allowed_blocks={"mytype", "myothertype"}, collect_raw_data=False ) self.assertEqual(len(blocks), 3) - self.assertEqual(blocks[0].block_type_name, "mytype") - self.assertEqual(blocks[0].block_name, "foo") - self.assertEqual(blocks[0].full_block, "{% mytype foo %} some stuff {% endmytype %}") - self.assertEqual(blocks[0].contents, " some stuff ") - self.assertEqual(blocks[1].block_type_name, "mytype") - self.assertEqual(blocks[1].block_name, "bar") - self.assertEqual(blocks[1].full_block, bar_block) - self.assertEqual(blocks[1].contents, bar_block[16:-15].rstrip()) - self.assertEqual(blocks[2].block_type_name, "myothertype") - self.assertEqual(blocks[2].block_name, "x") - self.assertEqual(blocks[2].full_block, x_block.strip()) + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "mytype") + self.assertEqual(b0.block_name, "foo") + self.assertEqual(b0.full_block, "{% mytype foo %} some stuff {% endmytype %}") + self.assertEqual(b0.contents, " some stuff ") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "mytype") + self.assertEqual(b1.block_name, "bar") + self.assertEqual(b1.full_block, bar_block) + self.assertEqual(b1.contents, bar_block[16:-15].rstrip()) + + b2 = blocks[2] + assert isinstance(b2, BlockTag) + self.assertEqual(b2.block_type_name, "myothertype") + self.assertEqual(b2.block_name, "x") + self.assertEqual(b2.full_block, x_block.strip()) self.assertEqual( - blocks[2].contents, + b2.contents, x_block[len("\n{% myothertype x %}") : -len("{% endmyothertype %}\n")], ) - def test_peaceful_macro_coexistence(self): + def test_peaceful_macro_coexistence(self) -> None: body = ( "{# my macro #} {% macro foo(a, b) %} do a thing " "{%- endmacro %} {# my model #} {% a b %} test {% enda %}" @@ -116,15 +133,22 @@ def test_peaceful_macro_coexistence(self): ) self.assertEqual(len(blocks), 4) self.assertEqual(blocks[0].full_block, "{# my macro #} ") - self.assertEqual(blocks[1].block_type_name, "macro") - self.assertEqual(blocks[1].block_name, "foo") - self.assertEqual(blocks[1].contents, " do a thing") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "macro") + self.assertEqual(b1.block_name, "foo") + self.assertEqual(b1.contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") - self.assertEqual(blocks[3].block_type_name, "a") - self.assertEqual(blocks[3].block_name, "b") - self.assertEqual(blocks[3].contents, " test ") - def test_macro_with_trailing_data(self): + b3 = blocks[3] + assert isinstance(b3, BlockTag) + self.assertEqual(b3.block_type_name, "a") + self.assertEqual(b3.block_name, "b") + self.assertEqual(b3.contents, " test ") + + def test_macro_with_trailing_data(self) -> None: body = ( "{# my macro #} {% macro foo(a, b) %} do a thing {%- endmacro %} " "{# my model #} {% a b %} test {% enda %} raw data so cool" @@ -134,16 +158,24 @@ def test_macro_with_trailing_data(self): ) self.assertEqual(len(blocks), 5) self.assertEqual(blocks[0].full_block, "{# my macro #} ") - self.assertEqual(blocks[1].block_type_name, "macro") - self.assertEqual(blocks[1].block_name, "foo") - self.assertEqual(blocks[1].contents, " do a thing") + + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "macro") + self.assertEqual(b1.block_name, "foo") + self.assertEqual(b1.contents, " do a thing") + self.assertEqual(blocks[2].full_block, " {# my model #} ") - self.assertEqual(blocks[3].block_type_name, "a") - self.assertEqual(blocks[3].block_name, "b") - self.assertEqual(blocks[3].contents, " test ") + + b3 = blocks[3] + assert isinstance(b3, BlockTag) + self.assertEqual(b3.block_type_name, "a") + self.assertEqual(b3.block_name, "b") + self.assertEqual(b3.contents, " test ") + self.assertEqual(blocks[4].full_block, " raw data so cool") - def test_macro_with_crazy_args(self): + def test_macro_with_crazy_args(self) -> None: body = ( """{% macro foo(a, b=asdf("cool this is 'embedded'" * 3) + external_var, c)%}""" "cool{# block comment with {% endmacro %} in it #} stuff here " @@ -151,38 +183,44 @@ def test_macro_with_crazy_args(self): ) blocks = extract_toplevel_blocks(body, allowed_blocks={"macro"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "macro") - self.assertEqual(blocks[0].block_name, "foo") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "macro") + self.assertEqual(b0.block_name, "foo") self.assertEqual( blocks[0].contents, "cool{# block comment with {% endmacro %} in it #} stuff here " ) - def test_materialization_parse(self): + def test_materialization_parse(self) -> None: body = "{% materialization xxx, default %} ... {% endmaterialization %}" blocks = extract_toplevel_blocks( body, allowed_blocks={"materialization"}, collect_raw_data=False ) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "materialization") - self.assertEqual(blocks[0].block_name, "xxx") - self.assertEqual(blocks[0].full_block, body) + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "materialization") + self.assertEqual(b0.block_name, "xxx") + self.assertEqual(b0.full_block, body) body = '{% materialization xxx, adapter="other" %} ... {% endmaterialization %}' blocks = extract_toplevel_blocks( body, allowed_blocks={"materialization"}, collect_raw_data=False ) + b0 = blocks[0] + assert isinstance(b0, BlockTag) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "materialization") - self.assertEqual(blocks[0].block_name, "xxx") - self.assertEqual(blocks[0].full_block, body) + self.assertEqual(b0.block_type_name, "materialization") + self.assertEqual(b0.block_name, "xxx") + self.assertEqual(b0.full_block, body) - def test_nested_not_ok(self): + def test_nested_not_ok(self) -> None: # we don't allow nesting same blocks body = "{% myblock a %} {% myblock b %} {% endmyblock %} {% endmyblock %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock"}) - def test_incomplete_block_failure(self): + def test_incomplete_block_failure(self) -> None: fullbody = "{% myblock foo %} {% endmyblock %}" for length in range(len("{% myblock foo %}"), len(fullbody) - 1): body = fullbody[:length] @@ -194,45 +232,45 @@ def test_wrong_end_failure(self): with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"}) - def test_comment_no_end_failure(self): + def test_comment_no_end_failure(self) -> None: body = "{# " with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_comment_only(self): + def test_comment_only(self) -> None: body = "{# myblock #}" blocks = extract_toplevel_blocks(body) self.assertEqual(len(blocks), 1) blocks = extract_toplevel_blocks(body, collect_raw_data=False) self.assertEqual(len(blocks), 0) - def test_comment_block_self_closing(self): + def test_comment_block_self_closing(self) -> None: # test the case where a comment start looks a lot like it closes itself # (but it doesn't in jinja!) body = "{#} {% myblock foo %} {#}" blocks = extract_toplevel_blocks(body, collect_raw_data=False) self.assertEqual(len(blocks), 0) - def test_embedded_self_closing_comment_block(self): + def test_embedded_self_closing_comment_block(self) -> None: body = "{% myblock foo %} {#}{% endmyblock %} {#}{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) self.assertEqual(blocks[0].contents, " {#}{% endmyblock %} {#}") - def test_set_statement(self): + def test_set_statement(self) -> None: body = "{% set x = 1 %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_set_block(self): + def test_set_block(self) -> None: body = "{% set x %}1{% endset %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_crazy_set_statement(self): + def test_crazy_set_statement(self) -> None: body = ( '{% set x = (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}' '{% set y = otherthing("{% myblock foo %}") %}' @@ -244,19 +282,19 @@ def test_crazy_set_statement(self): self.assertEqual(blocks[0].full_block, "{% otherblock bar %}x{% endotherblock %}") self.assertEqual(blocks[0].block_type_name, "otherblock") - def test_do_statement(self): + def test_do_statement(self) -> None: body = "{% do thing.update() %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_deceptive_do_statement(self): + def test_deceptive_do_statement(self) -> None: body = "{% do thing %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_do_block(self): + def test_do_block(self) -> None: body = "{% do %}thing.update(){% enddo %}{% myblock foo %}hi{% endmyblock %}" blocks = extract_toplevel_blocks( body, allowed_blocks={"do", "myblock"}, collect_raw_data=False @@ -266,7 +304,7 @@ def test_do_block(self): self.assertEqual(blocks[0].block_type_name, "do") self.assertEqual(blocks[1].full_block, "{% myblock foo %}hi{% endmyblock %}") - def test_crazy_do_statement(self): + def test_crazy_do_statement(self) -> None: body = ( '{% do (thing("{% myblock foo %}")) %}{% otherblock bar %}x{% endotherblock %}' '{% do otherthing("{% myblock foo %}") %}{% myblock x %}hi{% endmyblock %}' @@ -280,7 +318,7 @@ def test_crazy_do_statement(self): self.assertEqual(blocks[1].full_block, "{% myblock x %}hi{% endmyblock %}") self.assertEqual(blocks[1].block_type_name, "myblock") - def test_awful_jinja(self): + def test_awful_jinja(self) -> None: blocks = extract_toplevel_blocks( if_you_do_this_you_are_awful, allowed_blocks={"snapshot", "materialization"}, @@ -304,63 +342,71 @@ def test_awful_jinja(self): self.assertEqual(blocks[1].block_type_name, "materialization") self.assertEqual(blocks[1].contents, "\nhi\n") - def test_quoted_endblock_within_block(self): + def test_quoted_endblock_within_block(self) -> None: body = '{% myblock something -%} {% set x = ("{% endmyblock %}") %} {% endmyblock %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"myblock"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].block_type_name, "myblock") self.assertEqual(blocks[0].contents, '{% set x = ("{% endmyblock %}") %} ') - def test_docs_block(self): + def test_docs_block(self) -> None: body = ( "{% docs __my_doc__ %} asdf {# nope {% enddocs %}} #} {% enddocs %}" '{% docs __my_other_doc__ %} asdf "{% enddocs %}' ) blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) self.assertEqual(len(blocks), 2) - self.assertEqual(blocks[0].block_type_name, "docs") - self.assertEqual(blocks[0].contents, " asdf {# nope {% enddocs %}} #} ") - self.assertEqual(blocks[0].block_name, "__my_doc__") - self.assertEqual(blocks[1].block_type_name, "docs") - self.assertEqual(blocks[1].contents, ' asdf "') - self.assertEqual(blocks[1].block_name, "__my_other_doc__") - - def test_docs_block_expr(self): + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "docs") + self.assertEqual(b0.contents, " asdf {# nope {% enddocs %}} #} ") + self.assertEqual(b0.block_name, "__my_doc__") + b1 = blocks[1] + assert isinstance(b1, BlockTag) + self.assertEqual(b1.block_type_name, "docs") + self.assertEqual(b1.contents, ' asdf "') + self.assertEqual(b1.block_name, "__my_other_doc__") + + def test_docs_block_expr(self) -> None: body = '{% docs more_doc %} asdf {{ "{% enddocs %}" ~ "}}" }}{% enddocs %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"docs"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "docs") - self.assertEqual(blocks[0].contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}') - self.assertEqual(blocks[0].block_name, "more_doc") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "docs") + self.assertEqual(b0.contents, ' asdf {{ "{% enddocs %}" ~ "}}" }}') + self.assertEqual(b0.block_name, "more_doc") - def test_unclosed_model_quotes(self): + def test_unclosed_model_quotes(self) -> None: # test case for https://github.com/dbt-labs/dbt-core/issues/1533 body = '{% model my_model -%} select * from "something"."something_else{% endmodel %}' blocks = extract_toplevel_blocks(body, allowed_blocks={"model"}, collect_raw_data=False) self.assertEqual(len(blocks), 1) - self.assertEqual(blocks[0].block_type_name, "model") - self.assertEqual(blocks[0].contents, 'select * from "something"."something_else') - self.assertEqual(blocks[0].block_name, "my_model") + b0 = blocks[0] + assert isinstance(b0, BlockTag) + self.assertEqual(b0.block_type_name, "model") + self.assertEqual(b0.contents, 'select * from "something"."something_else') + self.assertEqual(b0.block_name, "my_model") - def test_if(self): + def test_if(self) -> None: # if you conditionally define your macros/models, don't body = "{% if true %}{% macro my_macro() %} adsf {% endmacro %}{% endif %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_if_innocuous(self): + def test_if_innocuous(self) -> None: body = "{% if true %}{% something %}asdfasd{% endsomething %}{% endif %}" blocks = extract_toplevel_blocks(body) self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) - def test_for(self): + def test_for(self) -> None: # no for-loops over macros. body = "{% for x in range(10) %}{% macro my_macro() %} adsf {% endmacro %}{% endfor %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body) - def test_for_innocuous(self): + def test_for_innocuous(self) -> None: # no for-loops over macros. body = ( "{% for x in range(10) %}{% something my_something %} adsf " @@ -370,7 +416,7 @@ def test_for_innocuous(self): self.assertEqual(len(blocks), 1) self.assertEqual(blocks[0].full_block, body) - def test_endif(self): + def test_endif(self) -> None: body = "{% snapshot foo %}select * from thing{% endsnapshot%}{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) @@ -382,7 +428,7 @@ def test_endif(self): str(err.exception), ) - def test_if_endfor(self): + def test_if_endfor(self) -> None: body = "{% if x %}...{% endfor %}{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) @@ -391,7 +437,7 @@ def test_if_endfor(self): str(err.exception), ) - def test_if_endfor_newlines(self): + def test_if_endfor_newlines(self) -> None: body = "{% if x %}\n ...\n {% endfor %}\n{% endif %}" with self.assertRaises(CompilationError) as err: extract_toplevel_blocks(body) diff --git a/tests/unit/test_model_config.py b/tests/unit/test_model_config.py index 0cc1e711..57a14438 100644 --- a/tests/unit/test_model_config.py +++ b/tests/unit/test_model_config.py @@ -14,7 +14,7 @@ class ThingWithMergeBehavior(dbtClassMixin): keysappended: Dict[str, int] = field(metadata={"merge": MergeBehavior.DictKeyAppend}) -def test_merge_behavior_meta(): +def test_merge_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(MergeBehavior) == { @@ -29,15 +29,14 @@ def test_merge_behavior_meta(): assert existing == initial_existing -def test_merge_behavior_from_field(): - fields = [f[0] for f in ThingWithMergeBehavior._get_fields()] - fields = {name: f for f, name in ThingWithMergeBehavior._get_fields()} - assert set(fields) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"} - assert MergeBehavior.from_field(fields["default_behavior"]) == MergeBehavior.Clobber - assert MergeBehavior.from_field(fields["appended"]) == MergeBehavior.Append - assert MergeBehavior.from_field(fields["updated"]) == MergeBehavior.Update - assert MergeBehavior.from_field(fields["clobbered"]) == MergeBehavior.Clobber - assert MergeBehavior.from_field(fields["keysappended"]) == MergeBehavior.DictKeyAppend +def test_merge_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithMergeBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "appended", "updated", "clobbered", "keysappended"} + assert MergeBehavior.from_field(fields2["default_behavior"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields2["appended"]) == MergeBehavior.Append + assert MergeBehavior.from_field(fields2["updated"]) == MergeBehavior.Update + assert MergeBehavior.from_field(fields2["clobbered"]) == MergeBehavior.Clobber + assert MergeBehavior.from_field(fields2["keysappended"]) == MergeBehavior.DictKeyAppend @dataclass @@ -47,7 +46,7 @@ class ThingWithShowBehavior(dbtClassMixin): shown: float = field(metadata={"show_hide": ShowBehavior.Show}) -def test_show_behavior_meta(): +def test_show_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(ShowBehavior) == {ShowBehavior.Hide, ShowBehavior.Show} @@ -57,13 +56,12 @@ def test_show_behavior_meta(): assert existing == initial_existing -def test_show_behavior_from_field(): - fields = [f[0] for f in ThingWithShowBehavior._get_fields()] - fields = {name: f for f, name in ThingWithShowBehavior._get_fields()} - assert set(fields) == {"default_behavior", "hidden", "shown"} - assert ShowBehavior.from_field(fields["default_behavior"]) == ShowBehavior.Show - assert ShowBehavior.from_field(fields["hidden"]) == ShowBehavior.Hide - assert ShowBehavior.from_field(fields["shown"]) == ShowBehavior.Show +def test_show_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithShowBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "hidden", "shown"} + assert ShowBehavior.from_field(fields2["default_behavior"]) == ShowBehavior.Show + assert ShowBehavior.from_field(fields2["hidden"]) == ShowBehavior.Hide + assert ShowBehavior.from_field(fields2["shown"]) == ShowBehavior.Show @dataclass @@ -73,7 +71,7 @@ class ThingWithCompareBehavior(dbtClassMixin): excluded: str = field(metadata={"compare": CompareBehavior.Exclude}) -def test_compare_behavior_meta(): +def test_compare_behavior_meta() -> None: existing = {"foo": "bar"} initial_existing = existing.copy() assert set(CompareBehavior) == {CompareBehavior.Include, CompareBehavior.Exclude} @@ -83,10 +81,9 @@ def test_compare_behavior_meta(): assert existing == initial_existing -def test_compare_behavior_from_field(): - fields = [f[0] for f in ThingWithCompareBehavior._get_fields()] - fields = {name: f for f, name in ThingWithCompareBehavior._get_fields()} - assert set(fields) == {"default_behavior", "included", "excluded"} - assert CompareBehavior.from_field(fields["default_behavior"]) == CompareBehavior.Include - assert CompareBehavior.from_field(fields["included"]) == CompareBehavior.Include - assert CompareBehavior.from_field(fields["excluded"]) == CompareBehavior.Exclude +def test_compare_behavior_from_field() -> None: + fields2 = {name: f for f, name in ThingWithCompareBehavior._get_fields()} + assert set(fields2) == {"default_behavior", "included", "excluded"} + assert CompareBehavior.from_field(fields2["default_behavior"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields2["included"]) == CompareBehavior.Include + assert CompareBehavior.from_field(fields2["excluded"]) == CompareBehavior.Exclude diff --git a/tests/unit/test_proto_events.py b/tests/unit/test_proto_events.py index 32eb08ae..d21b5062 100644 --- a/tests/unit/test_proto_events.py +++ b/tests/unit/test_proto_events.py @@ -18,7 +18,7 @@ } -def test_events(): +def test_events() -> None: # M020 event event_code = "M020" event = RetryExternalCall(attempt=3, max=5) @@ -45,7 +45,7 @@ def test_events(): assert new_msg.data.attempt == msg.data.attempt -def test_extra_dict_on_event(monkeypatch): +def test_extra_dict_on_event(monkeypatch) -> None: monkeypatch.setenv("DBT_ENV_CUSTOM_ENV_env_key", "env_value") reset_metadata_vars() diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index aa7af69b..6e02d710 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -1,5 +1,6 @@ import dataclasses import os +import pytest from typing import Optional from dbt_common.context import set_invocation_context, get_invocation_context @@ -24,58 +25,114 @@ class TestRecord(Record): result_cls = TestRecordResult -def test_decorator_records(): - prev = os.environ.get("DBT_RECORD", None) - try: - os.environ["DBT_RECORD"] = "True" - recorder = Recorder(RecorderMode.RECORD) - set_invocation_context({}) - get_invocation_context().recorder = recorder - - @record_function(TestRecord) - def test_func(a: int, b: str, c: Optional[str] = None) -> str: - return str(a) + b + (c if c else "") - - test_func(123, "abc") - - expected_record = TestRecord( - params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") - ) - - assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params - assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result - - finally: - if prev is None: - os.environ.pop("DBT_RECORD", None) - else: - os.environ["DBT_RECORD"] = prev - - -def test_decorator_replays(): - prev = os.environ.get("DBT_RECORD", None) - try: - os.environ["DBT_RECORD"] = "True" - recorder = Recorder(RecorderMode.REPLAY) - set_invocation_context({}) - get_invocation_context().recorder = recorder - - expected_record = TestRecord( - params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") - ) - - recorder._records_by_type["TestRecord"] = [expected_record] +@dataclasses.dataclass +class NotTestRecordParams: + a: int + b: str + c: Optional[str] = None - @record_function(TestRecord) - def test_func(a: int, b: str, c: Optional[str] = None) -> str: - raise Exception("This should not actually be called") - res = test_func(123, "abc") +@dataclasses.dataclass +class NotTestRecordResult: + return_val: str - assert res == "123abc" - finally: - if prev is None: - os.environ.pop("DBT_RECORD", None) - else: - os.environ["DBT_RECORD"] = prev +@Recorder.register_record_type +class NotTestRecord(Record): + params_cls = NotTestRecordParams + result_cls = NotTestRecordResult + + +@pytest.fixture(scope="function", autouse=True) +def setup(): + # capture the previous state of the environment variables + prev_mode = os.environ.get("DBT_RECORDER_MODE", None) + prev_type = os.environ.get("DBT_RECORDER_TYPES", None) + prev_fp = os.environ.get("DBT_RECORDER_FILE_PATH", None) + # clear the environment variables + os.environ.pop("DBT_RECORDER_MODE", None) + os.environ.pop("DBT_RECORDER_TYPES", None) + os.environ.pop("DBT_RECORDER_FILE_PATH", None) + yield + # reset the environment variables to their previous state + if prev_mode is None: + os.environ.pop("DBT_RECORDER_MODE", None) + else: + os.environ["DBT_RECORDER_MODE"] = prev_mode + if prev_type is None: + os.environ.pop("DBT_RECORDER_TYPES", None) + else: + os.environ["DBT_RECORDER_TYPES"] = prev_type + if prev_fp is None: + os.environ.pop("DBT_RECORDER_FILE_PATH", None) + else: + os.environ["DBT_RECORDER_FILE_PATH"] = prev_fp + + +def test_decorator_records(setup) -> None: + os.environ["DBT_RECORDER_MODE"] = "Record" + recorder = Recorder(RecorderMode.RECORD, None) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + @record_function(TestRecord) + def test_func(a: int, b: str, c: Optional[str] = None) -> str: + return str(a) + b + (c if c else "") + + test_func(123, "abc") + + expected_record = TestRecord( + params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") + ) + + assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params + assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result + + +def test_record_types(setup): + os.environ["DBT_RECORDER_MODE"] = "Record" + os.environ["DBT_RECORDER_TYPES"] = "TestRecord" + recorder = Recorder(RecorderMode.RECORD, ["TestRecord"]) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + @record_function(TestRecord) + def test_func(a: int, b: str, c: Optional[str] = None) -> str: + return str(a) + b + (c if c else "") + + @record_function(NotTestRecord) + def not_test_func(a: int, b: str, c: Optional[str] = None) -> str: + return str(a) + b + (c if c else "") + + test_func(123, "abc") + not_test_func(456, "def") + + expected_record = TestRecord( + params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") + ) + + assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params + assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result + assert NotTestRecord not in recorder._records_by_type + + +def test_decorator_replays(setup) -> None: + os.environ["DBT_RECORDER_MODE"] = "Replay" + os.environ["DBT_RECORDER_FILE_PATH"] = "record.json" + recorder = Recorder(RecorderMode.REPLAY, None) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + expected_record = TestRecord( + params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") + ) + + recorder._records_by_type["TestRecord"] = [expected_record] + + @record_function(TestRecord) + def test_func(a: int, b: str, c: Optional[str] = None) -> str: + raise Exception("This should not actually be called") + + res = test_func(123, "abc") + + assert res == "123abc" diff --git a/tests/unit/test_semver.py b/tests/unit/test_semver.py index ae48e592..383d3479 100644 --- a/tests/unit/test_semver.py +++ b/tests/unit/test_semver.py @@ -1,6 +1,6 @@ import itertools import unittest -from typing import List +from typing import List, Optional from dbt_common.exceptions import VersionsNotCompatibleError from dbt_common.semver import ( @@ -23,9 +23,11 @@ def semver_regex_versioning(versions: List[str]) -> bool: return True -def create_range(start_version_string, end_version_string): - start = UnboundedVersionSpecifier() - end = UnboundedVersionSpecifier() +def create_range( + start_version_string: Optional[str], end_version_string: Optional[str] +) -> VersionRange: + start: VersionSpecifier = UnboundedVersionSpecifier() + end: VersionSpecifier = UnboundedVersionSpecifier() if start_version_string is not None: start = VersionSpecifier.from_version_string(start_version_string) @@ -37,24 +39,24 @@ def create_range(start_version_string, end_version_string): class TestSemver(unittest.TestCase): - def assertVersionSetResult(self, inputs, output_range): + def assertVersionSetResult(self, inputs, output_range) -> None: expected = create_range(*output_range) for permutation in itertools.permutations(inputs): self.assertEqual(reduce_versions(*permutation), expected) - def assertInvalidVersionSet(self, inputs): + def assertInvalidVersionSet(self, inputs) -> None: for permutation in itertools.permutations(inputs): with self.assertRaises(VersionsNotCompatibleError): reduce_versions(*permutation) - def test__versions_compatible(self): + def test__versions_compatible(self) -> None: self.assertTrue(versions_compatible("0.0.1", "0.0.1")) self.assertFalse(versions_compatible("0.0.1", "0.0.2")) self.assertTrue(versions_compatible(">0.0.1", "0.0.2")) self.assertFalse(versions_compatible("0.4.5a1", "0.4.5a2")) - def test__semver_regex_versions(self): + def test__semver_regex_versions(self) -> None: self.assertTrue( semver_regex_versioning( [ @@ -140,7 +142,7 @@ def test__semver_regex_versions(self): ) ) - def test__reduce_versions(self): + def test__reduce_versions(self) -> None: self.assertVersionSetResult(["0.0.1", "0.0.1"], ["=0.0.1", "=0.0.1"]) self.assertVersionSetResult(["0.0.1"], ["=0.0.1", "=0.0.1"]) @@ -175,7 +177,7 @@ def test__reduce_versions(self): self.assertInvalidVersionSet(["<0.0.3", ">=0.0.3"]) self.assertInvalidVersionSet(["<0.0.3", ">0.0.3"]) - def test__resolve_to_specific_version(self): + def test__resolve_to_specific_version(self) -> None: self.assertEqual( resolve_to_specific_version(create_range(">0.0.1", None), ["0.0.1", "0.0.2"]), "0.0.2" ) @@ -253,7 +255,7 @@ def test__resolve_to_specific_version(self): "0.9.1", ) - def test__filter_installable(self): + def test__filter_installable(self) -> None: installable = filter_installable( [ "1.1.0", diff --git a/tests/unit/test_system_client.py b/tests/unit/test_system_client.py index a4dcc323..d2cf27ed 100644 --- a/tests/unit/test_system_client.py +++ b/tests/unit/test_system_client.py @@ -12,39 +12,39 @@ class SystemClient(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.tmp_dir = mkdtemp() self.profiles_path = "{}/profiles.yml".format(self.tmp_dir) - def set_up_profile(self): + def set_up_profile(self) -> None: with open(self.profiles_path, "w") as f: f.write("ORIGINAL_TEXT") - def get_profile_text(self): + def get_profile_text(self) -> str: with open(self.profiles_path, "r") as f: return f.read() - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.tmp_dir) except Exception as e: # noqa: F841 pass - def test__make_file_when_exists(self): + def test__make_file_when_exists(self) -> None: self.set_up_profile() written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") self.assertFalse(written) self.assertEqual(self.get_profile_text(), "ORIGINAL_TEXT") - def test__make_file_when_not_exists(self): + def test__make_file_when_not_exists(self) -> None: written = dbt_common.clients.system.make_file(self.profiles_path, contents="NEW_TEXT") self.assertTrue(written) self.assertEqual(self.get_profile_text(), "NEW_TEXT") - def test__make_file_with_overwrite(self): + def test__make_file_with_overwrite(self) -> None: self.set_up_profile() written = dbt_common.clients.system.make_file( self.profiles_path, contents="NEW_TEXT", overwrite=True @@ -53,12 +53,12 @@ def test__make_file_with_overwrite(self): self.assertTrue(written) self.assertEqual(self.get_profile_text(), "NEW_TEXT") - def test__make_dir_from_str(self): + def test__make_dir_from_str(self) -> None: test_dir_str = self.tmp_dir + "/test_make_from_str/sub_dir" dbt_common.clients.system.make_directory(test_dir_str) self.assertTrue(Path(test_dir_str).is_dir()) - def test__make_dir_from_pathobj(self): + def test__make_dir_from_pathobj(self) -> None: test_dir_pathobj = Path(self.tmp_dir + "/test_make_from_pathobj/sub_dir") dbt_common.clients.system.make_directory(test_dir_pathobj) self.assertTrue(test_dir_pathobj.is_dir()) @@ -72,7 +72,7 @@ class TestRunCmd(unittest.TestCase): not_a_file = "zzzbbfasdfasdfsdaq" - def setUp(self): + def setUp(self) -> None: self.tempdir = mkdtemp() self.run_dir = os.path.join(self.tempdir, "run_dir") self.does_not_exist = os.path.join(self.tempdir, "does_not_exist") @@ -86,10 +86,10 @@ def setUp(self): with open(self.empty_file, "w") as fp: # noqa: F841 pass # "touch" - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def test__executable_does_not_exist(self): + def test__executable_does_not_exist(self) -> None: with self.assertRaises(ExecutableError) as exc: dbt_common.clients.system.run_cmd(self.run_dir, [self.does_not_exist]) @@ -99,7 +99,7 @@ def test__executable_does_not_exist(self): self.assertIn("could not find", msg) self.assertIn(self.does_not_exist.lower(), msg) - def test__not_exe(self): + def test__not_exe(self) -> None: with self.assertRaises(ExecutableError) as exc: dbt_common.clients.system.run_cmd(self.run_dir, [self.empty_file]) @@ -112,14 +112,14 @@ def test__not_exe(self): self.assertIn("permissions", msg) self.assertIn(self.empty_file.lower(), msg) - def test__cwd_does_not_exist(self): + def test__cwd_does_not_exist(self) -> None: with self.assertRaises(WorkingDirectoryError) as exc: dbt_common.clients.system.run_cmd(self.does_not_exist, self.exists_cmd) msg = str(exc.exception).lower() self.assertIn("does not exist", msg) self.assertIn(self.does_not_exist.lower(), msg) - def test__cwd_not_directory(self): + def test__cwd_not_directory(self) -> None: with self.assertRaises(WorkingDirectoryError) as exc: dbt_common.clients.system.run_cmd(self.empty_file, self.exists_cmd) @@ -127,7 +127,7 @@ def test__cwd_not_directory(self): self.assertIn("not a directory", msg) self.assertIn(self.empty_file.lower(), msg) - def test__cwd_no_permissions(self): + def test__cwd_no_permissions(self) -> None: # it would be nice to add a windows test. Possible path to that is via # `psexec` (to get SYSTEM privs), use `icacls` to set permissions on # the directory for the test user. I'm pretty sure windows users can't @@ -145,18 +145,18 @@ def test__cwd_no_permissions(self): self.assertIn("permissions", msg) self.assertIn(self.run_dir.lower(), msg) - def test__ok(self): + def test__ok(self) -> None: out, err = dbt_common.clients.system.run_cmd(self.run_dir, self.exists_cmd) self.assertEqual(out.strip(), b"hello") self.assertEqual(err.strip(), b"") class TestFindMatching(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.base_dir = mkdtemp() self.tempdir = mkdtemp(dir=self.base_dir) - def test_find_matching_lowercase_file_pattern(self): + def test_find_matching_lowercase_file_pattern(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir) as named_file: file_path = os.path.dirname(named_file.name) relative_path = os.path.basename(file_path) @@ -175,7 +175,7 @@ def test_find_matching_lowercase_file_pattern(self): ] self.assertEqual(out, expected_output) - def test_find_matching_uppercase_file_pattern(self): + def test_find_matching_uppercase_file_pattern(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".SQL", dir=self.tempdir) as named_file: file_path = os.path.dirname(named_file.name) relative_path = os.path.basename(file_path) @@ -190,12 +190,12 @@ def test_find_matching_uppercase_file_pattern(self): ] self.assertEqual(out, expected_output) - def test_find_matching_file_pattern_not_found(self): + def test_find_matching_file_pattern_not_found(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".SQLT", dir=self.tempdir): out = dbt_common.clients.system.find_matching(self.tempdir, [""], "*.sql") self.assertEqual(out, []) - def test_ignore_spec(self): + def test_ignore_spec(self) -> None: with NamedTemporaryFile(prefix="sql-files", suffix=".sql", dir=self.tempdir): out = dbt_common.clients.system.find_matching( self.tempdir, @@ -207,7 +207,7 @@ def test_ignore_spec(self): ) self.assertEqual(out, []) - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.base_dir) except Exception as e: # noqa: F841 @@ -215,18 +215,18 @@ def tearDown(self): class TestUntarPackage(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.base_dir = mkdtemp() self.tempdir = mkdtemp(dir=self.base_dir) self.tempdest = mkdtemp(dir=self.base_dir) - def tearDown(self): + def tearDown(self) -> None: try: shutil.rmtree(self.base_dir) except Exception as e: # noqa: F841 pass - def test_untar_package_success(self): + def test_untar_package_success(self) -> None: # set up a valid tarball to test against with NamedTemporaryFile( prefix="my-package.2", suffix=".tar.gz", dir=self.tempdir, delete=False @@ -244,7 +244,7 @@ def test_untar_package_success(self): path = Path(os.path.join(self.tempdest, relative_file_a)) assert path.is_file() - def test_untar_package_failure(self): + def test_untar_package_failure(self) -> None: # create a text file then rename it as a tar (so it's invalid) with NamedTemporaryFile( prefix="a", suffix=".txt", dir=self.tempdir, delete=False @@ -259,7 +259,7 @@ def test_untar_package_failure(self): with self.assertRaises(tarfile.ReadError) as exc: # noqa: F841 dbt_common.clients.system.untar_package(tar_file_path, self.tempdest) - def test_untar_package_empty(self): + def test_untar_package_empty(self) -> None: # create a tarball with nothing in it with NamedTemporaryFile( prefix="my-empty-package.2", suffix=".tar.gz", dir=self.tempdir diff --git a/tests/unit/test_ui.py b/tests/unit/test_ui.py index 22e431d5..5b70b1d1 100644 --- a/tests/unit/test_ui.py +++ b/tests/unit/test_ui.py @@ -1,11 +1,11 @@ from dbt_common.ui import warning_tag, error_tag -def test_warning_tag(): +def test_warning_tag() -> None: tagged = warning_tag("hi") assert "WARNING" in tagged -def test_error_tag(): +def test_error_tag() -> None: tagged = error_tag("hi") assert "ERROR" in tagged diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 250c20cc..93c57046 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -5,7 +5,7 @@ class TestDeepMerge(unittest.TestCase): - def test__simple_cases(self): + def test__simple_cases(self) -> None: cases = [ {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, { @@ -27,7 +27,7 @@ def test__simple_cases(self): class TestMerge(unittest.TestCase): - def test__simple_cases(self): + def test__simple_cases(self) -> None: cases = [ {"args": [{}, {"a": 1}], "expected": {"a": 1}, "description": "one key into empty"}, { @@ -49,7 +49,7 @@ def test__simple_cases(self): class TestDeepMap(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.input_value = { "foo": { "bar": "hello", @@ -74,7 +74,7 @@ def intify_all(value, _): except (TypeError, ValueError): return -1 - def test__simple_cases(self): + def test__simple_cases(self) -> None: expected = { "foo": { "bar": -1, @@ -104,7 +104,7 @@ def special_keypath(value, keypath): else: return value - def test__keypath(self): + def test__keypath(self) -> None: expected = { "foo": { "bar": "hello", @@ -128,11 +128,11 @@ def test__keypath(self): actual = dbt_common.utils.dict.deep_map_render(self.special_keypath, expected) self.assertEqual(actual, expected) - def test__noop(self): + def test__noop(self) -> None: actual = dbt_common.utils.dict.deep_map_render(lambda x, _: x, self.input_value) self.assertEqual(actual, self.input_value) - def test_trivial(self): + def test_trivial(self) -> None: cases = [[], {}, 1, "abc", None, True] for case in cases: result = dbt_common.utils.dict.deep_map_render(lambda x, _: x, case)