Skip to content

Commit

Permalink
Make magic aliases user-customizable
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski committed Jul 17, 2024
1 parent 12d069e commit e82a451
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
10 changes: 10 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,16 @@ Aliases' names can contain ASCII letters (uppercase and lowercase), numbers, hyp

Aliases must refer to models or `LLMChain` objects; they cannot refer to other aliases.

To customize the aliases on startup you can set the `c.AiMagics.aliases` tratilet in `ipython_config.py`, for example:

```python
c.AiMagics.aliases = {
"my_custom_alias": "my_provider:my_model"
}
```

The location of `ipython_config.py` file is documented in [IPython configuration reference](https://ipython.readthedocs.io/en/stable/config/intro.html).

### Using magic commands with SageMaker endpoints

You can use magic commands with models hosted using Amazon SageMaker.
Expand Down
16 changes: 15 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional

import click
import traitlets
from IPython import get_ipython
from IPython.core.magic import Magics, line_cell_magic, magics_class
from IPython.display import HTML, JSON, Markdown, Math
Expand Down Expand Up @@ -122,6 +123,19 @@ class CellMagicError(BaseException):

@magics_class
class AiMagics(Magics):

aliases = traitlets.Dict(
default_value=MODEL_ID_ALIASES,
value_trait=traitlets.Unicode(),
key_trait=traitlets.Unicode(),
help="""Aliases for model identifiers.
Keys define aliases, values define the provider and the model to use.
The values should include identifiers in in the `provider:model` format.
""",
config=True,
)

def __init__(self, shell):
super().__init__(shell)
self.transcript_openai = []
Expand All @@ -145,7 +159,7 @@ def __init__(self, shell):
self.providers = get_lm_providers()

# initialize a registry of custom model/chain names
self.custom_model_registry = MODEL_ID_ALIASES
self.custom_model_registry = self.aliases

def _ai_bulleted_list_models_for_provider(self, provider_id, Provider):
output = ""
Expand Down
11 changes: 11 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from IPython import InteractiveShell
from traitlets.config.loader import Config


def test_aliases_config():
ip = InteractiveShell()
ip.config = Config()
ip.config.AiMagics.aliases = {"my_custom_alias": "my_provider:my_model"}
ip.extension_manager.load_extension("jupyter_ai_magics")
providers_list = ip.run_line_magic("ai", "list").text
assert "my_custom_alias" in providers_list

0 comments on commit e82a451

Please sign in to comment.