Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some type annotations #156

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions dbt_common/contracts/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
@dataclass
class BaseConfig(AdditionalPropertiesAllowed, Replaceable):
# enable syntax like: config['key']
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
return self.get(key)

# like doing 'get' on a dictionary
def get(self, key, default=None):
def get(self, key: str, default: Any = None) -> Any:
if hasattr(self, key):
return getattr(self, key)
elif key in self._extra:
Expand All @@ -30,13 +30,13 @@ def get(self, key, default=None):
return default

# enable syntax like: config['key'] = value
def __setitem__(self, key, value):
def __setitem__(self, key: str, value) -> None:
if hasattr(self, key):
setattr(self, key, value)
else:
self._extra[key] = value

def __delitem__(self, key):
def __delitem__(self, key: str) -> None:
if hasattr(self, key):
msg = (
'Error, tried to delete config key "{}": Cannot delete ' "built-in keys"
Expand All @@ -60,7 +60,7 @@ def _content_iterator(self, include_condition: Callable[[Field], bool]):
def __iter__(self):
yield from self._content_iterator(include_condition=lambda f: True)

def __len__(self):
def __len__(self) -> int:
return len(self._get_fields()) + len(self._extra)

@staticmethod
Expand Down Expand Up @@ -221,7 +221,7 @@ def _merge_field_value(
merge_behavior: MergeBehavior,
self_value: Any,
other_value: Any,
):
) -> Any:
if merge_behavior == MergeBehavior.Clobber:
return other_value
elif merge_behavior == MergeBehavior.Append:
Expand Down
2 changes: 1 addition & 1 deletion dbt_common/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ def get_invocation_id() -> str:
return _INVOCATION_ID


def reset_invocation_id():
def reset_invocation_id() -> None:
global _INVOCATION_ID
_INVOCATION_ID = str(uuid.uuid4())
40 changes: 20 additions & 20 deletions dbt_common/semver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ class VersionSpecification(dbtClassMixin):
_VERSION_REGEX = re.compile(_VERSION_REGEX_PAT_STR, re.VERBOSE)


def _cmp(a, b):
def _cmp(a, b) -> int:
"""Return negative if a<b, zero if a==b, positive if a>b."""
return (a > b) - (a < b)


@dataclass
class VersionSpecifier(VersionSpecification):
def to_version_string(self, skip_matcher=False):
def to_version_string(self, skip_matcher: bool = False) -> str:
prerelease = ""
build = ""
matcher = ""
Expand All @@ -92,7 +92,7 @@ def to_version_string(self, skip_matcher=False):
)

@classmethod
def from_version_string(cls, version_string):
def from_version_string(cls, version_string: str) -> "VersionSpecifier":
match = _VERSION_REGEX.match(version_string)

if not match:
Expand All @@ -104,7 +104,7 @@ def from_version_string(cls, version_string):

return cls.from_dict(matched)

def __str__(self):
def __str__(self) -> str:
return self.to_version_string()

def to_range(self) -> "VersionRange":
Expand All @@ -123,7 +123,7 @@ def to_range(self) -> "VersionRange":

return VersionRange(start=range_start, end=range_end)

def compare(self, other):
def compare(self, other: "VersionSpecifier") -> int:
if self.is_unbounded or other.is_unbounded:
return 0

Expand Down Expand Up @@ -192,36 +192,36 @@ def compare(self, other):

return 0

def __lt__(self, other):
def __lt__(self, other: "VersionSpecifier") -> bool:
return self.compare(other) == -1

def __gt__(self, other):
def __gt__(self, other: "VersionSpecifier") -> bool:
return self.compare(other) == 1

def __eq___(self, other):
def __eq___(self, other: "VersionSpecifier") -> bool:
return self.compare(other) == 0

def __cmp___(self, other):
def __cmp___(self, other: "VersionSpecifier") -> int:
return self.compare(other)

@property
def is_unbounded(self):
def is_unbounded(self) -> bool:
return False

@property
def is_lower_bound(self):
def is_lower_bound(self) -> bool:
return self.matcher in [Matchers.GREATER_THAN, Matchers.GREATER_THAN_OR_EQUAL]

@property
def is_upper_bound(self):
def is_upper_bound(self) -> bool:
return self.matcher in [Matchers.LESS_THAN, Matchers.LESS_THAN_OR_EQUAL]

@property
def is_exact(self):
def is_exact(self) -> bool:
return self.matcher == Matchers.EXACT

@classmethod
def _nat_cmp(cls, a, b):
def _nat_cmp(cls, a, b) -> int:
def cmp_prerelease_tag(a, b):
if isinstance(a, int) and isinstance(b, int):
return _cmp(a, b)
Expand Down Expand Up @@ -358,23 +358,23 @@ def __init__(self, *args, **kwargs) -> None:
matcher=Matchers.EXACT, major=None, minor=None, patch=None, prerelease=None, build=None
)

def __str__(self):
def __str__(self) -> str:
return "*"

@property
def is_unbounded(self):
def is_unbounded(self) -> bool:
return True

@property
def is_lower_bound(self):
def is_lower_bound(self) -> bool:
return False

@property
def is_upper_bound(self):
def is_upper_bound(self) -> bool:
return False

@property
def is_exact(self):
def is_exact(self) -> bool:
return False


Expand Down Expand Up @@ -418,7 +418,7 @@ def reduce_versions(*args):
return to_return


def versions_compatible(*args):
def versions_compatible(*args) -> bool:
if len(args) == 1:
return True

Expand Down
4 changes: 2 additions & 2 deletions dbt_common/utils/casting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This is useful for proto generated classes in particular, since
# the default for protobuf for strings is the empty string, so
# Optional[str] types don't work for generated Python classes.
from typing import Optional
from typing import Any, Dict, Mapping, Optional


def cast_to_str(string: Optional[str]) -> str:
Expand All @@ -18,7 +18,7 @@ def cast_to_int(integer: Optional[int]) -> int:
return integer


def cast_dict_to_dict_of_strings(dct):
def cast_dict_to_dict_of_strings(dct: Mapping[Any, Any]) -> Dict[str, str]:
new_dct = {}
for k, v in dct.items():
new_dct[str(k)] = str(v)
Expand Down
3 changes: 2 additions & 1 deletion dbt_common/utils/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from typing import Callable

from dbt_common.events.types import RecordRetryException, RetryExternalCall
from dbt_common.exceptions import ConnectionError
Expand All @@ -7,7 +8,7 @@
import requests


def connection_exception_retry(fn, max_attempts: int, attempt: int = 0):
def connection_exception_retry(fn: Callable, max_attempts: int, attempt: int = 0):
"""Handle connection retries gracefully.

Attempts to run a function that makes an external call, if the call fails
Expand Down
14 changes: 9 additions & 5 deletions dbt_common/utils/jinja.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
from typing import Optional

from dbt_common.exceptions import DbtInternalError


MACRO_PREFIX = "dbt_macro__"
DOCS_PREFIX = "dbt_docs__"


def get_dbt_macro_name(name):
def get_dbt_macro_name(name: str) -> str:
if name is None:
raise DbtInternalError("Got None for a macro name!")
return f"{MACRO_PREFIX}{name}"


def get_dbt_docs_name(name):
def get_dbt_docs_name(name: str) -> str:
if name is None:
raise DbtInternalError("Got None for a doc name!")
return f"{DOCS_PREFIX}{name}"


def get_materialization_macro_name(materialization_name, adapter_type=None, with_prefix=True):
def get_materialization_macro_name(
materialization_name: str, adapter_type: Optional[str] = None, with_prefix: bool = True
) -> str:
if adapter_type is None:
adapter_type = "default"
name = f"materialization_{materialization_name}_{adapter_type}"
return get_dbt_macro_name(name) if with_prefix else name


def get_docs_macro_name(docs_name, with_prefix=True):
def get_docs_macro_name(docs_name: str, with_prefix: bool = True) -> str:
return get_dbt_docs_name(docs_name) if with_prefix else docs_name


def get_test_macro_name(test_name, with_prefix=True):
def get_test_macro_name(test_name: str, with_prefix: bool = True) -> str:
name = f"test_{test_name}"
return get_dbt_macro_name(name) if with_prefix else name
Loading