diff --git a/.changes/unreleased/Under the Hood-20240122-163546.yaml b/.changes/unreleased/Under the Hood-20240122-163546.yaml new file mode 100644 index 00000000..0d32e2e3 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240122-163546.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Clean up macro contexts. +time: 2024-01-22T16:35:46.907999-05:00 +custom: + Author: peterallenwebb + Issue: "35" diff --git a/.changes/unreleased/Under the Hood-20240123-161107.yaml b/.changes/unreleased/Under the Hood-20240123-161107.yaml new file mode 100644 index 00000000..68a83ab4 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240123-161107.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Inject TagIterator into BlockIterator for greater flexibility. +time: 2024-01-23T16:11:07.24321-05:00 +custom: + Author: peterallenwebb + Issue: "38" diff --git a/.changes/unreleased/Under the Hood-20240123-194242.yaml b/.changes/unreleased/Under the Hood-20240123-194242.yaml new file mode 100644 index 00000000..a0cb7431 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240123-194242.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Change metadata_vars \`if not\` to \`if ... is None\` +time: 2024-01-23T19:42:42.95727089Z +custom: + Author: truls-p + Issue: "6073" diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c3b8366b..035240f4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -38,10 +38,10 @@ jobs: steps: - name: Check out the repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' diff --git a/.github/workflows/ci_code_quality.yml b/.github/workflows/ci_code_quality.yml index 7b12e3d5..b8b33314 100644 --- a/.github/workflows/ci_code_quality.yml +++ b/.github/workflows/ci_code_quality.yml @@ -38,10 +38,10 @@ jobs: steps: - name: Check out the repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' diff --git a/.github/workflows/ci_tests.yml b/.github/workflows/ci_tests.yml index 03cdffb8..87550594 100644 --- a/.github/workflows/ci_tests.yml +++ b/.github/workflows/ci_tests.yml @@ -43,10 +43,10 @@ jobs: steps: - name: "Check out the repository" - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: "Set up Python ${{ matrix.python-version }}" - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/dbt_common/clients/_jinja_blocks.py b/dbt_common/clients/_jinja_blocks.py index 9a830570..c6058bfa 100644 --- a/dbt_common/clients/_jinja_blocks.py +++ b/dbt_common/clients/_jinja_blocks.py @@ -1,5 +1,6 @@ import re from collections import namedtuple +from typing import Iterator, List, Optional, Set, Union from dbt_common.exceptions import ( BlockDefinitionNotAtTopError, @@ -12,40 +13,42 @@ ) -def regex(pat): +def regex(pat: str) -> re.Pattern: return re.compile(pat, re.DOTALL | re.MULTILINE) class BlockData: """raw plaintext data from the top level of the file.""" - def __init__(self, contents): + def __init__(self, contents: str) -> None: self.block_type_name = "__dbt__data" - self.contents = contents + self.contents: str = contents self.full_block = contents class BlockTag: - def __init__(self, block_type_name, block_name, contents=None, full_block=None, **kw): + def __init__( + self, block_type_name: str, block_name: str, contents: Optional[str] = None, full_block: Optional[str] = None + ) -> None: self.block_type_name = block_type_name self.block_name = block_name self.contents = contents self.full_block = full_block - def __str__(self): + def __str__(self) -> str: return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name) - def __repr__(self): + def __repr__(self) -> str: return str(self) @property - def end_block_type_name(self): + def end_block_type_name(self) -> str: return "end{}".format(self.block_type_name) - def end_pat(self): + def end_pat(self) -> re.Pattern: # we don't want to use string formatting here because jinja uses most # of the string formatting operators in its syntax... - pattern = "".join( + pattern: str = "".join( ( r"(?P((?:\s*\{\%\-|\{\%)\s*", self.end_block_type_name, @@ -98,44 +101,38 @@ def end_pat(self): class TagIterator: - def __init__(self, data): - self.data = data - self.blocks = [] - self._parenthesis_stack = [] - self.pos = 0 - - def linepos(self, end=None) -> str: - """Given an absolute position in the input data, return a pair of + def __init__(self, text: str) -> None: + self.text: str = text + self.pos: int = 0 + + def linepos(self, end: Optional[int] = None) -> str: + """Given an absolute position in the input text, return a pair of line number + relative position to the start of the line. """ end_val: int = self.pos if end is None else end - data = self.data[:end_val] + text = self.text[:end_val] # if not found, rfind returns -1, and -1+1=0, which is perfect! - last_line_start = data.rfind("\n") + 1 + last_line_start = text.rfind("\n") + 1 # it's easy to forget this, but line numbers are 1-indexed - line_number = data.count("\n") + 1 + line_number = text.count("\n") + 1 return f"{line_number}:{end_val - last_line_start}" - def advance(self, new_position): + def advance(self, new_position: int) -> None: self.pos = new_position - def rewind(self, amount=1): + def rewind(self, amount: int = 1) -> None: self.pos -= amount - def _search(self, pattern): - return pattern.search(self.data, self.pos) + def _search(self, pattern: re.Pattern) -> Optional[re.Match]: + return pattern.search(self.text, self.pos) - def _match(self, pattern): - return pattern.match(self.data, self.pos) + def _match(self, pattern: re.Pattern) -> Optional[re.Match]: + return pattern.match(self.text, self.pos) - def _first_match(self, *patterns, **kwargs): + def _first_match(self, *patterns) -> Optional[re.Match]: # type: ignore matches = [] for pattern in patterns: - # default to 'search', but sometimes we want to 'match'. - if kwargs.get("method", "search") == "search": - match = self._search(pattern) - else: - match = self._match(pattern) + match = self._search(pattern) if match: matches.append(match) if not matches: @@ -144,13 +141,13 @@ def _first_match(self, *patterns, **kwargs): # TODO: do I need to account for m.start(), or is this ok? return min(matches, key=lambda m: m.end()) - def _expect_match(self, expected_name, *patterns, **kwargs): - match = self._first_match(*patterns, **kwargs) + def _expect_match(self, expected_name: str, *patterns) -> re.Match: # type: ignore + match = self._first_match(*patterns) if match is None: - raise UnexpectedMacroEOFError(expected_name, self.data[self.pos :]) + raise UnexpectedMacroEOFError(expected_name, self.text[self.pos :]) return match - def handle_expr(self, match): + def handle_expr(self, match: re.Match) -> None: """Handle an expression. At this point we're at a string like: {{ 1 + 2 }} ^ right here @@ -176,12 +173,12 @@ def handle_expr(self, match): self.advance(match.end()) - def handle_comment(self, match): + def handle_comment(self, match: re.Match) -> None: self.advance(match.end()) match = self._expect_match("#}", COMMENT_END_PATTERN) self.advance(match.end()) - def _expect_block_close(self): + def _expect_block_close(self) -> None: """Search for the tag close marker. To the right of the type name, there are a few possiblities: - a name (handled by the regex's 'block_name') @@ -203,13 +200,13 @@ def _expect_block_close(self): string_match = self._expect_match("string", STRING_PATTERN) self.advance(string_match.end()) - def handle_raw(self): + def handle_raw(self) -> int: # raw blocks are super special, they are a single complete regex match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN) self.advance(match.end()) return match.end() - def handle_tag(self, match): + def handle_tag(self, match: re.Match) -> Tag: """The tag could be one of a few things: {% mytag %} @@ -234,7 +231,7 @@ def handle_tag(self, match): self._expect_block_close() return Tag(block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos) - def find_tags(self): + def find_tags(self) -> Iterator[Tag]: while True: match = self._first_match(BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN) if match is None: @@ -259,7 +256,7 @@ def find_tags(self): "Invalid regex match in next_block, expected block start, " "expr start, or comment start" ) - def __iter__(self): + def __iter__(self) -> Iterator[Tag]: return self.find_tags() @@ -272,31 +269,33 @@ def __iter__(self): class BlockIterator: - def __init__(self, data): - self.tag_parser = TagIterator(data) - self.current = None - self.stack = [] - self.last_position = 0 + def __init__(self, tag_iterator: TagIterator) -> None: + self.tag_parser = tag_iterator + self.current: Optional[Tag] = None + self.stack: List[str] = [] + self.last_position: int = 0 @property - def current_end(self): + def current_end(self) -> int: if self.current is None: return 0 else: return self.current.end @property - def data(self): - return self.tag_parser.data + def data(self) -> str: + return self.tag_parser.text - def is_current_end(self, tag): + def is_current_end(self, tag: Tag) -> bool: return ( tag.block_type_name.startswith("end") and self.current is not None and tag.block_type_name[3:] == self.current.block_type_name ) - def find_blocks(self, allowed_blocks=None, collect_raw_data=True): + def find_blocks( + self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True + ) -> Iterator[Union[BlockData, BlockTag]]: """Find all top-level blocks in the data.""" if allowed_blocks is None: allowed_blocks = {"snapshot", "macro", "materialization", "docs"} @@ -347,5 +346,7 @@ def find_blocks(self, allowed_blocks=None, collect_raw_data=True): if raw_data: yield BlockData(raw_data) - def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True): + def lex_for_blocks( + self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True + ) -> List[Union[BlockData, BlockTag]]: return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)) diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index 1b6de92b..b8c0d03a 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -3,9 +3,10 @@ import os import tempfile from ast import literal_eval +from collections import ChainMap from contextlib import contextmanager from itertools import chain, islice -from typing import List, Union, Set, Optional, Dict, Any, Iterator, Type, Callable +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type from typing_extensions import Protocol import jinja2 @@ -21,7 +22,7 @@ get_materialization_macro_name, get_test_macro_name, ) -from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag +from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag, TagIterator from dbt_common.exceptions import ( CompilationError, @@ -99,6 +100,41 @@ def _compile(self, source, filename): return super()._compile(source, filename) # type: ignore +class MacroFuzzTemplate(jinja2.nativetypes.NativeTemplate): + environment_class = MacroFuzzEnvironment + + def new_context( + self, + vars: Optional[Dict[str, Any]] = None, + shared: bool = False, + locals: Optional[Mapping[str, Any]] = None, + ) -> jinja2.runtime.Context: + # This custom override makes the assumption that the locals and shared + # parameters are not used, so enforce that. + if shared or locals: + raise Exception("The MacroFuzzTemplate.new_context() override cannot use the shared or locals parameters.") + + parent = ChainMap(vars, self.globals) if self.globals else vars + + return self.environment.context_class(self.environment, parent, self.name, self.blocks) + + def render(self, *args: Any, **kwargs: Any) -> Any: + if kwargs or len(args) != 1: + raise Exception("The MacroFuzzTemplate.render() override requires exactly one argument.") + + ctx = self.new_context(args[0]) + + try: + return self.environment_class.concat( # type: ignore + self.root_render_func(ctx) # type: ignore + ) + except Exception: + return self.environment.handle_exception() + + +MacroFuzzEnvironment.template_class = MacroFuzzTemplate + + class NativeSandboxEnvironment(MacroFuzzEnvironment): code_generator_class = jinja2.nativetypes.NativeCodeGenerator @@ -171,7 +207,7 @@ def render(self, *args, **kwargs): with :func:`ast.literal_eval`, the parsed value is returned. Otherwise, the string is returned. """ - vars = dict(*args, **kwargs) + vars = args[0] try: return quoted_native_concat(self.root_render_func(self.new_context(vars))) @@ -226,7 +262,7 @@ def get_macro(self): # make_module is in jinja2.environment. It returns a TemplateModule module = template.make_module(vars=self.context, shared=False) macro = module.__dict__[get_dbt_macro_name(name)] - module.__dict__.update(self.context) + return macro @contextmanager @@ -480,7 +516,7 @@ def render_template(template, ctx: Dict[str, Any], node=None) -> str: def extract_toplevel_blocks( - data: str, + text: str, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True, ) -> List[Union[BlockData, BlockTag]]: @@ -498,4 +534,5 @@ def extract_toplevel_blocks( :return: A list of `BlockTag`s matching the allowed block types and (if `collect_raw_data` is `True`) `BlockData` objects. """ - return BlockIterator(data).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data) + tag_iterator = TagIterator(text) + return BlockIterator(tag_iterator).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data) diff --git a/dbt_common/events/functions.py b/dbt_common/events/functions.py index 60ef4d27..fde2219b 100644 --- a/dbt_common/events/functions.py +++ b/dbt_common/events/functions.py @@ -138,7 +138,7 @@ def fire_event(e: BaseEvent, level: Optional[EventLevel] = None) -> None: def get_metadata_vars() -> Dict[str, str]: global metadata_vars - if not metadata_vars: + if metadata_vars is None: metadata_vars = { k[len(_METADATA_ENV_PREFIX) :]: v for k, v in os.environ.items() if k.startswith(_METADATA_ENV_PREFIX) }