From 8d6947893272be15faf4fac4456cf4b18f902278 Mon Sep 17 00:00:00 2001 From: Peter Webb Date: Tue, 23 Jan 2024 12:48:44 -0500 Subject: [PATCH] Refactor macro contexts. (#36) * Refactor macro contexts. * Changelog entry. --- .../Under the Hood-20240122-163546.yaml | 6 +++ dbt_common/clients/jinja.py | 42 +++++++++++++++++-- 2 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20240122-163546.yaml diff --git a/.changes/unreleased/Under the Hood-20240122-163546.yaml b/.changes/unreleased/Under the Hood-20240122-163546.yaml new file mode 100644 index 00000000..0d32e2e3 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240122-163546.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Clean up macro contexts. +time: 2024-01-22T16:35:46.907999-05:00 +custom: + Author: peterallenwebb + Issue: "35" diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index 1b6de92b..ca9a4b55 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -3,9 +3,10 @@ import os import tempfile from ast import literal_eval +from collections import ChainMap from contextlib import contextmanager from itertools import chain, islice -from typing import List, Union, Set, Optional, Dict, Any, Iterator, Type, Callable +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type from typing_extensions import Protocol import jinja2 @@ -99,6 +100,41 @@ def _compile(self, source, filename): return super()._compile(source, filename) # type: ignore +class MacroFuzzTemplate(jinja2.nativetypes.NativeTemplate): + environment_class = MacroFuzzEnvironment + + def new_context( + self, + vars: Optional[Dict[str, Any]] = None, + shared: bool = False, + locals: Optional[Mapping[str, Any]] = None, + ) -> jinja2.runtime.Context: + # This custom override makes the assumption that the locals and shared + # parameters are not used, so enforce that. + if shared or locals: + raise Exception("The MacroFuzzTemplate.new_context() override cannot use the shared or locals parameters.") + + parent = ChainMap(vars, self.globals) if self.globals else vars + + return self.environment.context_class(self.environment, parent, self.name, self.blocks) + + def render(self, *args: Any, **kwargs: Any) -> Any: + if kwargs or len(args) != 1: + raise Exception("The MacroFuzzTemplate.render() override requires exactly one argument.") + + ctx = self.new_context(args[0]) + + try: + return self.environment_class.concat( # type: ignore + self.root_render_func(ctx) # type: ignore + ) + except Exception: + return self.environment.handle_exception() + + +MacroFuzzEnvironment.template_class = MacroFuzzTemplate + + class NativeSandboxEnvironment(MacroFuzzEnvironment): code_generator_class = jinja2.nativetypes.NativeCodeGenerator @@ -171,7 +207,7 @@ def render(self, *args, **kwargs): with :func:`ast.literal_eval`, the parsed value is returned. Otherwise, the string is returned. """ - vars = dict(*args, **kwargs) + vars = args[0] try: return quoted_native_concat(self.root_render_func(self.new_context(vars))) @@ -226,7 +262,7 @@ def get_macro(self): # make_module is in jinja2.environment. It returns a TemplateModule module = template.make_module(vars=self.context, shared=False) macro = module.__dict__[get_dbt_macro_name(name)] - module.__dict__.update(self.context) + return macro @contextmanager