From a13b5a60914210352755219364eadb3dd19e10ec Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Mon, 10 Jun 2024 10:25:51 +0100 Subject: [PATCH] Prevent overriding `server_settings` on base provider class --- .../jupyter-ai-magics/jupyter_ai_magics/providers.py | 12 ++++++++++++ .../tests/test_provider_metaclass.py | 11 ++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index bef7efacc..3d39b22f4 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -197,6 +197,18 @@ def __new__(mcs, name, bases, namespace, **kwargs): return cls + @property + def server_settings(cls): + return cls._server_settings + + @server_settings.setter + def server_settings(cls, value): + if cls._server_settings is not None: + raise AttributeError("'server_settings' attribute was already set") + cls._server_settings = value + + _server_settings = None + class BaseProvider(BaseModel, metaclass=ProviderMetaclass): # diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_provider_metaclass.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_provider_metaclass.py index ddf99a245..359fe3774 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_provider_metaclass.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_provider_metaclass.py @@ -1,8 +1,10 @@ +from types import MappingProxyType from typing import ClassVar, Optional from langchain.pydantic_v1 import BaseModel +from pytest import raises -from ..providers import ProviderMetaclass +from ..providers import BaseProvider, ProviderMetaclass def test_provider_metaclass(): @@ -24,3 +26,10 @@ class Child(Base, Parent, metaclass=ProviderMetaclass): test: ClassVar[str] = "expected" assert Child.test == "expected" + + +def test_base_provider_server_settings_read_only(): + BaseProvider.server_settings = MappingProxyType({}) + + with raises(AttributeError, match="'server_settings' attribute was already set"): + BaseProvider.server_settings = MappingProxyType({})