Skip to content

Commit

Permalink
Refactor macro contexts. (#36)
Browse files Browse the repository at this point in the history
* Refactor macro contexts.

* Changelog entry.
  • Loading branch information
peterallenwebb authored Jan 23, 2024
1 parent 4e84585 commit 8d69478
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240122-163546.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Clean up macro contexts.
time: 2024-01-22T16:35:46.907999-05:00
custom:
Author: peterallenwebb
Issue: "35"
42 changes: 39 additions & 3 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import os
import tempfile
from ast import literal_eval
from collections import ChainMap
from contextlib import contextmanager
from itertools import chain, islice
from typing import List, Union, Set, Optional, Dict, Any, Iterator, Type, Callable
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type
from typing_extensions import Protocol

import jinja2
Expand Down Expand Up @@ -99,6 +100,41 @@ def _compile(self, source, filename):
return super()._compile(source, filename) # type: ignore


class MacroFuzzTemplate(jinja2.nativetypes.NativeTemplate):
environment_class = MacroFuzzEnvironment

def new_context(
self,
vars: Optional[Dict[str, Any]] = None,
shared: bool = False,
locals: Optional[Mapping[str, Any]] = None,
) -> jinja2.runtime.Context:
# This custom override makes the assumption that the locals and shared
# parameters are not used, so enforce that.
if shared or locals:
raise Exception("The MacroFuzzTemplate.new_context() override cannot use the shared or locals parameters.")

parent = ChainMap(vars, self.globals) if self.globals else vars

return self.environment.context_class(self.environment, parent, self.name, self.blocks)

def render(self, *args: Any, **kwargs: Any) -> Any:
if kwargs or len(args) != 1:
raise Exception("The MacroFuzzTemplate.render() override requires exactly one argument.")

ctx = self.new_context(args[0])

try:
return self.environment_class.concat( # type: ignore
self.root_render_func(ctx) # type: ignore
)
except Exception:
return self.environment.handle_exception()


MacroFuzzEnvironment.template_class = MacroFuzzTemplate


class NativeSandboxEnvironment(MacroFuzzEnvironment):
code_generator_class = jinja2.nativetypes.NativeCodeGenerator

Expand Down Expand Up @@ -171,7 +207,7 @@ def render(self, *args, **kwargs):
with :func:`ast.literal_eval`, the parsed value is returned.
Otherwise, the string is returned.
"""
vars = dict(*args, **kwargs)
vars = args[0]

try:
return quoted_native_concat(self.root_render_func(self.new_context(vars)))
Expand Down Expand Up @@ -226,7 +262,7 @@ def get_macro(self):
# make_module is in jinja2.environment. It returns a TemplateModule
module = template.make_module(vars=self.context, shared=False)
macro = module.__dict__[get_dbt_macro_name(name)]
module.__dict__.update(self.context)

return macro

@contextmanager
Expand Down

0 comments on commit 8d69478

Please sign in to comment.