Skip to content

Commit

Permalink
chore: add merge_models helper
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Dec 8, 2024
1 parent 5c8e7fe commit f9ce8e3
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion src/llmling/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar

from pydantic import BaseModel

from llmling.config.manager import ConfigManager
from llmling.config.models import (
Expand All @@ -24,6 +26,7 @@
from llmling.tools.toolsets import ToolSet

logger = get_logger(__name__)
T = TypeVar("T", bound=BaseModel)


def toolset_config_to_toolset(config) -> ToolSet:
Expand Down Expand Up @@ -83,3 +86,32 @@ def prepare_runtime(
raise TypeError(msg)

return runtime_cls.from_config(config)


def merge_models(base: T, overlay: T) -> T:
"""Deep merge two Pydantic models."""
if not isinstance(overlay, type(base)):
msg = f"Cannot merge different types: {type(base)} and {type(overlay)}"
raise TypeError(msg)

# Start with copy of base
merged_data = base.model_dump()

# Get overlay data (excluding None values)
overlay_data = overlay.model_dump(exclude_none=True)

for field_name, field_value in overlay_data.items():
base_value = merged_data.get(field_name)

match (base_value, field_value):
case (list(), list()):
merged_data[field_name] = [
*base_value,
*(item for item in field_value if item not in base_value),
]
case (dict(), dict()):
merged_data[field_name] = base_value | field_value
case _:
merged_data[field_name] = field_value

return base.__class__.model_validate(merged_data)

0 comments on commit f9ce8e3

Please sign in to comment.