diff --git a/.changes/unreleased/Under the Hood-20240123-161107.yaml b/.changes/unreleased/Under the Hood-20240123-161107.yaml new file mode 100644 index 00000000..68a83ab4 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240123-161107.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Inject TagIterator into BlockIterator for greater flexibility. +time: 2024-01-23T16:11:07.24321-05:00 +custom: + Author: peterallenwebb + Issue: "38" diff --git a/dbt_common/clients/_jinja_blocks.py b/dbt_common/clients/_jinja_blocks.py index 9a830570..c6058bfa 100644 --- a/dbt_common/clients/_jinja_blocks.py +++ b/dbt_common/clients/_jinja_blocks.py @@ -1,5 +1,6 @@ import re from collections import namedtuple +from typing import Iterator, List, Optional, Set, Union from dbt_common.exceptions import ( BlockDefinitionNotAtTopError, @@ -12,40 +13,42 @@ ) -def regex(pat): +def regex(pat: str) -> re.Pattern: return re.compile(pat, re.DOTALL | re.MULTILINE) class BlockData: """raw plaintext data from the top level of the file.""" - def __init__(self, contents): + def __init__(self, contents: str) -> None: self.block_type_name = "__dbt__data" - self.contents = contents + self.contents: str = contents self.full_block = contents class BlockTag: - def __init__(self, block_type_name, block_name, contents=None, full_block=None, **kw): + def __init__( + self, block_type_name: str, block_name: str, contents: Optional[str] = None, full_block: Optional[str] = None + ) -> None: self.block_type_name = block_type_name self.block_name = block_name self.contents = contents self.full_block = full_block - def __str__(self): + def __str__(self) -> str: return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name) - def __repr__(self): + def __repr__(self) -> str: return str(self) @property - def end_block_type_name(self): + def end_block_type_name(self) -> str: return "end{}".format(self.block_type_name) - def end_pat(self): + def end_pat(self) -> re.Pattern: # we don't want to use string formatting here because jinja uses most # of the string formatting operators in its syntax... - pattern = "".join( + pattern: str = "".join( ( r"(?P((?:\s*\{\%\-|\{\%)\s*", self.end_block_type_name, @@ -98,44 +101,38 @@ def end_pat(self): class TagIterator: - def __init__(self, data): - self.data = data - self.blocks = [] - self._parenthesis_stack = [] - self.pos = 0 - - def linepos(self, end=None) -> str: - """Given an absolute position in the input data, return a pair of + def __init__(self, text: str) -> None: + self.text: str = text + self.pos: int = 0 + + def linepos(self, end: Optional[int] = None) -> str: + """Given an absolute position in the input text, return a pair of line number + relative position to the start of the line. """ end_val: int = self.pos if end is None else end - data = self.data[:end_val] + text = self.text[:end_val] # if not found, rfind returns -1, and -1+1=0, which is perfect! - last_line_start = data.rfind("\n") + 1 + last_line_start = text.rfind("\n") + 1 # it's easy to forget this, but line numbers are 1-indexed - line_number = data.count("\n") + 1 + line_number = text.count("\n") + 1 return f"{line_number}:{end_val - last_line_start}" - def advance(self, new_position): + def advance(self, new_position: int) -> None: self.pos = new_position - def rewind(self, amount=1): + def rewind(self, amount: int = 1) -> None: self.pos -= amount - def _search(self, pattern): - return pattern.search(self.data, self.pos) + def _search(self, pattern: re.Pattern) -> Optional[re.Match]: + return pattern.search(self.text, self.pos) - def _match(self, pattern): - return pattern.match(self.data, self.pos) + def _match(self, pattern: re.Pattern) -> Optional[re.Match]: + return pattern.match(self.text, self.pos) - def _first_match(self, *patterns, **kwargs): + def _first_match(self, *patterns) -> Optional[re.Match]: # type: ignore matches = [] for pattern in patterns: - # default to 'search', but sometimes we want to 'match'. - if kwargs.get("method", "search") == "search": - match = self._search(pattern) - else: - match = self._match(pattern) + match = self._search(pattern) if match: matches.append(match) if not matches: @@ -144,13 +141,13 @@ def _first_match(self, *patterns, **kwargs): # TODO: do I need to account for m.start(), or is this ok? return min(matches, key=lambda m: m.end()) - def _expect_match(self, expected_name, *patterns, **kwargs): - match = self._first_match(*patterns, **kwargs) + def _expect_match(self, expected_name: str, *patterns) -> re.Match: # type: ignore + match = self._first_match(*patterns) if match is None: - raise UnexpectedMacroEOFError(expected_name, self.data[self.pos :]) + raise UnexpectedMacroEOFError(expected_name, self.text[self.pos :]) return match - def handle_expr(self, match): + def handle_expr(self, match: re.Match) -> None: """Handle an expression. At this point we're at a string like: {{ 1 + 2 }} ^ right here @@ -176,12 +173,12 @@ def handle_expr(self, match): self.advance(match.end()) - def handle_comment(self, match): + def handle_comment(self, match: re.Match) -> None: self.advance(match.end()) match = self._expect_match("#}", COMMENT_END_PATTERN) self.advance(match.end()) - def _expect_block_close(self): + def _expect_block_close(self) -> None: """Search for the tag close marker. To the right of the type name, there are a few possiblities: - a name (handled by the regex's 'block_name') @@ -203,13 +200,13 @@ def _expect_block_close(self): string_match = self._expect_match("string", STRING_PATTERN) self.advance(string_match.end()) - def handle_raw(self): + def handle_raw(self) -> int: # raw blocks are super special, they are a single complete regex match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN) self.advance(match.end()) return match.end() - def handle_tag(self, match): + def handle_tag(self, match: re.Match) -> Tag: """The tag could be one of a few things: {% mytag %} @@ -234,7 +231,7 @@ def handle_tag(self, match): self._expect_block_close() return Tag(block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos) - def find_tags(self): + def find_tags(self) -> Iterator[Tag]: while True: match = self._first_match(BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN) if match is None: @@ -259,7 +256,7 @@ def find_tags(self): "Invalid regex match in next_block, expected block start, " "expr start, or comment start" ) - def __iter__(self): + def __iter__(self) -> Iterator[Tag]: return self.find_tags() @@ -272,31 +269,33 @@ def __iter__(self): class BlockIterator: - def __init__(self, data): - self.tag_parser = TagIterator(data) - self.current = None - self.stack = [] - self.last_position = 0 + def __init__(self, tag_iterator: TagIterator) -> None: + self.tag_parser = tag_iterator + self.current: Optional[Tag] = None + self.stack: List[str] = [] + self.last_position: int = 0 @property - def current_end(self): + def current_end(self) -> int: if self.current is None: return 0 else: return self.current.end @property - def data(self): - return self.tag_parser.data + def data(self) -> str: + return self.tag_parser.text - def is_current_end(self, tag): + def is_current_end(self, tag: Tag) -> bool: return ( tag.block_type_name.startswith("end") and self.current is not None and tag.block_type_name[3:] == self.current.block_type_name ) - def find_blocks(self, allowed_blocks=None, collect_raw_data=True): + def find_blocks( + self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True + ) -> Iterator[Union[BlockData, BlockTag]]: """Find all top-level blocks in the data.""" if allowed_blocks is None: allowed_blocks = {"snapshot", "macro", "materialization", "docs"} @@ -347,5 +346,7 @@ def find_blocks(self, allowed_blocks=None, collect_raw_data=True): if raw_data: yield BlockData(raw_data) - def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True): + def lex_for_blocks( + self, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True + ) -> List[Union[BlockData, BlockTag]]: return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)) diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index ca9a4b55..b8c0d03a 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -22,7 +22,7 @@ get_materialization_macro_name, get_test_macro_name, ) -from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag +from dbt_common.clients._jinja_blocks import BlockIterator, BlockData, BlockTag, TagIterator from dbt_common.exceptions import ( CompilationError, @@ -516,7 +516,7 @@ def render_template(template, ctx: Dict[str, Any], node=None) -> str: def extract_toplevel_blocks( - data: str, + text: str, allowed_blocks: Optional[Set[str]] = None, collect_raw_data: bool = True, ) -> List[Union[BlockData, BlockTag]]: @@ -534,4 +534,5 @@ def extract_toplevel_blocks( :return: A list of `BlockTag`s matching the allowed block types and (if `collect_raw_data` is `True`) `BlockData` objects. """ - return BlockIterator(data).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data) + tag_iterator = TagIterator(text) + return BlockIterator(tag_iterator).lex_for_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data)