Skip to content

Commit

Permalink
Update composition tests and add context builder merge logic
Browse files Browse the repository at this point in the history
Updated unit tests to reflect changes in overlay patterns and added merge keys in composition. Introduced new module `data_merge.py` and implemented merge strategies for combining configurations within the context builder.
  • Loading branch information
coordt committed Nov 10, 2024
1 parent 9137a77 commit ba65296
Show file tree
Hide file tree
Showing 7 changed files with 449 additions and 7 deletions.
10 changes: 6 additions & 4 deletions project_forge/configurations/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@
from pydantic_core.core_schema import ValidationInfo

from project_forge.configurations.pattern import Pattern, read_pattern_file
from project_forge.context_builder.data_merge import MergeMethods
from project_forge.core.exceptions import PathNotFoundError, RepoAuthError, RepoNotFoundError
from project_forge.core.io import parse_file
from project_forge.core.location import Location

SkippedHook = Literal["pre", "post", "all", "none"]
"""Types of hooks to skip."""

MergeMethods = Literal["overwrite", "nested-overwrite", "comprehensive"]
"""Types of merge methods."""


class Overlay(BaseModel):
"""An object describing how to overlay a pattern in a composition."""
Expand Down Expand Up @@ -127,7 +125,11 @@ def from_location(cls, location: Union[str, Location]) -> "Composition":
return cls(overlays=[Overlay(pattern_location=location)])

def cache_data(self) -> None:
"""Makes sure all the patterns are cached and have their pattern objects loaded."""
"""
Makes sure all the patterns are cached and have their pattern objects loaded.
Accessing the `pattern` property on the overlay will lazily load the pattern.
"""
for overlay in self.overlays:
_ = overlay.pattern

Expand Down
27 changes: 25 additions & 2 deletions project_forge/context_builder/context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Builds and manages the rendering context."""

import datetime
from typing import Callable
from typing import Callable, Mapping

from project_forge.configurations.composition import Composition
from project_forge.context_builder.data_merge import MERGE_FUNCTION, MergeMethods
from project_forge.context_builder.overlays import process_overlay
from project_forge.rendering.expressions import render_expression

Expand Down Expand Up @@ -36,5 +37,27 @@ def build_context(composition: Composition, ui: Callable) -> dict:
running_context[key] = render_expression(value, running_context)

for overlay in composition.overlays:
running_context.update(process_overlay(overlay, running_context, ui))
overlay_context = process_overlay(overlay, running_context, ui)
running_context = update_context(composition.merge_keys or {}, running_context, overlay_context)
return running_context


def update_context(merge_keys: Mapping[str, MergeMethods], left: dict, right: dict) -> dict:
"""Return a dict where the left is updated with the right according to the composition rules."""
left_keys = set(left.keys())
right_keys = set(right.keys())
common_keys = left_keys.intersection(right_keys)
new_keys = right_keys - common_keys
result = {}

for key, value in left.items():
if key in right:
merge_func = MERGE_FUNCTION[merge_keys.get(key.lower(), "comprehensive")]
result[key] = merge_func(value, right[key])
else:
result[key] = value

for key in new_keys:
result[key] = right[key]

return result
169 changes: 169 additions & 0 deletions project_forge/context_builder/data_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""Tools for merging data."""

import copy
import logging
from collections import OrderedDict
from functools import reduce
from typing import Any, Iterable, Literal, MutableMapping, TypeVar, overload

from immutabledict import immutabledict

logger = logging.getLogger(__name__)

T = TypeVar("T")


@overload
def freeze_data(obj: set | frozenset) -> frozenset: ...


@overload
def freeze_data(obj: tuple | list) -> tuple: ...


@overload
def freeze_data(obj: dict | OrderedDict | immutabledict) -> immutabledict: ...


@overload
def freeze_data(obj: str) -> str: ...


@overload
def freeze_data(obj: int) -> int: ...


@overload
def freeze_data(obj: float) -> float: ...


@overload
def freeze_data(obj: bytes) -> bytes: ...


def freeze_data(obj: Any) -> Any:
"""Check type and recursively return a new read-only object."""
if isinstance(obj, (str, int, float, bytes, type(None), bool)):
return obj
elif isinstance(obj, tuple) and type(obj) is not tuple: # assumed namedtuple
return type(obj)(*(freeze_data(i) for i in obj))
elif isinstance(obj, (tuple, list)):
return tuple(freeze_data(i) for i in obj)
elif isinstance(obj, (dict, OrderedDict, immutabledict)):
return immutabledict({k: freeze_data(v) for k, v in obj.items()})
elif isinstance(obj, (set, frozenset)):
return frozenset(freeze_data(i) for i in obj)
raise ValueError(obj)


def merge_iterables(iter1: Iterable, iter2: Iterable) -> set:
"""
Merge and de-duplicate a bunch of lists into a single list.
Order is not guaranteed.
Args:
iter1: An Iterable
iter2: An Iterable
Returns:
The merged, de-duplicated sequence as a set
"""
from itertools import chain

return set(chain(freeze_data(iter1), freeze_data(iter2)))


def update(left_val: T, right_val: T) -> T:
"""Do a `dict.update` on all the dicts."""
match left_val, right_val:
case (dict(), dict()):
return left_val | right_val # type: ignore[operator]
case _:
return right_val


def nested_overwrite(*dicts: dict) -> dict:
"""
Merges dicts deeply.
Args:
*dicts: List of dicts to merge with the first one as the base
Returns:
dict: The merged dict
"""

def merge_into(d1: dict, d2: dict) -> dict:
for key, value in d2.items():
if key not in d1 or not isinstance(d1[key], dict):
d1[key] = copy.deepcopy(value)
else:
d1[key] = merge_into(d1[key], value)
return d1

return reduce(merge_into, dicts, {})


def comprehensive_merge(left_val: T, right_val: T) -> T:
"""
Merges data comprehensively.
All arguments must be of the same type.
- Scalars are overwritten by the new values
- lists are merged and de-duplicated
- dicts are recursively merged
Args:
left_val: The item to merge into
right_val: The item to merge from
Returns:
The merged data
"""
dict_types = (dict, OrderedDict, immutabledict)
iterable_types = (list, set, tuple)

def merge_into(d1: Any, d2: Any) -> Any:
if isinstance(d1, dict_types) and isinstance(d2, dict_types):
if isinstance(d1, OrderedDict) or isinstance(d2, OrderedDict):
od1: MutableMapping[Any, Any] = OrderedDict(d1)
od2: MutableMapping[Any, Any] = OrderedDict(d2)
else:
od1 = dict(d1)
od2 = dict(d2)

for key in od2:
od1[key] = merge_into(od1[key], od2[key]) if key in od1 else copy.deepcopy(od2[key])
return od1 # type: ignore[return-value]
elif isinstance(d1, list) and isinstance(d2, iterable_types):
return list(merge_iterables(d1, d2))
elif isinstance(d1, set) and isinstance(d2, iterable_types):
return merge_iterables(d1, d2)
elif isinstance(d1, tuple) and isinstance(d2, iterable_types):
return tuple(merge_iterables(d1, d2))
else:
return copy.deepcopy(d2)

return merge_into(left_val, right_val)


# Strategies merging data.
MergeMethods = Literal["overwrite", "comprehensive"]

UPDATE = "update"
"""Overwrite at the top level like `dict.update()`."""

COMPREHENSIVE = "comprehensive"
"""Comprehensively merge the two data structures.
- Scalars are overwritten by the new values
- lists are merged and de-duplicated
- dicts are recursively merged
"""

MERGE_FUNCTION = {
COMPREHENSIVE: comprehensive_merge,
UPDATE: update,
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"click>=8.1.7",
"pydantic-settings>=2.4.0",
"platformdirs>=4.2.2",
"immutabledict>=4.2.0",
]
authors = [
{ name = "Calloway Project", email = "[email protected]" },
Expand Down
49 changes: 48 additions & 1 deletion tests/test_context_builder/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import datetime

from project_forge.context_builder.context import build_context, get_starting_context
from project_forge.context_builder.context import build_context, get_starting_context, update_context
from unittest.mock import Mock, patch


Expand All @@ -25,6 +25,7 @@ def test_build_context_with_extra_context_and_overlays_composes_correct():
patch("project_forge.context_builder.context.process_overlay") as mock_process_overlay,
):
composition = Mock()
composition.merge_keys = {}
composition.extra_context = {"key": "{{ value }}", "overlay_key": "I should get overwritten"}
composition.overlays = ["overlay1", "overlay2"]

Expand Down Expand Up @@ -67,3 +68,49 @@ def test_build_context_with_empty_composition_is_starting_context():
mock_render_expression.assert_not_called()
mock_process_overlay.assert_not_called()
assert mock_get_starting_context.called


class TestUpdateContext:
"""Tests for the update_context function."""

def test_default_behavior_uses_comprehensive_merge(self):
"""The result should contain all the keys and the values should be merged comprehensively."""
# Assemble
merge_keys = {}
left = {"a": 1, "b": [1, 2, 3], "c": 3}
right = {"a": 2, "b": [4, 5, 6], "d": 4}
expected_result = {"a": 2, "b": [1, 2, 3, 4, 5, 6], "c": 3, "d": 4}

# Act
result = update_context(merge_keys, left, right)

# Assert
assert result == expected_result, f"Expected {expected_result}, but got {result}"

def test_updating_empty_dicts_returns_empty_dict(self):
"""Updating an empty dict with an empty dict should return an empty dict."""
# Assemble
merge_keys = {"a": "update", "b": "nested_overwrite"}
left = {}
right = {}
expected_result = {}

# Act
result = update_context(merge_keys, left, right)

# Assert
assert result == expected_result, f"Expected {expected_result}, but got {result}"

def test_respects_methods_in_merge_keys(self):
"""Update context should use the specified merge strategy."""
# Assemble
merge_keys = {"b": "update"}
left = {"a": 1, "b": [1, 2, 3], "c": 3}
right = {"a": 2, "b": [4, 5, 6], "d": 4}
expected_result = {"a": 2, "b": [4, 5, 6], "c": 3, "d": 4}

# Act
result = update_context(merge_keys, left, right)

# Assert
assert result == expected_result, f"Expected {expected_result}, but got {result}"
Loading

0 comments on commit ba65296

Please sign in to comment.