From f9ce8e3b1747359b74f6a7a84c7ed2453638c683 Mon Sep 17 00:00:00 2001 From: Philipp Temminghoff Date: Sun, 8 Dec 2024 14:15:47 +0100 Subject: [PATCH] chore: add merge_models helper --- src/llmling/config/utils.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/llmling/config/utils.py b/src/llmling/config/utils.py index 641ed91..dae6b86 100644 --- a/src/llmling/config/utils.py +++ b/src/llmling/config/utils.py @@ -3,7 +3,9 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar + +from pydantic import BaseModel from llmling.config.manager import ConfigManager from llmling.config.models import ( @@ -24,6 +26,7 @@ from llmling.tools.toolsets import ToolSet logger = get_logger(__name__) +T = TypeVar("T", bound=BaseModel) def toolset_config_to_toolset(config) -> ToolSet: @@ -83,3 +86,32 @@ def prepare_runtime( raise TypeError(msg) return runtime_cls.from_config(config) + + +def merge_models(base: T, overlay: T) -> T: + """Deep merge two Pydantic models.""" + if not isinstance(overlay, type(base)): + msg = f"Cannot merge different types: {type(base)} and {type(overlay)}" + raise TypeError(msg) + + # Start with copy of base + merged_data = base.model_dump() + + # Get overlay data (excluding None values) + overlay_data = overlay.model_dump(exclude_none=True) + + for field_name, field_value in overlay_data.items(): + base_value = merged_data.get(field_name) + + match (base_value, field_value): + case (list(), list()): + merged_data[field_name] = [ + *base_value, + *(item for item in field_value if item not in base_value), + ] + case (dict(), dict()): + merged_data[field_name] = base_value | field_value + case _: + merged_data[field_name] = field_value + + return base.__class__.model_validate(merged_data)