Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Type Annotations in jinja and semver Code #187

Merged
merged 7 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 75 additions & 49 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,21 @@
from collections import ChainMap
from contextlib import contextmanager
from itertools import chain, islice
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type
from types import CodeType
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Union,
Set,
Type,
NoReturn,
)

from typing_extensions import Protocol

import jinja2
Expand Down Expand Up @@ -39,10 +53,17 @@
SUPPORTED_LANG_ARG = jinja2.nodes.Name("supported_languages", "param")

# Global which can be set by dependents of dbt-common (e.g. core via flag parsing)
MACRO_DEBUGGING = False
MACRO_DEBUGGING: Union[str, bool] = False

_ParseReturn = Union[jinja2.nodes.Node, List[jinja2.nodes.Node]]


# Temporary type capturing the concept the functions in this file expect for a "node"
class _NodeProtocol(Protocol):
pass


def _linecache_inject(source, write):
def _linecache_inject(source: str, write: bool) -> str:
if write:
# this is the only reliable way to accomplish this. Obviously, it's
# really darn noisy and will fill your temporary directory
Expand All @@ -58,18 +79,18 @@ def _linecache_inject(source, write):
else:
# `codecs.encode` actually takes a `bytes` as the first argument if
# the second argument is 'hex' - mypy does not know this.
rnd = codecs.encode(os.urandom(12), "hex") # type: ignore
rnd = codecs.encode(os.urandom(12), "hex")
filename = rnd.decode("ascii")

# put ourselves in the cache
cache_entry = (len(source), None, [line + "\n" for line in source.splitlines()], filename)
# linecache does in fact have an attribute `cache`, thanks
linecache.cache[filename] = cache_entry # type: ignore
linecache.cache[filename] = cache_entry
return filename


class MacroFuzzParser(jinja2.parser.Parser):
def parse_macro(self):
def parse_macro(self) -> jinja2.nodes.Macro:
node = jinja2.nodes.Macro(lineno=next(self.stream).lineno)

# modified to fuzz macros defined in the same file. this way
Expand All @@ -83,16 +104,13 @@ def parse_macro(self):


class MacroFuzzEnvironment(jinja2.sandbox.SandboxedEnvironment):
def _parse(self, source, name, filename):
def _parse(
self, source: str, name: Optional[str], filename: Optional[str]
) -> jinja2.nodes.Template:
return MacroFuzzParser(self, source, name, filename).parse()

def _compile(self, source, filename):
def _compile(self, source: str, filename: str) -> CodeType:
"""





Override jinja's compilation. Use to stash the rendered source inside
the python linecache for debugging when the appropriate environment
variable is set.
Expand All @@ -108,7 +126,7 @@ def _compile(self, source, filename):


class MacroFuzzTemplate(jinja2.nativetypes.NativeTemplate):
environment_class = MacroFuzzEnvironment
environment_class = MacroFuzzEnvironment # type: ignore

def new_context(
self,
Expand Down Expand Up @@ -171,11 +189,11 @@ class NumberMarker(NativeMarker):
pass


def _is_number(value) -> bool:
def _is_number(value: Any) -> bool:
return isinstance(value, (int, float)) and not isinstance(value, bool)


def quoted_native_concat(nodes):
def quoted_native_concat(nodes: Iterator[str]) -> Any:
"""Handle special case for native_concat from the NativeTemplate.

This is almost native_concat from the NativeTemplate, except in the
Expand Down Expand Up @@ -213,7 +231,7 @@ def quoted_native_concat(nodes):
class NativeSandboxTemplate(jinja2.nativetypes.NativeTemplate): # mypy: ignore
environment_class = NativeSandboxEnvironment # type: ignore

def render(self, *args, **kwargs):
def render(self, *args: Any, **kwargs: Any) -> Any:
"""Render the template to produce a native Python type.

If the result is a single node, its value is returned. Otherwise,
Expand All @@ -229,14 +247,19 @@ def render(self, *args, **kwargs):
return self.environment.handle_exception()


class MacroProtocol(Protocol):
name: str
macro_sql: str


NativeSandboxEnvironment.template_class = NativeSandboxTemplate # type: ignore


class TemplateCache:
def __init__(self) -> None:
self.file_cache: Dict[str, jinja2.Template] = {}

def get_node_template(self, node) -> jinja2.Template:
def get_node_template(self, node: MacroProtocol) -> jinja2.Template:
key = node.macro_sql

if key in self.file_cache:
Expand All @@ -251,7 +274,7 @@ def get_node_template(self, node) -> jinja2.Template:
self.file_cache[key] = template
return template

def clear(self):
def clear(self) -> None:
self.file_cache.clear()


Expand All @@ -262,13 +285,13 @@ class BaseMacroGenerator:
def __init__(self, context: Optional[Dict[str, Any]] = None) -> None:
self.context: Optional[Dict[str, Any]] = context

def get_template(self):
def get_template(self) -> jinja2.Template:
raise NotImplementedError("get_template not implemented!")

def get_name(self) -> str:
raise NotImplementedError("get_name not implemented!")

def get_macro(self):
def get_macro(self) -> Callable:
name = self.get_name()
template = self.get_template()
# make the module. previously we set both vars and local, but that's
Expand All @@ -286,7 +309,7 @@ def exception_handler(self) -> Iterator[None]:
except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e:
raise CaughtMacroError(e)

def call_macro(self, *args, **kwargs):
def call_macro(self, *args: Any, **kwargs: Any) -> Any:
# called from __call__ methods
if self.context is None:
raise DbtInternalError("Context is still None in call_macro!")
Expand All @@ -301,11 +324,6 @@ def call_macro(self, *args, **kwargs):
return e.value


class MacroProtocol(Protocol):
name: str
macro_sql: str


class CallableMacroGenerator(BaseMacroGenerator):
def __init__(
self,
Expand All @@ -315,7 +333,7 @@ def __init__(
super().__init__(context)
self.macro = macro

def get_template(self):
def get_template(self) -> jinja2.Template:
return template_cache.get_node_template(self.macro)

def get_name(self) -> str:
Expand All @@ -332,14 +350,14 @@ def exception_handler(self) -> Iterator[None]:
raise e

# this makes MacroGenerator objects callable like functions
def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.call_macro(*args, **kwargs)


class MaterializationExtension(jinja2.ext.Extension):
tags = ["materialization"]

def parse(self, parser):
def parse(self, parser: jinja2.parser.Parser) -> _ParseReturn:
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
materialization_name = parser.parse_assign_target(name_only=True).name

Expand Down Expand Up @@ -382,7 +400,7 @@ def parse(self, parser):
class DocumentationExtension(jinja2.ext.Extension):
tags = ["docs"]

def parse(self, parser):
def parse(self, parser: jinja2.parser.Parser) -> _ParseReturn:
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
docs_name = parser.parse_assign_target(name_only=True).name

Expand All @@ -396,7 +414,7 @@ def parse(self, parser):
class TestExtension(jinja2.ext.Extension):
tags = ["test"]

def parse(self, parser):
def parse(self, parser: jinja2.parser.Parser) -> _ParseReturn:
node = jinja2.nodes.Macro(lineno=next(parser.stream).lineno)
test_name = parser.parse_assign_target(name_only=True).name

Expand All @@ -406,13 +424,19 @@ def parse(self, parser):
return node


def _is_dunder_name(name):
def _is_dunder_name(name: str) -> bool:
return name.startswith("__") and name.endswith("__")


def create_undefined(node=None):
def create_undefined(node: Optional[_NodeProtocol] = None) -> Type[jinja2.Undefined]:
class Undefined(jinja2.Undefined):
def __init__(self, hint=None, obj=None, name=None, exc=None):
def __init__(
self,
hint: Optional[str] = None,
obj: Any = None,
name: Optional[str] = None,
exc: Any = None,
) -> None:
super().__init__(hint=hint, name=name)
self.node = node
self.name = name
Expand All @@ -422,12 +446,12 @@ def __init__(self, hint=None, obj=None, name=None, exc=None):
self.unsafe_callable = False
self.alters_data = False

def __getitem__(self, name):
def __getitem__(self, name: Any) -> "Undefined":
# Propagate the undefined value if a caller accesses this as if it
# were a dictionary
return self

def __getattr__(self, name):
def __getattr__(self, name: str) -> "Undefined":
if name == "name" or _is_dunder_name(name):
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
Expand All @@ -437,11 +461,11 @@ def __getattr__(self, name):

return self.__class__(hint=self.hint, name=self.name)

def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> "Undefined":
return self

def __reduce__(self):
raise UndefinedCompilationError(name=self.name, node=node)
def __reduce__(self) -> NoReturn:
raise UndefinedCompilationError(name=self.name or "unknown", node=node)

return Undefined

Expand All @@ -463,7 +487,7 @@ def __reduce__(self):


def get_environment(
node=None,
node: Optional[_NodeProtocol] = None,
capture_macros: bool = False,
native: bool = False,
) -> jinja2.Environment:
Expand All @@ -472,7 +496,7 @@ def get_environment(
}

if capture_macros:
args["undefined"] = create_undefined(node)
args["undefined"] = create_undefined(node) # type: ignore

args["extensions"].append(MaterializationExtension)
args["extensions"].append(DocumentationExtension)
Expand All @@ -493,7 +517,7 @@ def get_environment(


@contextmanager
def catch_jinja(node=None) -> Iterator[None]:
def catch_jinja(node: Optional[_NodeProtocol] = None) -> Iterator[None]:
try:
yield
except jinja2.exceptions.TemplateSyntaxError as e:
Expand All @@ -506,16 +530,16 @@ def catch_jinja(node=None) -> Iterator[None]:
raise


_TESTING_PARSE_CACHE: Dict[str, jinja2.Template] = {}
_TESTING_PARSE_CACHE: Dict[str, jinja2.nodes.Template] = {}


def parse(string):
def parse(string: Any) -> jinja2.nodes.Template:
str_string = str(string)
if test_caching_enabled() and str_string in _TESTING_PARSE_CACHE:
return _TESTING_PARSE_CACHE[str_string]

with catch_jinja():
parsed = get_environment().parse(str(string))
parsed: jinja2.nodes.Template = get_environment().parse(str(string))
if test_caching_enabled():
_TESTING_PARSE_CACHE[str_string] = parsed
return parsed
Expand All @@ -524,18 +548,20 @@ def parse(string):
def get_template(
string: str,
ctx: Dict[str, Any],
node=None,
node: Optional[_NodeProtocol] = None,
capture_macros: bool = False,
native: bool = False,
):
) -> jinja2.Template:
with catch_jinja(node):
env = get_environment(node, capture_macros, native=native)

template_source = str(string)
return env.from_string(template_source, globals=ctx)


def render_template(template, ctx: Dict[str, Any], node=None) -> str:
def render_template(
template: jinja2.Template, ctx: Dict[str, Any], node: Optional[_NodeProtocol] = None
) -> str:
with catch_jinja(node):
return template.render(ctx)

Expand Down
13 changes: 9 additions & 4 deletions dbt_common/exceptions/jinja.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dbt_common.clients._jinja_blocks import Tag, TagIterator

from dbt_common.exceptions import CompilationError


class BlockDefinitionNotAtTopError(CompilationError):
def __init__(self, tag_parser, tag_start) -> None:
def __init__(self, tag_parser: "TagIterator", tag_start: int) -> None:
self.tag_parser = tag_parser
self.tag_start = tag_start
super().__init__(msg=self.get_message())
Expand Down Expand Up @@ -31,7 +36,7 @@ def get_message(self) -> str:


class MissingControlFlowStartTagError(CompilationError):
def __init__(self, tag, expected_tag: str, tag_parser) -> None:
def __init__(self, tag: "Tag", expected_tag: str, tag_parser: "TagIterator") -> None:
self.tag = tag
self.expected_tag = expected_tag
self.tag_parser = tag_parser
Expand All @@ -47,7 +52,7 @@ def get_message(self) -> str:


class NestedTagsError(CompilationError):
def __init__(self, outer, inner) -> None:
def __init__(self, outer: "Tag", inner: "Tag") -> None:
self.outer = outer
self.inner = inner
super().__init__(msg=self.get_message())
Expand All @@ -62,7 +67,7 @@ def get_message(self) -> str:


class UnexpectedControlFlowEndTagError(CompilationError):
def __init__(self, tag, expected_tag: str, tag_parser) -> None:
def __init__(self, tag: "Tag", expected_tag: str, tag_parser: "TagIterator") -> None:
self.tag = tag
self.expected_tag = expected_tag
self.tag_parser = tag_parser
Expand Down
Loading
Loading