Skip to content

Commit

Permalink
More Type Annotations (#177)
Browse files Browse the repository at this point in the history
* More type annotations.

* Fix Python 3.8 issues. More typing.

* More Python 3.8 fixes
  • Loading branch information
peterallenwebb authored Aug 1, 2024
1 parent bef3b7d commit c9cc99e
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 26 deletions.
4 changes: 2 additions & 2 deletions dbt_common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dbt_common.record import Recorder


class CaseInsensitiveMapping(Mapping):
class CaseInsensitiveMapping(Mapping[str, str]):
def __init__(self, env: Mapping[str, str]):
self._env = {k.casefold(): (k, v) for k, v in env.items()}

Expand Down Expand Up @@ -65,7 +65,7 @@ def env_secrets(self) -> List[str]:


def reliably_get_invocation_var() -> ContextVar[InvocationContext]:
invocation_var: Optional[ContextVar] = next(
invocation_var: Optional[ContextVar[InvocationContext]] = next(
(cv for cv in copy_context() if cv.name == _INVOCATION_CONTEXT_VAR.name), None
)

Expand Down
12 changes: 6 additions & 6 deletions dbt_common/contracts/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass, Field

from itertools import chain
from typing import Callable, Dict, Any, List, TypeVar, Type
from typing import Any, Callable, Dict, Iterator, List, Type, TypeVar

from dbt_common.contracts.config.metadata import Metadata
from dbt_common.exceptions import CompilationError, DbtInternalError
Expand Down Expand Up @@ -45,7 +45,7 @@ def __delitem__(self, key: str) -> None:
else:
del self._extra[key]

def _content_iterator(self, include_condition: Callable[[Field], bool]):
def _content_iterator(self, include_condition: Callable[[Field[Any]], bool]) -> Iterator[str]:
seen = set()
for fld, _ in self._get_fields():
seen.add(fld.name)
Expand All @@ -57,7 +57,7 @@ def _content_iterator(self, include_condition: Callable[[Field], bool]):
seen.add(key)
yield key

def __iter__(self):
def __iter__(self) -> Iterator[str]:
yield from self._content_iterator(include_condition=lambda f: True)

def __len__(self) -> int:
Expand All @@ -76,7 +76,7 @@ def compare_key(
elif key in unrendered and key not in other:
return False
else:
return unrendered[key] == other[key]
return bool(unrendered[key] == other[key])

@classmethod
def same_contents(cls, unrendered: Dict[str, Any], other: Dict[str, Any]) -> bool:
Expand Down Expand Up @@ -203,11 +203,11 @@ def metadata_key(cls) -> str:
return "compare"

@classmethod
def should_include(cls, fld: Field) -> bool:
def should_include(cls, fld: Field[Any]) -> bool:
return cls.from_field(fld) == cls.Include


def _listify(value: Any) -> List:
def _listify(value: Any) -> List[Any]:
if isinstance(value, list):
return value[:]
else:
Expand Down
7 changes: 4 additions & 3 deletions dbt_common/contracts/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import dataclasses
from typing import Any


# TODO: remove from dbt_common.contracts.util:: Replaceable + references
class Replaceable:
def replace(self, **kwargs):
return dataclasses.replace(self, **kwargs)
def replace(self, **kwargs: Any):
return dataclasses.replace(self, **kwargs) # type: ignore


class Mergeable(Replaceable):
Expand All @@ -15,7 +16,7 @@ def merged(self, *args):
replacements = {}
cls = type(self)
for arg in args:
for field in dataclasses.fields(cls):
for field in dataclasses.fields(cls): # type: ignore
value = getattr(arg, field.name)
if value is not None:
replacements[field.name] = value
Expand Down
2 changes: 1 addition & 1 deletion dbt_common/dataclass_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def json_schema(cls):
return json_schema

@classmethod
def validate(cls, data):
def validate(cls, data: Any) -> None:
json_schema = cls.json_schema()
validator = jsonschema.Draft7Validator(json_schema)
error = next(iter(validator.iter_errors(data)), None)
Expand Down
2 changes: 1 addition & 1 deletion dbt_common/events/contextvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_contextvars(prefix: str) -> Dict[str, Any]:
return rv


def get_node_info():
def get_node_info() -> Dict[str, Any]:
cvars = get_contextvars(LOG_PREFIX)
if "node_info" in cvars:
return cvars["node_info"]
Expand Down
26 changes: 13 additions & 13 deletions dbt_common/exceptions/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import builtins
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
import os

from dbt_common.constants import SECRET_ENV_PREFIX
Expand All @@ -23,7 +23,7 @@ class DbtBaseException(Exception):
CODE = -32000
MESSAGE = "Server Error"

def data(self):
def data(self) -> Dict[str, Any]:
# if overriding, make sure the result is json-serializable.
return {
"type": self.__class__.__name__,
Expand All @@ -32,15 +32,15 @@ def data(self):


class DbtInternalError(DbtBaseException):
def __init__(self, msg: str):
def __init__(self, msg: str) -> None:
self.stack: List = []
self.msg = scrub_secrets(msg, env_secrets())

@property
def type(self) -> str:
return "Internal"

def process_stack(self):
def process_stack(self) -> List[str]:
lines = []
stack = self.stack
first = True
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(self, msg: str, node=None) -> None:
self.node = node
self.msg = scrub_secrets(msg, env_secrets())

def add_node(self, node=None):
def add_node(self, node=None) -> None:
if node is not None and node is not self.node:
if self.node is not None:
self.stack.append(self.node)
Expand All @@ -91,7 +91,7 @@ def add_node(self, node=None):
def type(self):
return "Runtime"

def node_to_string(self, node: Any):
def node_to_string(self, node: Any) -> str:
"""Given a node-like object we attempt to create the best identifier we can."""
result = ""
if hasattr(node, "resource_type"):
Expand All @@ -103,7 +103,7 @@ def node_to_string(self, node: Any):

return result.strip() if result != "" else "<Unknown>"

def process_stack(self):
def process_stack(self) -> List[str]:
lines = []
stack = self.stack + [self.node]
first = True
Expand All @@ -122,7 +122,7 @@ def process_stack(self):

return lines

def validator_error_message(self, exc: builtins.Exception):
def validator_error_message(self, exc: builtins.Exception) -> str:
"""Given a dbt.dataclass_schema.ValidationError return the relevant parts as a string.
dbt.dataclass_schema.ValidationError is basically a jsonschema.ValidationError)
Expand All @@ -132,7 +132,7 @@ def validator_error_message(self, exc: builtins.Exception):
path = "[%s]" % "][".join(map(repr, exc.relative_path))
return f"at path {path}: {exc.message}"

def __str__(self, prefix: str = "! "):
def __str__(self, prefix: str = "! ") -> str:
node_string = ""

if self.node is not None:
Expand All @@ -149,7 +149,7 @@ def __str__(self, prefix: str = "! "):

return lines[0] + "\n" + "\n".join([" " + line for line in lines[1:]])

def data(self):
def data(self) -> Dict[str, Any]:
result = DbtBaseException.data(self)
if self.node is None:
return result
Expand Down Expand Up @@ -236,7 +236,7 @@ class DbtDatabaseError(DbtRuntimeError):
CODE = 10003
MESSAGE = "Database Error"

def process_stack(self):
def process_stack(self) -> List[str]:
lines = []

if hasattr(self.node, "build_path") and self.node.build_path:
Expand All @@ -250,7 +250,7 @@ def type(self):


class UnexpectedNullError(DbtDatabaseError):
def __init__(self, field_name: str, source):
def __init__(self, field_name: str, source) -> None:
self.field_name = field_name
self.source = source
msg = (
Expand All @@ -268,7 +268,7 @@ def __init__(self, cwd: str, cmd: List[str], msg: str = "Error running command")
self.cmd = cmd_scrubbed
self.args = (cwd, cmd_scrubbed, msg)

def __str__(self):
def __str__(self, prefix: str = "! ") -> str:
if len(self.cmd) == 0:
return f"{self.msg}: No arguments given"
return f'{self.msg}: "{self.cmd[0]}"'

0 comments on commit c9cc99e

Please sign in to comment.