diff --git a/docs/source/users/index.md b/docs/source/users/index.md index b52fb66bb..4226df56b 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -882,3 +882,72 @@ To allow more than one provider in the allowlist, repeat the runtime configurati ``` jupyter lab --AiExtension.allowed_providers=openai --AiExtension.allowed_providers=ai21 ``` + +### Model parameters +This configuration allows specifying arbitrary parameters that are unpacked and passed to the provider class. +This is useful for passing parameters such as model tuning that affect the response generation by the model. +This is also an appropriate place to pass in custom attributes required by certain providers/models. + +The accepted value should be a dictionary, with top level keys as the model id (provider:model_id), and value +should be any arbitrary dictionary which is unpacked and passed as is to the provider class. + +#### Configuring as a startup option +In this sample, the `bedrock` provider will be created with the value for `model_kwargs` when `ai21.j2-mid-v1` model is selected. + +```bash +jupyter lab --AiExtension.model_parameters {"bedrock:ai21.j2-mid-v1":{"model_kwargs":{"maxTokens":200}}} +``` +The above will result in the following LLM class to be generated. + +```python +BedrockProvider(model_kwargs={"maxTokens":200}, ...) +``` + +Here is another example, where `anthropic` provider will be created with the values for `max_tokens` and `temperature`, when `claude-2` model is selected. + + +```bash +jupyter lab --AiExtension.model_parameters {"anthropic:claude-2":{"max_tokens":1024,"temperature":0.9}} +``` +The above will result in the following LLM class to be generated. + +```python +AnthropicProvider(max_tokens=1024, temperature=0.9, ...) +``` + +#### Configuring as a config file +This configuration can also be specified in a config file in json format. The file should be named `jupyter_jupyter_ai_config.json` and saved in a path that JupyterLab can pick from. You can find this +path by running `jupyter --paths` command, and picking one of the paths from the `config` section. + +Here is an example of running the `jupyter --paths` command. + +```bash +(jupyter-ai-lab4) ➜ jupyter --paths +config: + /opt/anaconda3/envs/jupyter-ai-lab4/etc/jupyter + /Users/3coins/.jupyter + /Users/3coins/.local/etc/jupyter + /usr/3coins/etc/jupyter + /etc/jupyter +data: + /opt/anaconda3/envs/jupyter-ai-lab4/share/jupyter + /Users/3coins/Library/Jupyter + /Users/3coins/.local/share/jupyter + /usr/local/share/jupyter + /usr/share/jupyter +runtime: + /Users/3coins/Library/Jupyter/runtime +``` + +Here is an example for configuring the `bedrock` provider for `ai21.j2-mid-v1` model. +```json +{ + "AiExtension": { + "bedrock:ai21.j2-mid-v1": { + "model_kwargs": { + "maxTokens": 200 + } + } + } +} +``` diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index edc139d5e..a17067921 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -242,15 +242,6 @@ def allows_concurrency(self): return True -def pop_with_default(model: Mapping[str, Any], name: str, default: Any) -> Any: - try: - value = model.pop(name) - except KeyError as e: - return default - - return value - - class AI21Provider(BaseProvider, AI21): id = "ai21" name = "AI21" @@ -632,9 +623,6 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint): TextField( key="response_path", label="Response path (required)", format="jsonpath" ), - MultilineTextField( - key="endpoint_kwargs", label="Endpoint arguments", format="json" - ), ] def __init__(self, *args, **kwargs): @@ -644,15 +632,7 @@ def __init__(self, *args, **kwargs): request_schema=request_schema, response_path=response_path ) - endpoint_kwargs = pop_with_default(kwargs, "endpoint_kwargs", "{}") - endpoint_kwargs = json.loads(endpoint_kwargs) - - super().__init__( - *args, - **kwargs, - content_handler=content_handler, - endpoint_kwargs=endpoint_kwargs, - ) + super().__init__(*args, **kwargs, content_handler=content_handler) async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs) @@ -677,14 +657,8 @@ class BedrockProvider(BaseProvider, Bedrock): format="text", ), TextField(key="region_name", label="Region name (optional)", format="text"), - MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"), ] - def __init__(self, *args, **kwargs): - model_kwargs = pop_with_default(kwargs, "model_kwargs", "{}") - model_kwargs = json.loads(model_kwargs) - super().__init__(*args, **kwargs, model_kwargs=model_kwargs) - async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs) @@ -707,14 +681,8 @@ class BedrockChatProvider(BaseProvider, BedrockChat): format="text", ), TextField(key="region_name", label="Region name (optional)", format="text"), - MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"), ] - def __init__(self, *args, **kwargs): - model_kwargs = pop_with_default(kwargs, "model_kwargs", "{}") - model_kwargs = json.loads(model_kwargs) - super().__init__(*args, **kwargs, model_kwargs=model_kwargs) - async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 2f3f1388a..e5c852051 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -36,7 +36,8 @@ def __init__(self, retriever, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - self.llm = provider(**provider_params) + model_parameters = self.get_model_parameters(provider, provider_params) + self.llm = provider(**provider_params, **model_parameters) memory = ConversationBufferWindowMemory( memory_key="chat_history", return_messages=True, k=2 ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 7894dcfa5..5ffe65c7c 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -3,7 +3,7 @@ import traceback # necessary to prevent circular import -from typing import TYPE_CHECKING, Dict, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, Optional, Type from uuid import uuid4 from jupyter_ai.config_manager import ConfigManager, Logger @@ -23,10 +23,12 @@ def __init__( log: Logger, config_manager: ConfigManager, root_chat_handlers: Dict[str, "RootChatHandler"], + model_parameters: Dict[str, Dict], ): self.log = log self.config_manager = config_manager self._root_chat_handlers = root_chat_handlers + self.model_parameters = model_parameters self.parser = argparse.ArgumentParser() self.llm = None self.llm_params = None @@ -122,6 +124,13 @@ def get_llm_chain(self): self.llm_params = lm_provider_params return self.llm_chain + def get_model_parameters( + self, provider: Type[BaseProvider], provider_params: Dict[str, str] + ): + return self.model_parameters.get( + f"{provider.id}:{provider_params['model_id']}", {} + ) + def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 3638b6cfb..d329e05e2 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -40,7 +40,8 @@ def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - llm = provider(**provider_params) + model_parameters = self.get_model_parameters(provider, provider_params) + llm = provider(**provider_params, **model_parameters) if llm.is_chat_provider: prompt_template = ChatPromptTemplate.from_messages( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 5036c8dfb..cdddee92d 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -226,7 +226,9 @@ def __init__(self, root_dir: str, *args, **kwargs): def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): - llm = provider(**provider_params) + model_parameters = self.get_model_parameters(provider, provider_params) + llm = provider(**provider_params, **model_parameters) + self.llm = llm return llm diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index a2ecd5245..500bc9649 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -4,7 +4,7 @@ from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp -from traitlets import List, Unicode +from traitlets import Dict, List, Unicode from .chat_handlers import ( AskChatHandler, @@ -81,6 +81,17 @@ class AiExtension(ExtensionApp): config=True, ) + model_parameters = Dict( + key_trait=Unicode(), + value_trait=Dict(), + default_value={}, + help="""Key-value pairs for model id and corresponding parameters that + are passed to the provider class. The values are unpacked and passed to + the provider class as-is.""", + allow_none=True, + config=True, + ) + def initialize_settings(self): start = time.time() @@ -96,6 +107,9 @@ def initialize_settings(self): self.log.info(f"Configured model allowlist: {self.allowed_models}") self.log.info(f"Configured model blocklist: {self.blocked_models}") + + self.settings["model_parameters"] = self.model_parameters + # Fetch LM & EM providers self.settings["lm_providers"] = get_lm_providers( log=self.log, restrictions=restrictions @@ -147,6 +161,7 @@ def initialize_settings(self): "log": self.log, "config_manager": self.settings["jai_config_manager"], "root_chat_handlers": self.settings["jai_root_chat_handlers"], + "model_parameters": self.settings["model_parameters"], } default_chat_handler = DefaultChatHandler( **chat_handler_kwargs, chat_history=self.settings["chat_history"]