diff --git a/.changes/unreleased/Features-20240207-172004.yaml b/.changes/unreleased/Features-20240207-172004.yaml new file mode 100644 index 00000000..2b2bf0e4 --- /dev/null +++ b/.changes/unreleased/Features-20240207-172004.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Enable User Defined macros & filters in Jinja templating +time: 2024-02-07T17:20:04.295019+01:00 +custom: + Author: afranzi + Issue: "71" diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index 44d3eade..c964fffe 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -6,7 +6,19 @@ from collections import ChainMap from contextlib import contextmanager from itertools import chain, islice -from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Union, + Set, + Type, + Literal, +) from typing_extensions import Protocol import jinja2 # type: ignore @@ -34,7 +46,7 @@ UndefinedCompilationError, ) from dbt_common.exceptions.macros import MacroReturn, UndefinedMacroError, CaughtMacroError - +from dbt_common.utils.importer import import_from_string SUPPORTED_LANG_ARG = jinja2.nodes.Name("supported_languages", "param") @@ -461,6 +473,14 @@ def __reduce__(self): } +def import_user_defined_code(code: Literal["MACROS", "FILTERS"]) -> dict[str, Callable] | None: + module_path = os.environ.get(f"DBT_USER_DEFINED_{code}") + user_defined_code = None + if module_path: + user_defined_code = import_from_string(module_path) + return user_defined_code + + def get_environment( node=None, capture_macros: bool = False, @@ -488,6 +508,16 @@ def get_environment( env = env_cls(**args) env.filters.update(filters) + # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. + # https://jinja.palletsprojects.com/en/3.0.x/api/#jinja2.Environment.globals + user_defined_macros = import_user_defined_code("MACROS") + if user_defined_macros: + env.globals.update(user_defined_macros) + + user_defined_filters = import_user_defined_code("FILTERS") + if user_defined_filters: + env.filters.update(user_defined_filters) + return env diff --git a/dbt_common/utils/importer.py b/dbt_common/utils/importer.py new file mode 100644 index 00000000..f77520ee --- /dev/null +++ b/dbt_common/utils/importer.py @@ -0,0 +1,34 @@ +import importlib +from typing import Any + + +class ImportFromStringError(Exception): + pass + + +def import_from_string(import_str: Any) -> Any: + if not isinstance(import_str, str): + return import_str + + module_str, _, attrs_str = import_str.partition(":") + if not module_str or not attrs_str: + message = 'Import string "{import_str}" must be in format ":".' + raise ImportFromStringError(message.format(import_str=import_str)) + + try: + module = importlib.import_module(module_str) + except ModuleNotFoundError as exc: + if exc.name != module_str: + raise exc from None + message = 'Could not import module "{module_str}".' + raise ImportFromStringError(message.format(module_str=module_str)) + + instance = module + try: + for attr_str in attrs_str.split("."): + instance = getattr(instance, attr_str) + except AttributeError: + message = 'Attribute "{attrs_str}" not found in module "{module_str}".' + raise ImportFromStringError(message.format(attrs_str=attrs_str, module_str=module_str)) + + return instance diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_jinja.py b/tests/unit/test_jinja.py index f038a1ec..877120b7 100644 --- a/tests/unit/test_jinja.py +++ b/tests/unit/test_jinja.py @@ -1,8 +1,12 @@ +import os import unittest -from dbt_common.clients.jinja import extract_toplevel_blocks +from dbt_common.clients.jinja import extract_toplevel_blocks, get_environment from dbt_common.exceptions import CompilationError +dbt_macros = {"universe_number": lambda: 42} +dbt_filters = {"multiply": lambda x, y: x * y} + class TestBlockLexer(unittest.TestCase): def test_basic(self): @@ -17,6 +21,16 @@ def test_basic(self): self.assertEqual(blocks[0].contents, body) self.assertEqual(blocks[0].full_block, block_data) + def test_env(self): + os.environ["DBT_USER_DEFINED_MACROS"] = "tests.unit.test_jinja:dbt_macros" + os.environ["DBT_USER_DEFINED_FILTERS"] = "tests.unit.test_jinja:dbt_filters" + + body = "{% set name = 'potato' %}" "{{ name }} - {{ universe_number() | multiply(2) }}" + env = get_environment(None, capture_macros=True) + template = env.from_string(body) + render = template.render({}) + self.assertEqual(render, "potato - 84") + def test_multiple(self): body_one = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' body_two = ( @@ -433,7 +447,6 @@ def test_if_endfor_newlines(self): + x_block ) - if_you_do_this_you_are_awful = """ {#} here is a comment with a block inside {% block x %} asdf {% endblock %} {#} {% do