From eb57bbe01aced8699fbb33200d841aa498cf9d72 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Sat, 7 Sep 2024 17:31:52 -0400 Subject: [PATCH 1/7] Add types to jinja-related files --- dbt_common/clients/jinja.py | 124 ++++++++++++++++++++------------- dbt_common/exceptions/jinja.py | 9 +-- 2 files changed, 80 insertions(+), 53 deletions(-) diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index f6e90659..d27c2645 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -6,7 +6,21 @@ from collections import ChainMap from contextlib import contextmanager from itertools import chain, islice -from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type +from types import CodeType +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Union, + Set, + Type, + NoReturn, +) + from typing_extensions import Protocol import jinja2 @@ -39,10 +53,17 @@ SUPPORTED_LANG_ARG = jinja2.nodes.Name("supported_languages", "param") # Global which can be set by dependents of dbt-common (e.g. core via flag parsing) -MACRO_DEBUGGING = False +MACRO_DEBUGGING: Union[str, bool] = False + +_ParseReturn = Union[jinja2.nodes.Node, List[jinja2.nodes.Node]] + + +# Temporary type capturing the concept the functions in this file expect for a "node" +class _NodeProtocol(Protocol): + pass -def _linecache_inject(source, write): +def _linecache_inject(source: str, write: bool) -> str: if write: # this is the only reliable way to accomplish this. Obviously, it's # really darn noisy and will fill your temporary directory @@ -58,18 +79,18 @@ def _linecache_inject(source, write): else: # `codecs.encode` actually takes a `bytes` as the first argument if # the second argument is 'hex' - mypy does not know this. - rnd = codecs.encode(os.urandom(12), "hex") # type: ignore + rnd = codecs.encode(os.urandom(12), "hex") filename = rnd.decode("ascii") # put ourselves in the cache cache_entry = (len(source), None, [line + "\n" for line in source.splitlines()], filename) # linecache does in fact have an attribute `cache`, thanks - linecache.cache[filename] = cache_entry # type: ignore + linecache.cache[filename] = cache_entry return filename class MacroFuzzParser(jinja2.parser.Parser): - def parse_macro(self): + def parse_macro(self) -> jinja2.nodes.Macro: node = jinja2.nodes.Macro(lineno=next(self.stream).lineno) # modified to fuzz macros defined in the same file. this way @@ -83,16 +104,13 @@ def parse_macro(self): class MacroFuzzEnvironment(jinja2.sandbox.SandboxedEnvironment): - def _parse(self, source, name, filename): + def _parse( + self, source: str, name: Optional[str], filename: Optional[str] + ) -> jinja2.nodes.Template: return MacroFuzzParser(self, source, name, filename).parse() - def _compile(self, source, filename): + def _compile(self, source: str, filename: str) -> CodeType: """ - - - - - Override jinja's compilation. Use to stash the rendered source inside the python linecache for debugging when the appropriate environment variable is set. @@ -108,7 +126,7 @@ def _compile(self, source, filename): class MacroFuzzTemplate(jinja2.nativetypes.NativeTemplate): - environment_class = MacroFuzzEnvironment + environment_class = MacroFuzzEnvironment # type: ignore def new_context( self, @@ -171,11 +189,11 @@ class NumberMarker(NativeMarker): pass -def _is_number(value) -> bool: +def _is_number(value: Any) -> bool: return isinstance(value, (int, float)) and not isinstance(value, bool) -def quoted_native_concat(nodes): +def quoted_native_concat(nodes: Iterator[str]) -> Any: """Handle special case for native_concat from the NativeTemplate. This is almost native_concat from the NativeTemplate, except in the @@ -213,7 +231,7 @@ def quoted_native_concat(nodes): class NativeSandboxTemplate(jinja2.nativetypes.NativeTemplate): # mypy: ignore environment_class = NativeSandboxEnvironment # type: ignore - def render(self, *args, **kwargs): + def render(self, *args: Any, **kwargs: Any) -> Any: """Render the template to produce a native Python type. If the result is a single node, its value is returned. Otherwise, @@ -229,6 +247,11 @@ def render(self, *args, **kwargs): return self.environment.handle_exception() +class MacroProtocol(Protocol): + name: str + macro_sql: str + + NativeSandboxEnvironment.template_class = NativeSandboxTemplate # type: ignore @@ -236,7 +259,7 @@ class TemplateCache: def __init__(self) -> None: self.file_cache: Dict[str, jinja2.Template] = {} - def get_node_template(self, node) -> jinja2.Template: + def get_node_template(self, node: MacroProtocol) -> jinja2.Template: key = node.macro_sql if key in self.file_cache: @@ -251,7 +274,7 @@ def get_node_template(self, node) -> jinja2.Template: self.file_cache[key] = template return template - def clear(self): + def clear(self) -> None: self.file_cache.clear() @@ -262,13 +285,13 @@ class BaseMacroGenerator: def __init__(self, context: Optional[Dict[str, Any]] = None) -> None: self.context: Optional[Dict[str, Any]] = context - def get_template(self): + def get_template(self) -> jinja2.Template: raise NotImplementedError("get_template not implemented!") def get_name(self) -> str: raise NotImplementedError("get_name not implemented!") - def get_macro(self): + def get_macro(self) -> Callable: name = self.get_name() template = self.get_template() # make the module. previously we set both vars and local, but that's @@ -286,7 +309,7 @@ def exception_handler(self) -> Iterator[None]: except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e: raise CaughtMacroError(e) - def call_macro(self, *args, **kwargs): + def call_macro(self, *args: Any, **kwargs: Any) -> Any: # called from __call__ methods if self.context is None: raise DbtInternalError("Context is still None in call_macro!") @@ -301,11 +324,6 @@ def call_macro(self, *args, **kwargs): return e.value -class MacroProtocol(Protocol): - name: str - macro_sql: str - - class CallableMacroGenerator(BaseMacroGenerator): def __init__( self, @@ -315,7 +333,7 @@ def __init__( super().__init__(context) self.macro = macro - def get_template(self): + def get_template(self) -> jinja2.Template: return template_cache.get_node_template(self.macro) def get_name(self) -> str: @@ -332,14 +350,14 @@ def exception_handler(self) -> Iterator[None]: raise e # this makes MacroGenerator objects callable like functions - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.call_macro(*args, **kwargs) class MaterializationExtension(jinja2.ext.Extension): tags = ["materialization"] - def parse(self, parser): + def parse(self, parser: jinja2.parser.Parser) -> _ParseReturn: node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno) materialization_name = parser.parse_assign_target(name_only=True).name @@ -382,7 +400,7 @@ def parse(self, parser): class DocumentationExtension(jinja2.ext.Extension): tags = ["docs"] - def parse(self, parser): + def parse(self, parser: jinja2.parser.Parser) -> _ParseReturn: node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno) docs_name = parser.parse_assign_target(name_only=True).name @@ -396,7 +414,7 @@ def parse(self, parser): class TestExtension(jinja2.ext.Extension): tags = ["test"] - def parse(self, parser): + def parse(self, parser: jinja2.parser.Parser) -> _ParseReturn: node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno) test_name = parser.parse_assign_target(name_only=True).name @@ -406,13 +424,19 @@ def parse(self, parser): return node -def _is_dunder_name(name): +def _is_dunder_name(name: str) -> bool: return name.startswith("__") and name.endswith("__") -def create_undefined(node=None): +def create_undefined(node: Optional[_NodeProtocol] = None) -> Type[jinja2.Undefined]: class Undefined(jinja2.Undefined): - def __init__(self, hint=None, obj=None, name=None, exc=None): + def __init__( + self, + hint: Optional[str] = None, + obj: Any = None, + name: Optional[str] = None, + exc: Any = None, + ) -> None: super().__init__(hint=hint, name=name) self.node = node self.name = name @@ -422,12 +446,12 @@ def __init__(self, hint=None, obj=None, name=None, exc=None): self.unsafe_callable = False self.alters_data = False - def __getitem__(self, name): + def __getitem__(self, name: Any) -> "Undefined": # Propagate the undefined value if a caller accesses this as if it # were a dictionary return self - def __getattr__(self, name): + def __getattr__(self, name: str) -> "Undefined": if name == "name" or _is_dunder_name(name): raise AttributeError( "'{}' object has no attribute '{}'".format(type(self).__name__, name) @@ -437,11 +461,11 @@ def __getattr__(self, name): return self.__class__(hint=self.hint, name=self.name) - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> "Undefined": return self - def __reduce__(self): - raise UndefinedCompilationError(name=self.name, node=node) + def __reduce__(self) -> NoReturn: + raise UndefinedCompilationError(name=self.name or "unknown", node=node) return Undefined @@ -463,7 +487,7 @@ def __reduce__(self): def get_environment( - node=None, + node: Optional[_NodeProtocol] = None, capture_macros: bool = False, native: bool = False, ) -> jinja2.Environment: @@ -472,7 +496,7 @@ def get_environment( } if capture_macros: - args["undefined"] = create_undefined(node) + args["undefined"] = create_undefined(node) # type: ignore args["extensions"].append(MaterializationExtension) args["extensions"].append(DocumentationExtension) @@ -493,7 +517,7 @@ def get_environment( @contextmanager -def catch_jinja(node=None) -> Iterator[None]: +def catch_jinja(node: Optional[_NodeProtocol] = None) -> Iterator[None]: try: yield except jinja2.exceptions.TemplateSyntaxError as e: @@ -506,16 +530,16 @@ def catch_jinja(node=None) -> Iterator[None]: raise -_TESTING_PARSE_CACHE: Dict[str, jinja2.Template] = {} +_TESTING_PARSE_CACHE: Dict[str, jinja2.nodes.Template] = {} -def parse(string): +def parse(string: Any) -> jinja2.nodes.Template: str_string = str(string) if test_caching_enabled() and str_string in _TESTING_PARSE_CACHE: return _TESTING_PARSE_CACHE[str_string] with catch_jinja(): - parsed = get_environment().parse(str(string)) + parsed: jinja2.nodes.Template = get_environment().parse(str(string)) if test_caching_enabled(): _TESTING_PARSE_CACHE[str_string] = parsed return parsed @@ -524,10 +548,10 @@ def parse(string): def get_template( string: str, ctx: Dict[str, Any], - node=None, + node: Optional[_NodeProtocol] = None, capture_macros: bool = False, native: bool = False, -): +) -> jinja2.Template: with catch_jinja(node): env = get_environment(node, capture_macros, native=native) @@ -535,7 +559,9 @@ def get_template( return env.from_string(template_source, globals=ctx) -def render_template(template, ctx: Dict[str, Any], node=None) -> str: +def render_template( + template: jinja2.Template, ctx: Dict[str, Any], node: Optional[_NodeProtocol] = None +) -> str: with catch_jinja(node): return template.render(ctx) diff --git a/dbt_common/exceptions/jinja.py b/dbt_common/exceptions/jinja.py index 8edfd87a..a697c34d 100644 --- a/dbt_common/exceptions/jinja.py +++ b/dbt_common/exceptions/jinja.py @@ -1,8 +1,9 @@ +from dbt_common.clients._jinja_blocks import Tag, TagIterator from dbt_common.exceptions import CompilationError class BlockDefinitionNotAtTopError(CompilationError): - def __init__(self, tag_parser, tag_start) -> None: + def __init__(self, tag_parser: TagIterator, tag_start: int) -> None: self.tag_parser = tag_parser self.tag_start = tag_start super().__init__(msg=self.get_message()) @@ -31,7 +32,7 @@ def get_message(self) -> str: class MissingControlFlowStartTagError(CompilationError): - def __init__(self, tag, expected_tag: str, tag_parser) -> None: + def __init__(self, tag: Tag, expected_tag: str, tag_parser: TagIterator) -> None: self.tag = tag self.expected_tag = expected_tag self.tag_parser = tag_parser @@ -47,7 +48,7 @@ def get_message(self) -> str: class NestedTagsError(CompilationError): - def __init__(self, outer, inner) -> None: + def __init__(self, outer: Tag, inner: Tag) -> None: self.outer = outer self.inner = inner super().__init__(msg=self.get_message()) @@ -62,7 +63,7 @@ def get_message(self) -> str: class UnexpectedControlFlowEndTagError(CompilationError): - def __init__(self, tag, expected_tag: str, tag_parser) -> None: + def __init__(self, tag: Tag, expected_tag: str, tag_parser: TagIterator) -> None: self.tag = tag self.expected_tag = expected_tag self.tag_parser = tag_parser From 436ac248ea99dffa62916c3bf6acd3bfdeca218b Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Sat, 7 Sep 2024 17:43:33 -0400 Subject: [PATCH 2/7] Fix circular import --- dbt_common/exceptions/jinja.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dbt_common/exceptions/jinja.py b/dbt_common/exceptions/jinja.py index a697c34d..19355df1 100644 --- a/dbt_common/exceptions/jinja.py +++ b/dbt_common/exceptions/jinja.py @@ -1,9 +1,13 @@ -from dbt_common.clients._jinja_blocks import Tag, TagIterator + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from dbt_common.clients._jinja_blocks import Tag, TagIterator + from dbt_common.exceptions import CompilationError class BlockDefinitionNotAtTopError(CompilationError): - def __init__(self, tag_parser: TagIterator, tag_start: int) -> None: + def __init__(self, tag_parser: "TagIterator", tag_start: int) -> None: self.tag_parser = tag_parser self.tag_start = tag_start super().__init__(msg=self.get_message()) @@ -32,7 +36,7 @@ def get_message(self) -> str: class MissingControlFlowStartTagError(CompilationError): - def __init__(self, tag: Tag, expected_tag: str, tag_parser: TagIterator) -> None: + def __init__(self, tag: "Tag", expected_tag: str, tag_parser: "TagIterator") -> None: self.tag = tag self.expected_tag = expected_tag self.tag_parser = tag_parser @@ -48,7 +52,7 @@ def get_message(self) -> str: class NestedTagsError(CompilationError): - def __init__(self, outer: Tag, inner: Tag) -> None: + def __init__(self, outer: "Tag", inner: "Tag") -> None: self.outer = outer self.inner = inner super().__init__(msg=self.get_message()) @@ -63,7 +67,7 @@ def get_message(self) -> str: class UnexpectedControlFlowEndTagError(CompilationError): - def __init__(self, tag: Tag, expected_tag: str, tag_parser: TagIterator) -> None: + def __init__(self, tag: "Tag", expected_tag: str, tag_parser: "TagIterator") -> None: self.tag = tag self.expected_tag = expected_tag self.tag_parser = tag_parser From f63804af8a9af847905102ed16a892be550260e3 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Sat, 7 Sep 2024 17:46:02 -0400 Subject: [PATCH 3/7] Fix black --- dbt_common/exceptions/jinja.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt_common/exceptions/jinja.py b/dbt_common/exceptions/jinja.py index 19355df1..d73295f4 100644 --- a/dbt_common/exceptions/jinja.py +++ b/dbt_common/exceptions/jinja.py @@ -1,5 +1,5 @@ - from typing import TYPE_CHECKING + if TYPE_CHECKING: from dbt_common.clients._jinja_blocks import Tag, TagIterator From 05d6ae77f3a7ca0387cf768477e02063e549578a Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Sun, 8 Sep 2024 15:39:17 -0400 Subject: [PATCH 4/7] Add types to semver code --- dbt_common/record.py | 10 ++++- dbt_common/semver.py | 66 ++++++++++++++++++------------- tests/unit/test_behavior_flags.py | 8 ++-- tests/unit/test_diff.py | 25 +++++++----- tests/unit/test_functions.py | 1 + tests/unit/test_semver.py | 4 +- tests/unit/test_utils.py | 2 +- 7 files changed, 71 insertions(+), 45 deletions(-) diff --git a/dbt_common/record.py b/dbt_common/record.py index 612ddf75..76d2c330 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -29,8 +29,14 @@ def __init__(self, params, result) -> None: def to_dict(self) -> Dict[str, Any]: return { - "params": self.params._to_dict() if hasattr(self.params, "_to_dict") else dataclasses.asdict(self.params), # type: ignore - "result": self.result._to_dict() if hasattr(self.result, "_to_dict") else dataclasses.asdict(self.result) if self.result is not None else None, # type: ignore + "params": self.params._to_dict() + if hasattr(self.params, "_to_dict") + else dataclasses.asdict(self.params), + "result": self.result._to_dict() + if hasattr(self.result, "_to_dict") + else dataclasses.asdict(self.result) + if self.result is not None + else None, } @classmethod diff --git a/dbt_common/semver.py b/dbt_common/semver.py index 391d8898..e248bee8 100644 --- a/dbt_common/semver.py +++ b/dbt_common/semver.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import re -from typing import Iterable, List, Union +from typing import Any, Iterable, List, Union import dbt_common.exceptions.base from dbt_common.exceptions import VersionsNotCompatibleError @@ -67,9 +67,9 @@ class VersionSpecification(dbtClassMixin): _VERSION_REGEX = re.compile(_VERSION_REGEX_PAT_STR, re.VERBOSE) -def _cmp(a, b) -> int: +def _cmp(a: Any, b: Any) -> int: """Return negative if ab.""" - return (a > b) - (a < b) + return int((a > b) - (a < b)) @dataclass @@ -102,7 +102,9 @@ def from_version_string(cls, version_string: str) -> "VersionSpecifier": matched = {k: v for k, v in match.groupdict().items() if v is not None} - return cls.from_dict(matched) + spec = cls.from_dict(matched) + assert isinstance(spec, VersionSpecifier) + return spec def __str__(self) -> str: return self.to_version_string() @@ -198,10 +200,11 @@ def __lt__(self, other: "VersionSpecifier") -> bool: def __gt__(self, other: "VersionSpecifier") -> bool: return self.compare(other) == 1 - def __eq___(self, other: "VersionSpecifier") -> bool: + def __eq__(self, other: object) -> bool: + assert isinstance(other, VersionSpecifier) return self.compare(other) == 0 - def __cmp___(self, other: "VersionSpecifier") -> int: + def __cmp__(self, other: "VersionSpecifier") -> int: return self.compare(other) @property @@ -221,8 +224,8 @@ def is_exact(self) -> bool: return self.matcher == Matchers.EXACT @classmethod - def _nat_cmp(cls, a, b) -> int: - def cmp_prerelease_tag(a, b): + def _nat_cmp(cls, a: str, b: str) -> int: + def cmp_prerelease_tag(a: Union[str, int], b: Union[str, int]) -> int: if isinstance(a, int) and isinstance(b, int): return _cmp(a, b) elif isinstance(a, int): @@ -234,10 +237,10 @@ def cmp_prerelease_tag(a, b): a, b = a or "", b or "" a_parts, b_parts = a.split("."), b.split(".") - a_parts = [int(x) if re.match(r"^\d+$", x) else x for x in a_parts] - b_parts = [int(x) if re.match(r"^\d+$", x) else x for x in b_parts] - for sub_a, sub_b in zip(a_parts, b_parts): - cmp_result = cmp_prerelease_tag(sub_a, sub_b) + a_parts_2 = [int(x) if re.match(r"^\d+$", x) else x for x in a_parts] + b_parts_2 = [int(x) if re.match(r"^\d+$", x) else x for x in b_parts] + for sub_a, sub_b in zip(a_parts_2, b_parts_2): + cmp_result = cmp_prerelease_tag(sub_a, sub_b) # type: ignore if cmp_result != 0: return cmp_result else: @@ -249,13 +252,15 @@ class VersionRange: start: VersionSpecifier end: VersionSpecifier - def _try_combine_exact(self, a, b): + def _try_combine_exact(self, a: VersionSpecifier, b: VersionSpecifier) -> VersionSpecifier: if a.compare(b) == 0: return a else: raise VersionsNotCompatibleError() - def _try_combine_lower_bound_with_exact(self, lower, exact): + def _try_combine_lower_bound_with_exact( + self, lower: VersionSpecifier, exact: VersionSpecifier + ) -> VersionSpecifier: comparison = lower.compare(exact) if comparison < 0 or (comparison == 0 and lower.matcher == Matchers.GREATER_THAN_OR_EQUAL): @@ -263,7 +268,9 @@ def _try_combine_lower_bound_with_exact(self, lower, exact): raise VersionsNotCompatibleError() - def _try_combine_lower_bound(self, a, b): + def _try_combine_lower_bound( + self, a: VersionSpecifier, b: VersionSpecifier + ) -> VersionSpecifier: if b.is_unbounded: return a elif a.is_unbounded: @@ -280,10 +287,12 @@ def _try_combine_lower_bound(self, a, b): elif a.is_exact: return self._try_combine_lower_bound_with_exact(b, a) - elif b.is_exact: + else: return self._try_combine_lower_bound_with_exact(a, b) - def _try_combine_upper_bound_with_exact(self, upper, exact): + def _try_combine_upper_bound_with_exact( + self, upper: VersionSpecifier, exact: VersionSpecifier + ) -> VersionSpecifier: comparison = upper.compare(exact) if comparison > 0 or (comparison == 0 and upper.matcher == Matchers.LESS_THAN_OR_EQUAL): @@ -291,7 +300,9 @@ def _try_combine_upper_bound_with_exact(self, upper, exact): raise VersionsNotCompatibleError() - def _try_combine_upper_bound(self, a, b): + def _try_combine_upper_bound( + self, a: VersionSpecifier, b: VersionSpecifier + ) -> VersionSpecifier: if b.is_unbounded: return a elif a.is_unbounded: @@ -308,15 +319,14 @@ def _try_combine_upper_bound(self, a, b): elif a.is_exact: return self._try_combine_upper_bound_with_exact(b, a) - elif b.is_exact: + else: return self._try_combine_upper_bound_with_exact(a, b) - def reduce(self, other): + def reduce(self, other: "VersionRange") -> "VersionRange": start = None if self.start.is_exact and other.start.is_exact: start = end = self._try_combine_exact(self.start, other.start) - else: start = self._try_combine_lower_bound(self.start, other.start) end = self._try_combine_upper_bound(self.end, other.end) @@ -326,7 +336,7 @@ def reduce(self, other): return VersionRange(start=start, end=end) - def __str__(self): + def __str__(self) -> str: result = [] if self.start.is_unbounded and self.end.is_unbounded: @@ -340,7 +350,7 @@ def __str__(self): return ", ".join(result) - def to_version_string_pair(self): + def to_version_string_pair(self) -> List[str]: to_return = [] if not self.start.is_unbounded: @@ -353,7 +363,7 @@ def to_version_string_pair(self): class UnboundedVersionSpecifier(VersionSpecifier): - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__( matcher=Matchers.EXACT, major=None, minor=None, patch=None, prerelease=None, build=None ) @@ -418,7 +428,7 @@ def reduce_versions(*args: Union[VersionSpecifier, VersionRange, str]) -> Versio return to_return -def versions_compatible(*args) -> bool: +def versions_compatible(*args: Union[VersionSpecifier, VersionRange, str]) -> bool: if len(args) == 1: return True @@ -429,7 +439,9 @@ def versions_compatible(*args) -> bool: return False -def find_possible_versions(requested_range, available_versions: Iterable[str]): +def find_possible_versions( + requested_range: VersionRange, available_versions: Iterable[str] +) -> List[str]: possible_versions = [] for version_string in available_versions: @@ -443,7 +455,7 @@ def find_possible_versions(requested_range, available_versions: Iterable[str]): def resolve_to_specific_version( - requested_range, available_versions: Iterable[str] + requested_range: VersionRange, available_versions: Iterable[str] ) -> Optional[str]: max_version = None max_version_string = None diff --git a/tests/unit/test_behavior_flags.py b/tests/unit/test_behavior_flags.py index 256b744e..88fcb849 100644 --- a/tests/unit/test_behavior_flags.py +++ b/tests/unit/test_behavior_flags.py @@ -4,7 +4,7 @@ from dbt_common.exceptions.base import CompilationError -def test_behavior_default(): +def test_behavior_default() -> None: behavior = Behavior( [ {"name": "default_false_flag", "default": False}, @@ -17,7 +17,7 @@ def test_behavior_default(): assert behavior.default_true_flag.setting is True -def test_behavior_user_override(): +def test_behavior_user_override() -> None: behavior = Behavior( [ {"name": "flag_default_false", "default": False}, @@ -43,7 +43,7 @@ def test_behavior_user_override(): assert behavior.flag_default_true_override_true.setting is True -def test_behavior_unregistered_flag_raises_correct_exception(): +def test_behavior_unregistered_flag_raises_correct_exception() -> None: behavior = Behavior( [ {"name": "behavior_flag_exists", "default": False}, @@ -56,7 +56,7 @@ def test_behavior_unregistered_flag_raises_correct_exception(): assert behavior.behavior_flag_does_not_exist -def test_behavior_flag_can_be_used_as_conditional(): +def test_behavior_flag_can_be_used_as_conditional() -> None: behavior = Behavior( [ {"name": "flag_false", "default": False}, diff --git a/tests/unit/test_diff.py b/tests/unit/test_diff.py index 26d9d490..002f793e 100644 --- a/tests/unit/test_diff.py +++ b/tests/unit/test_diff.py @@ -1,8 +1,10 @@ import json -from typing import Any, Dict, List +from inspect import Traceback +from typing import Any, Callable, Dict, List, Optional, Type import pytest from dbt_common.record import Diff +from tox.pytest import MonkeyPatch Case = List[Dict[str, Any]] @@ -172,22 +174,27 @@ def test_diff_default_with_diff(current_simple: Case, current_simple_modified: C # Mock out reading the files so we don't have to class MockFile: - def __init__(self, json_data) -> None: + def __init__(self, json_data: Any) -> None: self.json_data = json_data - def __enter__(self): + def __enter__(self) -> "MockFile": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[Exception]], + exc_val: Optional[Exception], + exc_tb: Optional[Traceback], + ) -> None: pass - def read(self): + def read(self) -> str: return json.dumps(self.json_data) # Create a Mock Open Function -def mock_open(mock_files): - def open_mock(file, *args, **kwargs): +def mock_open(mock_files: Dict[str, Any]) -> Callable[..., MockFile]: + def open_mock(file: str, *args: Any, **kwargs: Any) -> MockFile: if file in mock_files: return MockFile(mock_files[file]) raise FileNotFoundError(f"No mock file found for {file}") @@ -195,7 +202,7 @@ def open_mock(file, *args, **kwargs): return open_mock -def test_calculate_diff_no_diff(monkeypatch) -> None: +def test_calculate_diff_no_diff(monkeypatch: MonkeyPatch) -> None: # Mock data for the files current_recording_data = { "GetEnvRecord": [ @@ -259,7 +266,7 @@ def test_calculate_diff_no_diff(monkeypatch) -> None: assert result == expected_result -def test_calculate_diff_with_diff(monkeypatch) -> None: +def test_calculate_diff_with_diff(monkeypatch: MonkeyPatch) -> None: # Mock data for the files current_recording_data = { "GetEnvRecord": [ diff --git a/tests/unit/test_functions.py b/tests/unit/test_functions.py index 9a8a9c22..6c4126a1 100644 --- a/tests/unit/test_functions.py +++ b/tests/unit/test_functions.py @@ -17,6 +17,7 @@ def code(self) -> str: return "Z050" def message(self) -> str: + assert isinstance(self.msg, str) return self.msg diff --git a/tests/unit/test_semver.py b/tests/unit/test_semver.py index 383d3479..a4b76ef8 100644 --- a/tests/unit/test_semver.py +++ b/tests/unit/test_semver.py @@ -39,13 +39,13 @@ def create_range( class TestSemver(unittest.TestCase): - def assertVersionSetResult(self, inputs, output_range) -> None: + def assertVersionSetResult(self, inputs: List[str], output_range: List[Optional[str]]) -> None: expected = create_range(*output_range) for permutation in itertools.permutations(inputs): self.assertEqual(reduce_versions(*permutation), expected) - def assertInvalidVersionSet(self, inputs) -> None: + def assertInvalidVersionSet(self, inputs: List[str]) -> None: for permutation in itertools.permutations(inputs): with self.assertRaises(VersionsNotCompatibleError): reduce_versions(*permutation) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index bb5563e2..356a3898 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -69,7 +69,7 @@ def setUp(self) -> None: } @staticmethod - def intify_all(value, _) -> int: + def intify_all(value: Any, _: Any) -> int: try: return int(value) except (TypeError, ValueError): From 6bf126ccf9c0a5f335d798bedb6146e735dd9a73 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Sun, 8 Sep 2024 15:47:15 -0400 Subject: [PATCH 5/7] Fix monkeypatch type --- tests/unit/test_diff.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_diff.py b/tests/unit/test_diff.py index 002f793e..295d4e61 100644 --- a/tests/unit/test_diff.py +++ b/tests/unit/test_diff.py @@ -4,7 +4,6 @@ import pytest from dbt_common.record import Diff -from tox.pytest import MonkeyPatch Case = List[Dict[str, Any]] @@ -202,7 +201,7 @@ def open_mock(file: str, *args: Any, **kwargs: Any) -> MockFile: return open_mock -def test_calculate_diff_no_diff(monkeypatch: MonkeyPatch) -> None: +def test_calculate_diff_no_diff(monkeypatch: pytest.MonkeyPatch) -> None: # Mock data for the files current_recording_data = { "GetEnvRecord": [ @@ -266,7 +265,7 @@ def test_calculate_diff_no_diff(monkeypatch: MonkeyPatch) -> None: assert result == expected_result -def test_calculate_diff_with_diff(monkeypatch: MonkeyPatch) -> None: +def test_calculate_diff_with_diff(monkeypatch: pytest.MonkeyPatch) -> None: # Mock data for the files current_recording_data = { "GetEnvRecord": [ From b0485f9d1c6d93955847a075621ffd8a4f06d276 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Mon, 9 Sep 2024 10:19:39 -0400 Subject: [PATCH 6/7] Some final annotations --- tests/unit/test_behavior_flags.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_behavior_flags.py b/tests/unit/test_behavior_flags.py index 88fcb849..1433bbdd 100644 --- a/tests/unit/test_behavior_flags.py +++ b/tests/unit/test_behavior_flags.py @@ -2,6 +2,7 @@ from dbt_common.behavior_flags import Behavior from dbt_common.exceptions.base import CompilationError +from unit.utils import EventCatcher def test_behavior_default() -> None: @@ -69,7 +70,7 @@ def test_behavior_flag_can_be_used_as_conditional() -> None: assert True if behavior.flag_true else False -def test_behavior_flags_emit_deprecation_event_on_evaluation(event_catcher) -> None: +def test_behavior_flags_emit_deprecation_event_on_evaluation(event_catcher: EventCatcher) -> None: behavior = Behavior( [ {"name": "flag_false", "default": False}, @@ -89,7 +90,7 @@ def test_behavior_flags_emit_deprecation_event_on_evaluation(event_catcher) -> N assert len(event_catcher.caught_events) == 1 -def test_behavior_flags_emit_correct_deprecation_event(event_catcher) -> None: +def test_behavior_flags_emit_correct_deprecation_event(event_catcher: EventCatcher) -> None: behavior = Behavior([{"name": "flag_false", "default": False}], {}) # trigger the evaluation @@ -102,7 +103,7 @@ def test_behavior_flags_emit_correct_deprecation_event(event_catcher) -> None: assert msg.data.flag_source == __name__ # defaults to the calling module -def test_behavior_flags_no_deprecation_event_on_no_warn(event_catcher) -> None: +def test_behavior_flags_no_deprecation_event_on_no_warn(event_catcher: EventCatcher) -> None: behavior = Behavior([{"name": "flag_false", "default": False}], {}) # trigger the evaluation with no_warn, no event should fire From d3ccb68d8bbd359a10c205bb7d3a77ea2a264d12 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Mon, 9 Sep 2024 10:28:31 -0400 Subject: [PATCH 7/7] Add explicit path for import --- tests/unit/test_behavior_flags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_behavior_flags.py b/tests/unit/test_behavior_flags.py index 1433bbdd..73c550da 100644 --- a/tests/unit/test_behavior_flags.py +++ b/tests/unit/test_behavior_flags.py @@ -2,7 +2,7 @@ from dbt_common.behavior_flags import Behavior from dbt_common.exceptions.base import CompilationError -from unit.utils import EventCatcher +from tests.unit.utils import EventCatcher def test_behavior_default() -> None: