Skip to content

Commit

Permalink
Added configurable option for model parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins committed Nov 8, 2023
1 parent 960c692 commit aab2f6f
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 38 deletions.
69 changes: 69 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -882,3 +882,72 @@ To allow more than one provider in the allowlist, repeat the runtime configurati
```
jupyter lab --AiExtension.allowed_providers=openai --AiExtension.allowed_providers=ai21
```

### Model parameters
This configuration allows specifying arbitrary parameters that are unpacked and passed to the provider class.
This is useful for passing parameters such as model tuning that affect the response generation by the model.
This is also an appropriate place to pass in custom attributes required by certain providers/models.

The accepted value should be a dictionary, with top level keys as the model id (provider:model_id), and value
should be any arbitrary dictionary which is unpacked and passed as is to the provider class.

#### Configuring as a startup option
In this sample, the `bedrock` provider will be created with the value for `model_kwargs` when `ai21.j2-mid-v1` model is selected.

```bash
jupyter lab --AiExtension.model_parameters {"bedrock:ai21.j2-mid-v1":{"model_kwargs":{"maxTokens":200}}}
```
The above will result in the following LLM class to be generated.

```python
BedrockProvider(model_kwargs={"maxTokens":200}, ...)
```

Here is another example, where `anthropic` provider will be created with the values for `max_tokens` and `temperature`, when `claude-2` model is selected.


```bash
jupyter lab --AiExtension.model_parameters {"anthropic:claude-2":{"max_tokens":1024,"temperature":0.9}}
```
The above will result in the following LLM class to be generated.

```python
AnthropicProvider(max_tokens=1024, temperature=0.9, ...)
```

#### Configuring as a config file
This configuration can also be specified in a config file in json format. The file should be named `jupyter_jupyter_ai_config.json` and saved in a path that JupyterLab can pick from. You can find this
path by running `jupyter --paths` command, and picking one of the paths from the `config` section.

Here is an example of running the `jupyter --paths` command.

```bash
(jupyter-ai-lab4) ➜ jupyter --paths
config:
/opt/anaconda3/envs/jupyter-ai-lab4/etc/jupyter
/Users/3coins/.jupyter
/Users/3coins/.local/etc/jupyter
/usr/3coins/etc/jupyter
/etc/jupyter
data:
/opt/anaconda3/envs/jupyter-ai-lab4/share/jupyter
/Users/3coins/Library/Jupyter
/Users/3coins/.local/share/jupyter
/usr/local/share/jupyter
/usr/share/jupyter
runtime:
/Users/3coins/Library/Jupyter/runtime
```

Here is an example for configuring the `bedrock` provider for `ai21.j2-mid-v1` model.
```json
{
"AiExtension": {
"bedrock:ai21.j2-mid-v1": {
"model_kwargs": {
"maxTokens": 200
}
}
}
}
```
34 changes: 1 addition & 33 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,6 @@ def allows_concurrency(self):
return True


def pop_with_default(model: Mapping[str, Any], name: str, default: Any) -> Any:
try:
value = model.pop(name)
except KeyError as e:
return default

return value


class AI21Provider(BaseProvider, AI21):
id = "ai21"
name = "AI21"
Expand Down Expand Up @@ -632,9 +623,6 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
TextField(
key="response_path", label="Response path (required)", format="jsonpath"
),
MultilineTextField(
key="endpoint_kwargs", label="Endpoint arguments", format="json"
),
]

def __init__(self, *args, **kwargs):
Expand All @@ -644,15 +632,7 @@ def __init__(self, *args, **kwargs):
request_schema=request_schema, response_path=response_path
)

endpoint_kwargs = pop_with_default(kwargs, "endpoint_kwargs", "{}")
endpoint_kwargs = json.loads(endpoint_kwargs)

super().__init__(
*args,
**kwargs,
content_handler=content_handler,
endpoint_kwargs=endpoint_kwargs,
)
super().__init__(*args, **kwargs, content_handler=content_handler)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)
Expand All @@ -677,14 +657,8 @@ class BedrockProvider(BaseProvider, Bedrock):
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"),
]

def __init__(self, *args, **kwargs):
model_kwargs = pop_with_default(kwargs, "model_kwargs", "{}")
model_kwargs = json.loads(model_kwargs)
super().__init__(*args, **kwargs, model_kwargs=model_kwargs)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)

Expand All @@ -707,14 +681,8 @@ class BedrockChatProvider(BaseProvider, BedrockChat):
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"),
]

def __init__(self, *args, **kwargs):
model_kwargs = pop_with_default(kwargs, "model_kwargs", "{}")
model_kwargs = json.loads(model_kwargs)
super().__init__(*args, **kwargs, model_kwargs=model_kwargs)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self, retriever, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
self.llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
self.llm = provider(**provider_params, **model_parameters)
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
Expand Down
11 changes: 10 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import traceback

# necessary to prevent circular import
from typing import TYPE_CHECKING, Dict, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
from uuid import uuid4

from jupyter_ai.config_manager import ConfigManager, Logger
Expand All @@ -23,10 +23,12 @@ def __init__(
log: Logger,
config_manager: ConfigManager,
root_chat_handlers: Dict[str, "RootChatHandler"],
model_parameters: Dict[str, Dict],
):
self.log = log
self.config_manager = config_manager
self._root_chat_handlers = root_chat_handlers
self.model_parameters = model_parameters
self.parser = argparse.ArgumentParser()
self.llm = None
self.llm_params = None
Expand Down Expand Up @@ -122,6 +124,13 @@ def get_llm_chain(self):
self.llm_params = lm_provider_params
return self.llm_chain

def get_model_parameters(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
return self.model_parameters.get(
f"{provider.id}:{provider_params['model_id']}", {}
)

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

if llm.is_chat_provider:
prompt_template = ChatPromptTemplate.from_messages(
Expand Down
4 changes: 3 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def __init__(self, root_dir: str, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

self.llm = llm
return llm

Expand Down
16 changes: 15 additions & 1 deletion packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_server.extension.application import ExtensionApp
from traitlets import List, Unicode
from traitlets import Dict, List, Unicode

from .chat_handlers import (
AskChatHandler,
Expand Down Expand Up @@ -53,13 +53,26 @@ class AiExtension(ExtensionApp):
config=True,
)

model_parameters = Dict(
key_trait=Unicode(),
value_trait=Dict(),
default_value={},
help="""Key-value pairs for model id and corresponding parameters that
are passed to the provider class. The values are unpacked and passed to
the provider class as-is.""",
allow_none=True,
config=True,
)

def initialize_settings(self):
start = time.time()
restrictions = {
"allowed_providers": self.allowed_providers,
"blocked_providers": self.blocked_providers,
}

self.settings["model_parameters"] = self.model_parameters

self.settings["lm_providers"] = get_lm_providers(
log=self.log, restrictions=restrictions
)
Expand Down Expand Up @@ -107,6 +120,7 @@ def initialize_settings(self):
"log": self.log,
"config_manager": self.settings["jai_config_manager"],
"root_chat_handlers": self.settings["jai_root_chat_handlers"],
"model_parameters": self.settings["model_parameters"],
}
default_chat_handler = DefaultChatHandler(
**chat_handler_kwargs, chat_history=self.settings["chat_history"]
Expand Down

0 comments on commit aab2f6f

Please sign in to comment.