Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Aug 6, 2024
1 parent f96b202 commit e5b3e10
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 13 deletions.
1 change: 1 addition & 0 deletions packages/jupyter-ai-test/jupyter_ai_test/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class TestProviderWithStreaming(BaseProvider, TestLLMWithStreaming):
"""User inputs expected by this provider when initializing it. Each `Field` `f`
should be passed in the constructor as a keyword argument, keyed by `f.key`."""


class TestProviderAskLearnUnsupported(BaseProvider, TestLLMWithStreaming):
id: ClassVar[str] = "test-provider-ask-learn-unsupported"
"""ID for this provider class."""
Expand Down
18 changes: 11 additions & 7 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class BaseChatHandler:
"""Format string template that is used to build the help message. Specified
from traitlets configuration."""

chat_handlers: Dict[str, 'BaseChatHandler']
chat_handlers: Dict[str, "BaseChatHandler"]
"""Dictionary of chat handlers. Allows one chat handler to reference other
chat handlers, which is necessary for some use-cases like printing the help
message."""
Expand All @@ -127,7 +127,7 @@ def __init__(
preferred_dir: Optional[str],
dask_client_future: Awaitable[DaskClient],
help_message_template: str,
chat_handlers: Dict[str, 'BaseChatHandler'],
chat_handlers: Dict[str, "BaseChatHandler"],
):
self.log = log
self.config_manager = config_manager
Expand Down Expand Up @@ -381,16 +381,18 @@ def output_dir(self) -> str:
return self.preferred_dir
else:
return self.root_dir

def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> None:
"""Sends a help message to all connected clients."""
lm_provider = self.config_manager.lm_provider
unsupported_slash_commands = lm_provider.unsupported_slash_commands if lm_provider else set()
unsupported_slash_commands = (
lm_provider.unsupported_slash_commands if lm_provider else set()
)
chat_handlers = self.chat_handlers
slash_commands = {k: v for k, v in chat_handlers.items() if k != "default" }
slash_commands = {k: v for k, v in chat_handlers.items() if k != "default"}
for key in unsupported_slash_commands:
del slash_commands[key]

# markdown string that lists the slash commands
slash_commands_list = "\n".join(
[
Expand All @@ -399,7 +401,9 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non
]
)

help_message_body = self.help_message_template.format(persona_name=self.persona.name, slash_commands_list=slash_commands_list)
help_message_body = self.help_message_template.format(
persona_name=self.persona.name, slash_commands_list=slash_commands_list
)
help_message = AgentChatMessage(
id=uuid4().hex,
time=time.time(),
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ async def process_message(self, _):
break

# re-send help message
self.send_help_message()
self.send_help_message()
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .base import BaseChatHandler, SlashCommandRoutingType


class HelpChatHandler(BaseChatHandler):
id = "help"
name = "Help"
Expand All @@ -15,4 +16,4 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def process_message(self, message: HumanChatMessage):
self.send_help_message(message)
self.send_help_message(message)
9 changes: 5 additions & 4 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ class AiExtension(ExtensionApp):
'Jupyternaut'.
- `slash_commands_list`: A string containing a bulleted list of the
slash commands available to the configured language model.
slash commands available to the configured language model.
""",
config=True
config=True,
)

def initialize_settings(self):
Expand Down Expand Up @@ -369,9 +369,10 @@ def _show_help_message(self):
# call `send_help_message()` on any instance of `BaseChatHandler`. The
# `default` chat handler should always exist, so we reference that
# object when calling `send_help_message()`.
default_chat_handler: DefaultChatHandler = self.settings["jai_chat_handlers"]["default"]
default_chat_handler: DefaultChatHandler = self.settings["jai_chat_handlers"][
"default"
]
default_chat_handler.send_help_message()


async def _get_dask_client(self):
return DaskClient(processes=False, asynchronous=True)
Expand Down

0 comments on commit e5b3e10

Please sign in to comment.