diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 8beece4f6..5376c29b0 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -839,6 +839,16 @@ Aliases' names can contain ASCII letters (uppercase and lowercase), numbers, hyp Aliases must refer to models or `LLMChain` objects; they cannot refer to other aliases. +To customize the aliases on startup you can set the `c.AiMagics.aliases` tratilet in `ipython_config.py`, for example: + +```python +c.AiMagics.aliases = { + "my_custom_alias": "my_provider:my_model" +} +``` + +The location of `ipython_config.py` file is documented in [IPython configuration reference](https://ipython.readthedocs.io/en/stable/config/intro.html). + ### Using magic commands with SageMaker endpoints You can use magic commands with models hosted using Amazon SageMaker. diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index d605e37f2..e83d77d41 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -8,6 +8,7 @@ from typing import Optional import click +import traitlets from IPython import get_ipython from IPython.core.magic import Magics, line_cell_magic, magics_class from IPython.display import HTML, JSON, Markdown, Math @@ -122,6 +123,19 @@ class CellMagicError(BaseException): @magics_class class AiMagics(Magics): + + aliases = traitlets.Dict( + default_value=MODEL_ID_ALIASES, + value_trait=traitlets.Unicode(), + key_trait=traitlets.Unicode(), + help="""Aliases for model identifiers. + + Keys define aliases, values define the provider and the model to use. + The values should include identifiers in in the `provider:model` format. + """, + config=True, + ) + def __init__(self, shell): super().__init__(shell) self.transcript_openai = [] @@ -145,7 +159,7 @@ def __init__(self, shell): self.providers = get_lm_providers() # initialize a registry of custom model/chain names - self.custom_model_registry = MODEL_ID_ALIASES + self.custom_model_registry = self.aliases def _ai_bulleted_list_models_for_provider(self, provider_id, Provider): output = "" diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py new file mode 100644 index 000000000..ec1635347 --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py @@ -0,0 +1,11 @@ +from IPython import InteractiveShell +from traitlets.config.loader import Config + + +def test_aliases_config(): + ip = InteractiveShell() + ip.config = Config() + ip.config.AiMagics.aliases = {"my_custom_alias": "my_provider:my_model"} + ip.extension_manager.load_extension("jupyter_ai_magics") + providers_list = ip.run_line_magic("ai", "list").text + assert "my_custom_alias" in providers_list