Skip to content

Commit

Permalink
Add type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb committed Jan 23, 2024
1 parent f6040c6 commit 49f15a3
Showing 1 changed file with 39 additions and 44 deletions.
83 changes: 39 additions & 44 deletions dbt_common/clients/_jinja_blocks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from collections import namedtuple
from typing import Iterator, Optional, List

from dbt_common.exceptions import (
BlockDefinitionNotAtTopError,
Expand All @@ -12,40 +13,40 @@
)


def regex(pat):
def regex(pat: str) -> re.Pattern:
return re.compile(pat, re.DOTALL | re.MULTILINE)


class BlockData:
"""raw plaintext data from the top level of the file."""

def __init__(self, contents):
def __init__(self, contents: str) -> None:
self.block_type_name = "__dbt__data"
self.contents = contents
self.contents: str = contents
self.full_block = contents


class BlockTag:
def __init__(self, block_type_name, block_name, contents=None, full_block=None, **kw):
def __init__(self, block_type_name: str, block_name: str, contents: Optional[str] = None, full_block: Optional[str] = None) -> None:
self.block_type_name = block_type_name
self.block_name = block_name
self.contents = contents
self.full_block = full_block

def __str__(self):
def __str__(self) -> str:
return "BlockTag({!r}, {!r})".format(self.block_type_name, self.block_name)

def __repr__(self):
def __repr__(self) -> str:
return str(self)

@property
def end_block_type_name(self):
def end_block_type_name(self) -> str:
return "end{}".format(self.block_type_name)

def end_pat(self):
def end_pat(self) -> re.Pattern:
# we don't want to use string formatting here because jinja uses most
# of the string formatting operators in its syntax...
pattern = "".join(
pattern: str = "".join(
(
r"(?P<endblock>((?:\s*\{\%\-|\{\%)\s*",
self.end_block_type_name,
Expand Down Expand Up @@ -98,13 +99,11 @@ def end_pat(self):


class TagIterator:
def __init__(self, text):
self.text = text
self.blocks = []
self._parenthesis_stack = []
self.pos = 0
def __init__(self, text: str) -> None:
self.text: str = text
self.pos: int = 0

def linepos(self, end=None) -> str:
def linepos(self, end: Optional[int] = None) -> str:
"""Given an absolute position in the input text, return a pair of
line number + relative position to the start of the line.
"""
Expand All @@ -116,26 +115,22 @@ def linepos(self, end=None) -> str:
line_number = text.count("\n") + 1
return f"{line_number}:{end_val - last_line_start}"

def advance(self, new_position):
def advance(self, new_position: int) -> None:
self.pos = new_position

def rewind(self, amount=1):
def rewind(self, amount: int = 1) -> None:
self.pos -= amount

def _search(self, pattern):
def _search(self, pattern: re.Pattern) -> Optional[re.Match]:
return pattern.search(self.text, self.pos)

def _match(self, pattern):
def _match(self, pattern: re.Pattern) -> Optional[re.Match]:
return pattern.match(self.text, self.pos)

def _first_match(self, *patterns, **kwargs):
def _first_match(self, *patterns) -> Optional[re.Match]: # type: ignore
matches = []
for pattern in patterns:
# default to 'search', but sometimes we want to 'match'.
if kwargs.get("method", "search") == "search":
match = self._search(pattern)
else:
match = self._match(pattern)
match = self._search(pattern)
if match:
matches.append(match)
if not matches:
Expand All @@ -144,13 +139,13 @@ def _first_match(self, *patterns, **kwargs):
# TODO: do I need to account for m.start(), or is this ok?
return min(matches, key=lambda m: m.end())

def _expect_match(self, expected_name, *patterns, **kwargs):
match = self._first_match(*patterns, **kwargs)
def _expect_match(self, expected_name: str, *patterns) -> re.Match: # type: ignore
match = self._first_match(*patterns)
if match is None:
raise UnexpectedMacroEOFError(expected_name, self.text[self.pos :])
raise UnexpectedMacroEOFError(expected_name, self.text[self.pos:])
return match

def handle_expr(self, match):
def handle_expr(self, match: re.Match) -> None:
"""Handle an expression. At this point we're at a string like:
{{ 1 + 2 }}
^ right here
Expand All @@ -176,12 +171,12 @@ def handle_expr(self, match):

self.advance(match.end())

def handle_comment(self, match):
def handle_comment(self, match: re.Match) -> None:
self.advance(match.end())
match = self._expect_match("#}", COMMENT_END_PATTERN)
self.advance(match.end())

def _expect_block_close(self):
def _expect_block_close(self) -> None:
"""Search for the tag close marker.
To the right of the type name, there are a few possiblities:
- a name (handled by the regex's 'block_name')
Expand All @@ -203,13 +198,13 @@ def _expect_block_close(self):
string_match = self._expect_match("string", STRING_PATTERN)
self.advance(string_match.end())

def handle_raw(self):
def handle_raw(self) -> int:
# raw blocks are super special, they are a single complete regex
match = self._expect_match("{% raw %}...{% endraw %}", RAW_BLOCK_PATTERN)
self.advance(match.end())
return match.end()

def handle_tag(self, match):
def handle_tag(self, match: re.Match) -> Tag:
"""The tag could be one of a few things:
{% mytag %}
Expand All @@ -234,7 +229,7 @@ def handle_tag(self, match):
self._expect_block_close()
return Tag(block_type_name=block_type_name, block_name=block_name, start=start_pos, end=self.pos)

def find_tags(self):
def find_tags(self) -> Iterator[Tag]:
while True:
match = self._first_match(BLOCK_START_PATTERN, COMMENT_START_PATTERN, EXPR_START_PATTERN)
if match is None:
Expand All @@ -259,7 +254,7 @@ def find_tags(self):
"Invalid regex match in next_block, expected block start, " "expr start, or comment start"
)

def __iter__(self):
def __iter__(self) -> Iterator[Tag]:
return self.find_tags()


Expand All @@ -272,31 +267,31 @@ def __iter__(self):


class BlockIterator:
def __init__(self, tag_iterator):
def __init__(self, tag_iterator: TagIterator) -> None:
self.tag_parser = tag_iterator
self.current = None
self.stack = []
self.last_position = 0
self.current: Optional[Tag] = None
self.stack: List[str] = []
self.last_position: int = 0

@property
def current_end(self):
def current_end(self) -> int:
if self.current is None:
return 0
else:
return self.current.end

@property
def data(self):
def data(self) -> str:
return self.tag_parser.text

def is_current_end(self, tag):
def is_current_end(self, tag: Tag) -> bool:
return (
tag.block_type_name.startswith("end")
and self.current is not None
and tag.block_type_name[3:] == self.current.block_type_name
)

def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
def find_blocks(self, allowed_blocks: Optional[set[str]] = None, collect_raw_data: bool = True) -> Iterator[BlockData | BlockTag]:
"""Find all top-level blocks in the data."""
if allowed_blocks is None:
allowed_blocks = {"snapshot", "macro", "materialization", "docs"}
Expand Down Expand Up @@ -347,5 +342,5 @@ def find_blocks(self, allowed_blocks=None, collect_raw_data=True):
if raw_data:
yield BlockData(raw_data)

def lex_for_blocks(self, allowed_blocks=None, collect_raw_data=True):
def lex_for_blocks(self, allowed_blocks: Optional[set[str]] = None, collect_raw_data: bool = True) -> List[BlockData | BlockTag]:
return list(self.find_blocks(allowed_blocks=allowed_blocks, collect_raw_data=collect_raw_data))

0 comments on commit 49f15a3

Please sign in to comment.