From 4b8a41e6b2642d9c4835db6b5ae4796e25120494 Mon Sep 17 00:00:00 2001 From: Peter Webb Date: Mon, 19 Aug 2024 12:25:55 -0400 Subject: [PATCH] Even More Type Annotations (#180) * Even more type annotations. * Fix some issues for Python 3.8 * Type annotations * Generic type to work around lack of Self in Python before 3.11 --- dbt_common/clients/agate_helper.py | 4 +-- dbt_common/clients/jinja.py | 14 ++++++----- dbt_common/contracts/util.py | 11 +++++--- dbt_common/dataclass_schema.py | 7 +++--- dbt_common/exceptions/system.py | 4 +-- dbt_common/semver.py | 4 +-- tests/unit/test_connection_retries.py | 4 +-- tests/unit/test_diff.py | 36 ++++++++++++++------------- tests/unit/test_functions.py | 10 ++++---- tests/unit/test_jinja.py | 2 +- tests/unit/test_utils.py | 5 ++-- 11 files changed, 55 insertions(+), 46 deletions(-) diff --git a/dbt_common/clients/agate_helper.py b/dbt_common/clients/agate_helper.py index 3aade66d..45f525a0 100644 --- a/dbt_common/clients/agate_helper.py +++ b/dbt_common/clients/agate_helper.py @@ -1,6 +1,6 @@ from codecs import BOM_UTF8 -import agate # type: ignore +import agate import datetime import isodate import json @@ -149,7 +149,7 @@ def as_matrix(table): return [r.values() for r in table.rows.values()] -def from_csv(abspath, text_columns, delimiter=","): +def from_csv(abspath, text_columns, delimiter=",") -> agate.Table: type_tester = build_type_tester(text_columns=text_columns) with open(abspath, encoding="utf-8") as fp: if fp.read(1) != BOM: diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index 44d3eade..f6e90659 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -9,12 +9,12 @@ from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type from typing_extensions import Protocol -import jinja2 # type: ignore -import jinja2.ext # type: ignore -import jinja2.nativetypes # type: ignore -import jinja2.nodes # type: ignore -import jinja2.parser # type: ignore -import jinja2.sandbox # type: ignore +import jinja2 +import jinja2.ext +import jinja2.nativetypes +import jinja2.nodes +import jinja2.parser +import jinja2.sandbox from dbt_common.tests import test_caching_enabled from dbt_common.utils.jinja import ( @@ -124,6 +124,7 @@ def new_context( "shared or locals parameters." ) + vars = {} if vars is None else vars parent = ChainMap(vars, self.globals) if self.globals else vars return self.environment.context_class(self.environment, parent, self.name, self.blocks) @@ -544,6 +545,7 @@ def render_template(template, ctx: Dict[str, Any], node=None) -> str: def _get_blocks_hash(text: str, allowed_blocks: Optional[Set[str]], collect_raw_data: bool) -> int: """Provides a hash function over the arguments to extract_toplevel_blocks, in order to support caching.""" + allowed_blocks = allowed_blocks or set() allowed_tuple = tuple(sorted(allowed_blocks) or []) return text.__hash__() + allowed_tuple.__hash__() + collect_raw_data.__hash__() diff --git a/dbt_common/contracts/util.py b/dbt_common/contracts/util.py index 7bd26e3b..34134b52 100644 --- a/dbt_common/contracts/util.py +++ b/dbt_common/contracts/util.py @@ -1,15 +1,20 @@ import dataclasses -from typing import Any +from typing import Any, TypeVar + +_R = TypeVar("_R", bound="Replaceable") # TODO: remove from dbt_common.contracts.util:: Replaceable + references class Replaceable: - def replace(self, **kwargs: Any): + def replace(self: _R, **kwargs: Any) -> _R: return dataclasses.replace(self, **kwargs) # type: ignore +_M = TypeVar("_M", bound="Mergeable") + + class Mergeable(Replaceable): - def merged(self, *args): + def merged(self: _M, *args: Any) -> _M: """Perform a shallow merge, where the last non-None write wins. This is intended to merge dataclasses that are a collection of optional values. """ diff --git a/dbt_common/dataclass_schema.py b/dbt_common/dataclass_schema.py index 4e003b13..0aad4d5f 100644 --- a/dbt_common/dataclass_schema.py +++ b/dbt_common/dataclass_schema.py @@ -1,4 +1,4 @@ -from typing import Any, cast, ClassVar, Dict, get_type_hints, List, Optional, Tuple +from typing import Any, ClassVar, Dict, get_type_hints, List, Optional, Tuple, Union import re import jsonschema from dataclasses import fields, Field @@ -6,7 +6,6 @@ from datetime import datetime from dateutil.parser import parse -# type: ignore from mashumaro.config import ( TO_DICT_ADD_OMIT_NONE_FLAG, ADD_SERIALIZATION_CONTEXT, @@ -33,8 +32,8 @@ def serialize(self, value: datetime) -> str: out += "Z" return out - def deserialize(self, value) -> datetime: - return value if isinstance(value, datetime) else parse(cast(str, value)) + def deserialize(self, value: Union[datetime, str]) -> datetime: + return value if isinstance(value, datetime) else parse(value) class dbtMashConfig(MashBaseConfig): diff --git a/dbt_common/exceptions/system.py b/dbt_common/exceptions/system.py index b0062f63..b576baba 100644 --- a/dbt_common/exceptions/system.py +++ b/dbt_common/exceptions/system.py @@ -26,7 +26,7 @@ class WorkingDirectoryError(CommandError): def __init__(self, cwd: str, cmd: List[str], msg: str) -> None: super().__init__(cwd, cmd, msg) - def __str__(self): + def __str__(self, prefix: str = "! ") -> str: return f'{self.msg}: "{self.cwd}"' @@ -46,5 +46,5 @@ def __init__( self.stderr = scrub_secrets(stderr.decode("utf-8"), env_secrets()) self.args = (cwd, self.cmd, returncode, self.stdout, self.stderr, msg) - def __str__(self): + def __str__(self, prefix: str = "! ") -> str: return f"{self.msg} running: {self.cmd}" diff --git a/dbt_common/semver.py b/dbt_common/semver.py index 4c411911..391d8898 100644 --- a/dbt_common/semver.py +++ b/dbt_common/semver.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import re -from typing import List, Iterable +from typing import Iterable, List, Union import dbt_common.exceptions.base from dbt_common.exceptions import VersionsNotCompatibleError @@ -378,7 +378,7 @@ def is_exact(self) -> bool: return False -def reduce_versions(*args): +def reduce_versions(*args: Union[VersionSpecifier, VersionRange, str]) -> VersionRange: version_specifiers = [] for version in args: diff --git a/tests/unit/test_connection_retries.py b/tests/unit/test_connection_retries.py index 44fc72f5..12352352 100644 --- a/tests/unit/test_connection_retries.py +++ b/tests/unit/test_connection_retries.py @@ -5,12 +5,12 @@ from dbt_common.utils.connection import connection_exception_retry -def no_retry_fn(): +def no_retry_fn() -> str: return "success" class TestNoRetries: - def test_no_retry(self): + def test_no_retry(self) -> None: fn_to_retry = functools.partial(no_retry_fn) result = connection_exception_retry(fn_to_retry, 3) diff --git a/tests/unit/test_diff.py b/tests/unit/test_diff.py index 54f735e3..26d9d490 100644 --- a/tests/unit/test_diff.py +++ b/tests/unit/test_diff.py @@ -1,12 +1,14 @@ import json -from typing import Any, Dict +from typing import Any, Dict, List import pytest from dbt_common.record import Diff +Case = List[Dict[str, Any]] + @pytest.fixture -def current_query(): +def current_query() -> Case: return [ { "params": { @@ -21,7 +23,7 @@ def current_query(): @pytest.fixture -def query_modified_order(): +def query_modified_order() -> Case: return [ { "params": { @@ -36,7 +38,7 @@ def query_modified_order(): @pytest.fixture -def query_modified_value(): +def query_modified_value() -> Case: return [ { "params": { @@ -51,7 +53,7 @@ def query_modified_value(): @pytest.fixture -def current_simple(): +def current_simple() -> Case: return [ { "params": { @@ -65,7 +67,7 @@ def current_simple(): @pytest.fixture -def current_simple_modified(): +def current_simple_modified() -> Case: return [ { "params": { @@ -79,7 +81,7 @@ def current_simple_modified(): @pytest.fixture -def env_record(): +def env_record() -> Case: return [ { "params": {}, @@ -94,7 +96,7 @@ def env_record(): @pytest.fixture -def modified_env_record(): +def modified_env_record() -> Case: return [ { "params": {}, @@ -108,30 +110,30 @@ def modified_env_record(): ] -def test_diff_query_records_no_diff(current_query, query_modified_order): +def test_diff_query_records_no_diff(current_query: Case, query_modified_order: Case) -> None: # Setup: Create an instance of Diff diff_instance = Diff( current_recording_path="path/to/current", previous_recording_path="path/to/previous" ) result = diff_instance.diff_query_records(current_query, query_modified_order) # the order changed but the diff should be empty - expected_result = {} + expected_result: Dict[str, Any] = {} assert result == expected_result # Replace expected_result with what you actually expect -def test_diff_query_records_with_diff(current_query, query_modified_value): +def test_diff_query_records_with_diff(current_query: Case, query_modified_value: Case) -> None: diff_instance = Diff( current_recording_path="path/to/current", previous_recording_path="path/to/previous" ) result = diff_instance.diff_query_records(current_query, query_modified_value) # the values changed this time - expected_result = { + expected_result: Dict[str, Any] = { "values_changed": {"root[0]['result']['table'][1]['b']": {"new_value": 7, "old_value": 10}} } assert result == expected_result -def test_diff_env_records(env_record, modified_env_record): +def test_diff_env_records(env_record: Case, modified_env_record: Case) -> None: diff_instance = Diff( current_recording_path="path/to/current", previous_recording_path="path/to/previous" ) @@ -147,17 +149,17 @@ def test_diff_env_records(env_record, modified_env_record): assert result == expected_result -def test_diff_default_no_diff(current_simple): +def test_diff_default_no_diff(current_simple: Case) -> None: diff_instance = Diff( current_recording_path="path/to/current", previous_recording_path="path/to/previous" ) # use the same list to ensure no diff result = diff_instance.diff_default(current_simple, current_simple) - expected_result = {} + expected_result: Dict[str, Any] = {} assert result == expected_result -def test_diff_default_with_diff(current_simple, current_simple_modified): +def test_diff_default_with_diff(current_simple: Case, current_simple_modified: Case) -> None: diff_instance = Diff( current_recording_path="path/to/current", previous_recording_path="path/to/previous" ) @@ -170,7 +172,7 @@ def test_diff_default_with_diff(current_simple, current_simple_modified): # Mock out reading the files so we don't have to class MockFile: - def __init__(self, json_data): + def __init__(self, json_data) -> None: self.json_data = json_data def __enter__(self): diff --git a/tests/unit/test_functions.py b/tests/unit/test_functions.py index 372b2bda..9a8a9c22 100644 --- a/tests/unit/test_functions.py +++ b/tests/unit/test_functions.py @@ -38,7 +38,7 @@ def valid_error_names() -> Set[str]: class TestWarnOrError: - def test_fires_error(self, valid_error_names: Set[str]): + def test_fires_error(self, valid_error_names: Set[str]) -> None: functions.WARN_ERROR_OPTIONS = WarnErrorOptions( include="*", valid_error_names=valid_error_names ) @@ -49,8 +49,8 @@ def test_fires_warning( self, valid_error_names: Set[str], event_catcher: EventCatcher, - set_event_manager_with_catcher, - ): + set_event_manager_with_catcher: None, + ) -> None: functions.WARN_ERROR_OPTIONS = WarnErrorOptions( include="*", exclude=list(valid_error_names), valid_error_names=valid_error_names ) @@ -62,8 +62,8 @@ def test_silenced( self, valid_error_names: Set[str], event_catcher: EventCatcher, - set_event_manager_with_catcher, - ): + set_event_manager_with_catcher: None, + ) -> None: functions.WARN_ERROR_OPTIONS = WarnErrorOptions( include="*", silence=list(valid_error_names), valid_error_names=valid_error_names ) diff --git a/tests/unit/test_jinja.py b/tests/unit/test_jinja.py index e906a0ac..cf44eee7 100644 --- a/tests/unit/test_jinja.py +++ b/tests/unit/test_jinja.py @@ -227,7 +227,7 @@ def test_incomplete_block_failure(self) -> None: with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock"}) - def test_wrong_end_failure(self): + def test_wrong_end_failure(self) -> None: body = "{% myblock foo %} {% endotherblock %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"}) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 93c57046..bb5563e2 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,5 @@ import unittest +from typing import Any, Tuple, Union import dbt_common.exceptions import dbt_common.utils.dict @@ -68,7 +69,7 @@ def setUp(self) -> None: } @staticmethod - def intify_all(value, _): + def intify_all(value, _) -> int: try: return int(value) except (TypeError, ValueError): @@ -98,7 +99,7 @@ def test__simple_cases(self) -> None: self.assertEqual(actual, expected) @staticmethod - def special_keypath(value, keypath): + def special_keypath(value: Any, keypath: Tuple[Union[str, int], ...]) -> Any: if tuple(keypath) == ("foo", "baz", 1): return "hello" else: