Skip to content

Commit

Permalink
Merge branch 'main' into er/create-template-with-ui
Browse files Browse the repository at this point in the history
  • Loading branch information
emmyoop authored Jan 25, 2024
2 parents 80a2c66 + 5ce19dd commit 25fa354
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 66 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240122-163546.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Clean up macro contexts.
time: 2024-01-22T16:35:46.907999-05:00
custom:
Author: peterallenwebb
Issue: "35"
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240123-161107.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Inject TagIterator into BlockIterator for greater flexibility.
time: 2024-01-23T16:11:07.24321-05:00
custom:
Author: peterallenwebb
Issue: "38"
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240123-194242.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Change metadata_vars \`if not\` to \`if ... is None\`
time: 2024-01-23T19:42:42.95727089Z
custom:
Author: truls-p
Issue: "6073"
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ jobs:

steps:
- name: Check out the repository
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci_code_quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ jobs:

steps:
- name: Check out the repository
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ jobs:

steps:
- name: "Check out the repository"
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: "Set up Python ${{ matrix.python-version }}"
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down
107 changes: 54 additions & 53 deletions dbt_common/clients/_jinja_blocks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from collections import namedtuple
from typing import Iterator, List, Optional, Set, Union

from dbt_common.exceptions import (
BlockDefinitionNotAtTopError,
Expand All @@ -12,40 +13,42 @@
)


def regex(pat):
def regex(pat: str) -> re.Pattern:
return re.compile(pat, re.DOTALL | re.MULTILINE)


class BlockData:
"""raw plaintext data from the top level of the file."""

def __init__(self, contents):
def __init__(self, contents: str) -> None:
self.block_type_name = "__dbt__data"
self.contents = contents
self.contents: str = contents
self.full_block = contents


class BlockTag:
def __init__(self, block_type_name, block_name, contents=None, full_block=None, **kw):
def __init__(
self, block_type_name: str, block_name: str, contents: Optional[str] = None, full_block: Optional[str] = None
) -> None:
self.block_type_name = block_type_name
self.block_name = block_name
self.contents = contents
self.full_block = full_block

def __str__(self):
def __str__(self) -> str:
return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name)

def __repr__(self):
def __repr__(self) -> str:
return str(self)

@property
def end_block_type_name(self):
def end_block_type_name(self) -> str:
return "end{}".format(self.block_type_name)

def end_pat(self):
def end_pat(self) -> re.Pattern:
# we don't want to use string formatting here because jinja uses most
# of the string formatting operators in its syntax...
pattern = "".join(
pattern: str = "".join(
(
r"(?P<endblock>((?:\s*\{\%\-|\{\%)\s*",
self.end_block_type_name,
Expand Down Expand Up @@ -98,44 +101,38 @@ def end_pat(self):


class TagIterator:
def __init__(self, data):
self.data = data
self.blocks = []
self._parenthesis_stack = []
self.pos = 0

def linepos(self, end=None) -> str:
"""Given an absolute position in the input data, return a pair of
def __init__(self, text: str) -> None:
self.text: str = text
self.pos: int = 0

def linepos(self, end: Optional[int] = None) -> str:
"""Given an absolute position in the input text, return a pair of
line number + relative position to the start of the line.
"""
end_val: int = self.pos if end is None else end
data = self.data[:end_val]
text = self.text[:end_val]
# if not found, rfind returns -1, and -1+1=0, which is perfect!
last_line_start = data.rfind("\n") + 1
last_line_start = text.rfind("\n") + 1
# it's easy to forget this, but line numbers are 1-indexed
line_number = data.count("\n") + 1
line_number = text.count("\n") + 1
return f"{line_number}:{end_val - last_line_start}"

def advance(self, new_position):
def advance(self, new_position: int) -> None:
self.pos = new_position

def rewind(self, amount=1):
def rewind(self, amount: int = 1) -> None:
self.pos -= amount

def _search(self, pattern):
return pattern.search(self.data, self.pos)
def _search(self, pattern: re.Pattern) -> Optional[re.Match]:
return pattern.search(self.text, self.pos)

def _match(self, pattern):
return pattern.match(self.data, self.pos)
def _match(self, pattern: re.Pattern) -> Optional[re.Match]:
return pattern.match(self.text, self.pos)

def _first_match(self, *patterns, **kwargs):
def _first_match(self, *patterns) -> Optional[re.Match]: # type: ignore
matches = []
for pattern in patterns:
# default to 'search', but sometimes we want to 'match'.
if kwargs.get("method", "search") == "search":
match = self._search(pattern)
else:
match = self._match(pattern)
match = self._search(pattern)
if match:
matches.append(match)
if not matches:
Expand All @@ -144,13 +141,13 @@ def _first_match(self, *patterns, **kwargs):
# TODO: do I need to account for m.start(), or is this ok?
return min(matches, key=lambda m: m.end())

def _expect_match(self, expected_name, *patterns, **kwargs):
match = self._first_match(*patterns, **kwargs)
def _expect_match(self, expected_name: str, *patterns) -> re.Match: # type: ignore
match = self._first_match(*patterns)
if match is None:
raise UnexpectedMacroEOFError(expected_name, self.data[self.pos :])
raise UnexpectedMacroEOFError(expected_name, self.text[self.pos :])
return match

def handle_expr(self, match):
def handle_expr(self, match: re.Match) -> None:
"""Handle an expression. At this point we're at a string like:
{{ 1 + 2 }}
^ right here
Expand All @@ -176,12 +173,12 @@ def handle_expr(self, match):

self.advance(match.end())

def handle_comment(self, match):
def handle_comment(self, match: re.Match) -> None:
self.advance(match.end())
match = self._expect_match("#}", COMMENT_END_PATTERN)
self.advance(match.end())

def _expect_block_close(self):
def _expect_block_close(self) -> None:
"""Search for the tag close marker.
To the right of the type name, there are a few possiblities:
- a name (handled by the regex's 'block_name')
Expand All @@ -203,13 +200,13 @@ def _expect_block_close(self):
string_match = self._expect_match("string", STRING_PATTERN)
self.advance(string_match.end())

def handle_raw(self):
def handle_raw(self) -> int:
# raw blocks are super special, they are a single complete regex
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
self.advance(match.end())
return match.end()

def handle_tag(self, match):
def handle_tag(self, match: re.Match) -> Tag:
"""The tag could be one of a few things:
{% mytag %}
Expand All @@ -234,7 +231,7 @@ def handle_tag(self, match):
self._expect_block_close()
return Tag(block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos)

def find_tags(self):
def find_tags(self) -> Iterator[Tag]:
while True:
match = self._first_match(BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN)
if match is None:
Expand All @@ -259,7 +256,7 @@ def find_tags(self):
"Invalid regex match in next_block, expected block start, " "expr start, or comment start"
)

def __iter__(self):
def __iter__(self) -> Iterator[Tag]:
return self.find_tags()


Expand All @@ -272,31 +269,33 @@ def __iter__(self):


class BlockIterator:
def __init__(self, data):
self.tag_parser = TagIterator(data)
self.current = None
self.stack = []
self.last_position = 0
def __init__(self, tag_iterator: TagIterator) -> None:
self.tag_parser = tag_iterator
self.current: Optional[Tag] = None
self.stack: List[str] = []
self.last_position: int = 0

@property
def current_end(self):
def current_end(self) -> int:
if self.current is None:
return 0
else:
return self.current.end

@property
def data(self):
return self.tag_parser.data
def data(self) -> str:
return self.tag_parser.text

def is_current_end(self, tag):
def is_current_end(self, tag: Tag) -> bool:
return (
tag.block_type_name.startswith("end")
and self.current is not None
and tag.block_type_name[3:] == self.current.block_type_name
)

def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
def find_blocks(
self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True
) -> Iterator[Union[BlockData, BlockTag]]:
"""Find all top-level blocks in the data."""
if allowed_blocks is None:
allowed_blocks = {"snapshot", "macro", "materialization", "docs"}
Expand Down Expand Up @@ -347,5 +346,7 @@ def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
if raw_data:
yield BlockData(raw_data)

def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True):
def lex_for_blocks(
self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True
) -> List[Union[BlockData, BlockTag]]:
return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data))
49 changes: 43 additions & 6 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import os
import tempfile
from ast import literal_eval
from collections import ChainMap
from contextlib import contextmanager
from itertools import chain, islice
from typing import List, Union, Set, Optional, Dict, Any, Iterator, Type, Callable
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type
from typing_extensions import Protocol

import jinja2
Expand All @@ -21,7 +22,7 @@
get_materialization_macro_name,
get_test_macro_name,
)
from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag, TagIterator

from dbt_common.exceptions import (
CompilationError,
Expand Down Expand Up @@ -99,6 +100,41 @@ def _compile(self, source, filename):
return super()._compile(source, filename) # type: ignore


class MacroFuzzTemplate(jinja2.nativetypes.NativeTemplate):
environment_class = MacroFuzzEnvironment

def new_context(
self,
vars: Optional[Dict[str, Any]] = None,
shared: bool = False,
locals: Optional[Mapping[str, Any]] = None,
) -> jinja2.runtime.Context:
# This custom override makes the assumption that the locals and shared
# parameters are not used, so enforce that.
if shared or locals:
raise Exception("The MacroFuzzTemplate.new_context() override cannot use the shared or locals parameters.")

parent = ChainMap(vars, self.globals) if self.globals else vars

return self.environment.context_class(self.environment, parent, self.name, self.blocks)

def render(self, *args: Any, **kwargs: Any) -> Any:
if kwargs or len(args) != 1:
raise Exception("The MacroFuzzTemplate.render() override requires exactly one argument.")

ctx = self.new_context(args[0])

try:
return self.environment_class.concat( # type: ignore
self.root_render_func(ctx) # type: ignore
)
except Exception:
return self.environment.handle_exception()


MacroFuzzEnvironment.template_class = MacroFuzzTemplate


class NativeSandboxEnvironment(MacroFuzzEnvironment):
code_generator_class = jinja2.nativetypes.NativeCodeGenerator

Expand Down Expand Up @@ -171,7 +207,7 @@ def render(self, *args, **kwargs):
with :func:`ast.literal_eval`, the parsed value is returned.
Otherwise, the string is returned.
"""
vars = dict(*args, **kwargs)
vars = args[0]

try:
return quoted_native_concat(self.root_render_func(self.new_context(vars)))
Expand Down Expand Up @@ -226,7 +262,7 @@ def get_macro(self):
# make_module is in jinja2.environment. It returns a TemplateModule
module = template.make_module(vars=self.context, shared=False)
macro = module.__dict__[get_dbt_macro_name(name)]
module.__dict__.update(self.context)

return macro

@contextmanager
Expand Down Expand Up @@ -480,7 +516,7 @@ def render_template(template, ctx: Dict[str, Any], node=None) -> str:


def extract_toplevel_blocks(
data: str,
text: str,
allowed_blocks: Optional[Set[str]] = None,
collect_raw_data: bool = True,
) -> List[Union[BlockData, BlockTag]]:
Expand All @@ -498,4 +534,5 @@ def extract_toplevel_blocks(
:return: A list of `BlockTag`s matching the allowed block types and (if
`collect_raw_data` is `True`) `BlockData` objects.
"""
return BlockIterator(data).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)
tag_iterator = TagIterator(text)
return BlockIterator(tag_iterator).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)
Loading

0 comments on commit 25fa354

Please sign in to comment.