From fd50136dc62c8fd74d55f42aa3118bca18e81a6b Mon Sep 17 00:00:00 2001 From: "albert.franzi" <3647015+afranzi@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:14:04 +0100 Subject: [PATCH 1/3] feat: Allow Macros & Filters definitions --- dbt_common/clients/jinja.py | 22 +++++++++++++++++++-- dbt_common/utils/importer.py | 38 ++++++++++++++++++++++++++++++++++++ tests/__init__.py | 0 tests/unit/__init__.py | 0 tests/unit/test_jinja.py | 20 +++++++++++++++++-- 5 files changed, 76 insertions(+), 4 deletions(-) create mode 100644 dbt_common/utils/importer.py create mode 100644 tests/__init__.py create mode 100644 tests/unit/__init__.py diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index 44d3eade..5f4b4f54 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -6,7 +6,7 @@ from collections import ChainMap from contextlib import contextmanager from itertools import chain, islice -from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type, Literal from typing_extensions import Protocol import jinja2 # type: ignore @@ -34,7 +34,7 @@ UndefinedCompilationError, ) from dbt_common.exceptions.macros import MacroReturn, UndefinedMacroError, CaughtMacroError - +from dbt_common.utils.importer import import_from_string SUPPORTED_LANG_ARG = jinja2.nodes.Name("supported_languages", "param") @@ -461,6 +461,14 @@ def __reduce__(self): } +def import_user_defined_code(code: Literal['MACROS', 'FILTERS']) -> dict[str, Callable] | None: + module_path = os.environ.get(f"DBT_USER_DEFINED_{code}") + user_defined_code = None + if module_path: + user_defined_code = import_from_string(module_path) + return user_defined_code + + def get_environment( node=None, capture_macros: bool = False, @@ -488,6 +496,16 @@ def get_environment( env = env_cls(**args) env.filters.update(filters) + # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. + # https://jinja.palletsprojects.com/en/3.0.x/api/#jinja2.Environment.globals + user_defined_macros = import_user_defined_code('MACROS') + if user_defined_macros: + env.globals.update(user_defined_macros) + + user_defined_filters = import_user_defined_code('FILTERS') + if user_defined_filters: + env.filters.update(user_defined_filters) + return env diff --git a/dbt_common/utils/importer.py b/dbt_common/utils/importer.py new file mode 100644 index 00000000..338eba25 --- /dev/null +++ b/dbt_common/utils/importer.py @@ -0,0 +1,38 @@ +import importlib +from typing import Any + + +class ImportFromStringError(Exception): + pass + + +def import_from_string(import_str: Any) -> Any: + if not isinstance(import_str, str): + return import_str + + module_str, _, attrs_str = import_str.partition(":") + if not module_str or not attrs_str: + message = ( + 'Import string "{import_str}" must be in format ":".' + ) + raise ImportFromStringError(message.format(import_str=import_str)) + + try: + module = importlib.import_module(module_str) + except ModuleNotFoundError as exc: + if exc.name != module_str: + raise exc from None + message = 'Could not import module "{module_str}".' + raise ImportFromStringError(message.format(module_str=module_str)) + + instance = module + try: + for attr_str in attrs_str.split("."): + instance = getattr(instance, attr_str) + except AttributeError: + message = 'Attribute "{attrs_str}" not found in module "{module_str}".' + raise ImportFromStringError( + message.format(attrs_str=attrs_str, module_str=module_str) + ) + + return instance diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_jinja.py b/tests/unit/test_jinja.py index f038a1ec..3edfa8fe 100644 --- a/tests/unit/test_jinja.py +++ b/tests/unit/test_jinja.py @@ -1,8 +1,12 @@ +import os import unittest -from dbt_common.clients.jinja import extract_toplevel_blocks +from dbt_common.clients.jinja import extract_toplevel_blocks, get_environment from dbt_common.exceptions import CompilationError +dbt_macros = {"universe_number": lambda: 42} +dbt_filters = {"multiply": lambda x, y: x * y} + class TestBlockLexer(unittest.TestCase): def test_basic(self): @@ -17,6 +21,19 @@ def test_basic(self): self.assertEqual(blocks[0].contents, body) self.assertEqual(blocks[0].full_block, block_data) + def test_env(self): + os.environ['DBT_USER_DEFINED_MACROS'] = 'tests.unit.test_jinja:dbt_macros' + os.environ['DBT_USER_DEFINED_FILTERS'] = 'tests.unit.test_jinja:dbt_filters' + + body = ( + "{% set name = 'potato' %}" + "{{ name }} - {{ universe_number() | multiply(2) }}" + ) + env = get_environment(None, capture_macros=True) + template = env.from_string(body) + render = template.render({}) + self.assertEqual(render, "potato - 84") + def test_multiple(self): body_one = '{{ config(foo="bar") }}\r\nselect * from this.that\r\n' body_two = ( @@ -433,7 +450,6 @@ def test_if_endfor_newlines(self): + x_block ) - if_you_do_this_you_are_awful = """ {#} here is a comment with a block inside {% block x %} asdf {% endblock %} {#} {% do From b53a7f1f8ddc836f057bd412b96fedf044bddd4d Mon Sep 17 00:00:00 2001 From: "albert.franzi" <3647015+afranzi@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:20:10 +0100 Subject: [PATCH 2/3] feat: Allow Macros & Filters definitions --- .changes/unreleased/Features-20240207-172004.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Features-20240207-172004.yaml diff --git a/.changes/unreleased/Features-20240207-172004.yaml b/.changes/unreleased/Features-20240207-172004.yaml new file mode 100644 index 00000000..2b2bf0e4 --- /dev/null +++ b/.changes/unreleased/Features-20240207-172004.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Enable User Defined macros & filters in Jinja templating +time: 2024-02-07T17:20:04.295019+01:00 +custom: + Author: afranzi + Issue: "71" From c4d5394779d9f435911ab7c691a68b099fc575ec Mon Sep 17 00:00:00 2001 From: "albert.franzi" <3647015+afranzi@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:26:21 +0100 Subject: [PATCH 3/3] lint: Format code --- dbt_common/clients/jinja.py | 20 ++++++++++++++++---- dbt_common/utils/importer.py | 8 ++------ tests/unit/test_jinja.py | 9 +++------ 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index 5f4b4f54..c964fffe 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -6,7 +6,19 @@ from collections import ChainMap from contextlib import contextmanager from itertools import chain, islice -from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type, Literal +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Union, + Set, + Type, + Literal, +) from typing_extensions import Protocol import jinja2 # type: ignore @@ -461,7 +473,7 @@ def __reduce__(self): } -def import_user_defined_code(code: Literal['MACROS', 'FILTERS']) -> dict[str, Callable] | None: +def import_user_defined_code(code: Literal["MACROS", "FILTERS"]) -> dict[str, Callable] | None: module_path = os.environ.get(f"DBT_USER_DEFINED_{code}") user_defined_code = None if module_path: @@ -498,11 +510,11 @@ def get_environment( # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. # https://jinja.palletsprojects.com/en/3.0.x/api/#jinja2.Environment.globals - user_defined_macros = import_user_defined_code('MACROS') + user_defined_macros = import_user_defined_code("MACROS") if user_defined_macros: env.globals.update(user_defined_macros) - user_defined_filters = import_user_defined_code('FILTERS') + user_defined_filters = import_user_defined_code("FILTERS") if user_defined_filters: env.filters.update(user_defined_filters) diff --git a/dbt_common/utils/importer.py b/dbt_common/utils/importer.py index 338eba25..f77520ee 100644 --- a/dbt_common/utils/importer.py +++ b/dbt_common/utils/importer.py @@ -12,9 +12,7 @@ def import_from_string(import_str: Any) -> Any: module_str, _, attrs_str = import_str.partition(":") if not module_str or not attrs_str: - message = ( - 'Import string "{import_str}" must be in format ":".' - ) + message = 'Import string "{import_str}" must be in format ":".' raise ImportFromStringError(message.format(import_str=import_str)) try: @@ -31,8 +29,6 @@ def import_from_string(import_str: Any) -> Any: instance = getattr(instance, attr_str) except AttributeError: message = 'Attribute "{attrs_str}" not found in module "{module_str}".' - raise ImportFromStringError( - message.format(attrs_str=attrs_str, module_str=module_str) - ) + raise ImportFromStringError(message.format(attrs_str=attrs_str, module_str=module_str)) return instance diff --git a/tests/unit/test_jinja.py b/tests/unit/test_jinja.py index 3edfa8fe..877120b7 100644 --- a/tests/unit/test_jinja.py +++ b/tests/unit/test_jinja.py @@ -22,13 +22,10 @@ def test_basic(self): self.assertEqual(blocks[0].full_block, block_data) def test_env(self): - os.environ['DBT_USER_DEFINED_MACROS'] = 'tests.unit.test_jinja:dbt_macros' - os.environ['DBT_USER_DEFINED_FILTERS'] = 'tests.unit.test_jinja:dbt_filters' + os.environ["DBT_USER_DEFINED_MACROS"] = "tests.unit.test_jinja:dbt_macros" + os.environ["DBT_USER_DEFINED_FILTERS"] = "tests.unit.test_jinja:dbt_filters" - body = ( - "{% set name = 'potato' %}" - "{{ name }} - {{ universe_number() | multiply(2) }}" - ) + body = "{% set name = 'potato' %}" "{{ name }} - {{ universe_number() | multiply(2) }}" env = get_environment(None, capture_macros=True) template = env.from_string(body) render = template.render({})