Skip to content

Commit

Permalink
Document how to add custom model providers (#420)
Browse files Browse the repository at this point in the history
* Document how to add custom model providers

* Apply suggestions from review

Co-authored-by: Jason Weill <[email protected]>

---------

Co-authored-by: Jason Weill <[email protected]>
  • Loading branch information
krassowski and JasonWeill authored Oct 30, 2023
1 parent b06e259 commit 7e4a2a5
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 6 deletions.
92 changes: 91 additions & 1 deletion docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,96 @@ responsible for all charges they incur when they make API requests. Review your
provider's pricing information before submitting requests via Jupyter AI.
:::

### Custom model providers

You can define new providers using the LangChain framework API. Custom providers
inherit from both `jupyter-ai`'s ``BaseProvider`` and `langchain`'s [``LLM``][LLM].
You can either import a pre-defined model from [LangChain LLM list][langchain_llms],
or define a [custom LLM][custom_llm].
In the example below, we define a provider with two models using
a dummy ``FakeListLLM`` model, which returns responses from the ``responses``
keyword argument.

```python
# my_package/my_provider.py
from jupyter_ai_magics import BaseProvider
from langchain.llms import FakeListLLM


class MyProvider(BaseProvider, FakeListLLM):
id = "my_provider"
name = "My Provider"
model_id_key = "model"
models = [
"model_a",
"model_b"
]
def __init__(self, **kwargs):
model = kwargs.get("model_id")
kwargs["responses"] = (
["This is a response from model 'a'"]
if model == "model_a" else
["This is a response from model 'b'"]
)
super().__init__(**kwargs)
```


If the new provider inherits from [``BaseChatModel``][BaseChatModel], it will be available
both in the chat UI and with magic commands. Otherwise, users can only use the new provider
with magic commands.

To make the new provider available, you need to declare it as an [entry point](https://setuptools.pypa.io/en/latest/userguide/entry_point.html):

```toml
# my_package/pyproject.toml
[project]
name = "my_package"
version = "0.0.1"

[project.entry-points."jupyter_ai.model_providers"]
my-provider = "my_provider:MyProvider"
```

To test that the above minimal provider package works, install it with:

```sh
# from `my_package` directory
pip install -e .
```

Then, restart JupyterLab. You should now see an info message in the log that mentions
your new provider's `id`:

```
[I 2023-10-29 13:56:16.915 AiExtension] Registered model provider `my_provider`.
```

[langchain_llms]: https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.llms
[custom_llm]: https://python.langchain.com/docs/modules/model_io/models/llms/custom_llm
[LLM]: https://api.python.langchain.com/en/latest/llms/langchain.llms.base.LLM.html#langchain.llms.base.LLM
[BaseChatModel]: https://api.python.langchain.com/en/latest/chat_models/langchain.chat_models.base.BaseChatModel.html


### Customizing prompt templates

To modify the prompt template for a given format, override the ``get_prompt_template`` method:

```python
from langchain.prompts import PromptTemplate


class MyProvider(BaseProvider, FakeListLLM):
# (... properties as above ...)
def get_prompt_template(self, format) -> PromptTemplate:
if format === "code":
return PromptTemplate.from_template(
"{prompt}\n\nProduce output as source code only, "
"with no text or explanation before or after it."
)
return super().get_prompt_template(format)
```

## The chat interface

The easiest way to get started with Jupyter AI is to use the chat interface.
Expand Down Expand Up @@ -689,7 +779,7 @@ Write a poem about C++.

You can also define a custom LangChain chain:

```
```python
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
Expand Down
10 changes: 6 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ def get_lm_providers(
for model_provider_ep in model_provider_eps:
try:
provider = model_provider_ep.load()
except:
except Exception as e:
log.error(
f"Unable to load model provider class from entry point `{model_provider_ep.name}`."
f"Unable to load model provider class from entry point `{model_provider_ep.name}`: %s.",
e,
)
continue
if not is_provider_allowed(provider.id, restrictions):
Expand All @@ -58,9 +59,10 @@ def get_em_providers(
for model_provider_ep in model_provider_eps:
try:
provider = model_provider_ep.load()
except:
except Exception as e:
log.error(
f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`."
f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`: %s.",
e,
)
continue
if not is_provider_allowed(provider.id, restrictions):
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def broadcast_message(self, message: Message):
self.chat_history.append(message)

async def on_message(self, message):
self.log.debug("Message recieved: %s", message)
self.log.debug("Message received: %s", message)

try:
message = json.loads(message)
Expand Down

0 comments on commit 7e4a2a5

Please sign in to comment.