From f6d2ceceaaf8d7967e50fbd703a4957bf07472df Mon Sep 17 00:00:00 2001 From: Govinda Totla Date: Tue, 17 Dec 2024 12:44:11 +0530 Subject: [PATCH] Revamped Jupyter AI --- .../jupyter_ai_magics/prompts.py | 124 ++++++++++ .../jupyter_ai_magics/providers.py | 94 ++------ .../jupyter_ai/_variable_describer.py | 146 ++++++++++++ .../jupyter_ai/chat_handlers/base.py | 11 +- .../jupyter_ai/chat_handlers/clear.py | 2 +- .../jupyter_ai/chat_handlers/default.py | 43 +++- .../jupyter_ai/chat_handlers/export.py | 1 + .../jupyter-ai/jupyter_ai/config_manager.py | 38 +++- packages/jupyter-ai/jupyter_ai/extension.py | 33 ++- .../jupyter_ai/filter_autocomplete_globals.py | 43 ++++ packages/jupyter-ai/jupyter_ai/handlers.py | 13 +- packages/jupyter-ai/jupyter_ai/models.py | 22 +- packages/jupyter-ai/schema/plugin.json | 2 +- packages/jupyter-ai/src/completions/plugin.ts | 8 +- .../jupyter-ai/src/completions/provider.ts | 43 +++- .../jupyter-ai/src/components/chat-input.tsx | 122 ++++++++-- .../src/components/chat-input/send-button.tsx | 134 +---------- .../src/components/chat-messages.tsx | 129 ++++++++--- .../chat-messages/chat-message-menu.tsx | 187 ++++++++++++---- .../src/components/chat-settings.tsx | 211 ++++++++++++------ packages/jupyter-ai/src/components/chat.tsx | 162 ++++++++------ .../components/code-blocks/code-toolbar.tsx | 75 +++---- .../src/components/rendermime-markdown.tsx | 9 +- .../src/contexts/active-cell-context.tsx | 5 + .../src/contexts/notebook-tracker-context.tsx | 19 ++ .../jupyter-ai/src/contexts/utils-context.tsx | 20 ++ packages/jupyter-ai/src/handler.ts | 29 ++- packages/jupyter-ai/src/index.ts | 16 +- packages/jupyter-ai/src/utils.ts | 177 ++++++++++++++- .../jupyter-ai/src/widgets/chat-sidebar.tsx | 8 +- 30 files changed, 1392 insertions(+), 534 deletions(-) create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/prompts.py create mode 100644 packages/jupyter-ai/jupyter_ai/_variable_describer.py create mode 100644 packages/jupyter-ai/jupyter_ai/filter_autocomplete_globals.py create mode 100644 packages/jupyter-ai/src/contexts/notebook-tracker-context.tsx create mode 100644 packages/jupyter-ai/src/contexts/utils-context.tsx diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/prompts.py b/packages/jupyter-ai-magics/jupyter_ai_magics/prompts.py new file mode 100644 index 000000000..8e9fcedb4 --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/prompts.py @@ -0,0 +1,124 @@ +CHAT_SYSTEM_PROMPT = """ +You are Jupyternaut, a conversational assistant living in JupyterLab to help users. +You are an expert in Jupyter Notebook, Data Visualization, Data Science and Data Analysis. +You always use Markdown to format your response. +Code blocks must be formatted in Markdown. +Math should be rendered with inline TeX markup, surrounded by $. +If you do not know the answer to a question, answer truthfully by responding that you do not know. +The following is a friendly conversation between you and a human. +""".strip() + +CHAT_DEFAULT_TEMPLATE = """ + +{% if notebook_code %} +Below is a Jupyter notebook, in jupytext "percent" format that the Human is working with: + +{{notebook_code}} + +{% endif %} +{% if active_cell_id %} +The following cell is selected currently by the user: {{active_cell_id}} +{% endif %} +{% if selection and selection.type == 'text' %} +The following code is selected by the user in the active cell: + +{{selection.source}} + +{% endif %} +{% if variable_context %} +The kernel is currently running with some of the global variables and their values/types listed below: +{{variable_context}} +{% endif %} + +Unless asked otherwise, answer all questions as concisely and directly as possible. +That is, you directly output code, when asked, with minimal explanation. +Avoid writing cell ids or too many comments in the code. +""" + +COMPLETION_SYSTEM_PROMPT = """ +You are a python coding assistant capable of completing code. +You should only produce code. Avoid comments in the code. Produce clean code. +The code is written in JupyterLab, a data analysis and code development +environment which can execute code extended with additional syntax for +interactive features, such as magics. + +Here are some examples of Python code completion: + +Example 1: +Input: +def calculate_area(radius): + return 3.14 * [BLANK] + +Output: +radius ** 2 + +Example 2: +Input: +for i in range(10): + if i % 2 == 0: + [BLANK] + +Output: +print(i) + +Example 3: +Input: +try: + result = 10 / 0 +except [BLANK]: + print("Division by zero!") + +Output: +ZeroDivisionError + +Example 4: +Input: +import random + +numbers = [1, 2, 3, 4, 5] +random[BLANK] + +Output: +.shuffle(numbers) + +Example 5: +Input: +def quick_sort(arr): + # exp[BLANK] + +Output: +lanation: + +Example 6: +Input: +import pandas as pd +import numpy as np + +def create_random_dataframe(arr): + [BLANK] + +Output: +return pd.DataFrame(arr, columns=["column1", "column2"]) + +Example 7: +Input: +def get_params(): + return (1, 2) + +x, y[BLANK] + +print(x + y) + +Output: + = get_params() +""".strip() + +# only add the suffix bit if present to save input tokens/computation time +COMPLETION_DEFAULT_TEMPLATE = """ +Now, complete the following Python code being written in {{filename}}: + +{{prefix}}[BLANK]{{suffix}} + +Fill in the blank to complete the code block. +Your response should include only the code to replace [BLANK], without surrounding backticks. +Do not return a linebreak at the beginning of your response.""" \ No newline at end of file diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index fac868229..aeefe8189 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -49,66 +49,7 @@ InlineCompletionStreamChunk, ) from .models.persona import Persona - -CHAT_SYSTEM_PROMPT = """ -You are Jupyternaut, a conversational assistant living in JupyterLab to help users. -You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}. -You are talkative and you provide lots of specific details from the foundation model's context. -You may use Markdown to format your response. -If your response includes code, they must be enclosed in Markdown fenced code blocks (with triple backticks before and after). -If your response includes mathematical notation, they must be expressed in LaTeX markup and enclosed in LaTeX delimiters. -All dollar quantities (of USD) must be formatted in LaTeX, with the `$` symbol escaped by a single backslash `\\`. -- Example prompt: `If I have \\\\$100 and spend \\\\$20, how much money do I have left?` -- **Correct** response: `You have \\(\\$80\\) remaining.` -- **Incorrect** response: `You have $80 remaining.` -If you do not know the answer to a question, answer truthfully by responding that you do not know. -The following is a friendly conversation between you and a human. -""".strip() - -CHAT_DEFAULT_TEMPLATE = """ -{% if context %} -Context: -{{context}} - -{% endif %} -Current conversation: -{{history}} -Human: {{input}} -AI:""" - -HUMAN_MESSAGE_TEMPLATE = """ -{% if context %} -Context: -{{context}} - -{% endif %} -{{input}} -""" - -COMPLETION_SYSTEM_PROMPT = """ -You are an application built to provide helpful code completion suggestions. -You should only produce code. Keep comments to minimum, use the -programming language comment syntax. Produce clean code. -The code is written in JupyterLab, a data analysis and code development -environment which can execute code extended with additional syntax for -interactive features, such as magics. -""".strip() - -# only add the suffix bit if present to save input tokens/computation time -COMPLETION_DEFAULT_TEMPLATE = """ -The document is called `{{filename}}` and written in {{language}}. -{% if suffix %} -The code after the completion request is: - -``` -{{suffix}} -``` -{% endif %} - -Complete the following code: - -``` -{{prefix}}""" +from .prompts import CHAT_DEFAULT_TEMPLATE, CHAT_SYSTEM_PROMPT, COMPLETION_DEFAULT_TEMPLATE, COMPLETION_SYSTEM_PROMPT class EnvAuthStrategy(BaseModel): @@ -215,7 +156,6 @@ def server_settings(cls, value): _server_settings = None - class BaseProvider(BaseModel, metaclass=ProviderMetaclass): # # pydantic config @@ -291,6 +231,11 @@ class Config: Providers are not allowed to mutate this dictionary. """ + chat_system_prompt: str = CHAT_SYSTEM_PROMPT + chat_default_prompt: str = CHAT_DEFAULT_TEMPLATE + completion_system_prompt: str = COMPLETION_SYSTEM_PROMPT + completion_default_prompt: str = COMPLETION_DEFAULT_TEMPLATE + @classmethod def chat_models(self): """Models which are suitable for chat.""" @@ -353,6 +298,9 @@ def __init__(self, *args, **kwargs): } super().__init__(*args, **kwargs, **model_kwargs) + def process_notebook_for_context(self, code_cells: List[str], active_cell: Optional[int]) -> str: + return "\n\n".join(code_cells) + async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]: """ Calls self._call() asynchronously in a separate thread for providers @@ -409,23 +357,25 @@ def get_chat_prompt_template(self) -> PromptTemplate: return ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template( - CHAT_SYSTEM_PROMPT - ).format(provider_name=name, local_model_id=self.model_id), - MessagesPlaceholder(variable_name="history"), + self.chat_system_prompt, + template_format="jinja2" + ), HumanMessagePromptTemplate.from_template( - HUMAN_MESSAGE_TEMPLATE, - template_format="jinja2", + self.chat_default_prompt, + template_format="jinja2" ), + MessagesPlaceholder(variable_name="history"), + HumanMessagePromptTemplate.from_template("{input}"), ] ) else: return PromptTemplate( input_variables=["history", "input", "context"], - template=CHAT_SYSTEM_PROMPT.format( + template=self.chat_system_prompt.format( provider_name=name, local_model_id=self.model_id ) + "\n\n" - + CHAT_DEFAULT_TEMPLATE, + + self.chat_default_prompt, template_format="jinja2", ) @@ -437,18 +387,18 @@ def get_completion_prompt_template(self) -> PromptTemplate: if self.is_chat_provider: return ChatPromptTemplate.from_messages( [ - SystemMessagePromptTemplate.from_template(COMPLETION_SYSTEM_PROMPT), + SystemMessagePromptTemplate.from_template(self.completion_system_prompt), HumanMessagePromptTemplate.from_template( - COMPLETION_DEFAULT_TEMPLATE, template_format="jinja2" + self.completion_default_prompt, template_format="jinja2" ), ] ) else: return PromptTemplate( input_variables=["prefix", "suffix", "language", "filename"], - template=COMPLETION_SYSTEM_PROMPT + template=self.completion_default_prompt + "\n\n" - + COMPLETION_DEFAULT_TEMPLATE, + + self.completion_default_prompt, template_format="jinja2", ) diff --git a/packages/jupyter-ai/jupyter_ai/_variable_describer.py b/packages/jupyter-ai/jupyter_ai/_variable_describer.py new file mode 100644 index 000000000..f61b60721 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/_variable_describer.py @@ -0,0 +1,146 @@ +import pandas as pd +import inspect +import random +import types +from typing import Any, Callable +import io +from pydantic.v1 import BaseModel + +class VariableDescription(BaseModel): + name: str + type: str + value: str + # We cant use schema as it is an attribute of pydantic + structure: str | None = None + + def format(self) -> str: + res = "\n" + res += f"{self.name}\n" + res += f"{self.type}\n" + if self.structure: + res += f"{self.structure}\n" + res += f"{self.value}\n" + res += "\n" + + return res + + +def default_handler(name: str, x: Any): + fqn = f"{x.__class__.__module__}.{x.__class__.__name__}" + return VariableDescription( + name=name, + type=fqn, + value=str(x) + ) + +def dataframe_handler(name: str, x: pd.DataFrame): + info_buf = io.StringIO() + x.info(buf=info_buf, memory_usage=False, show_counts=False) + + return VariableDescription( + name=name, + type="pandas.DataFrame", + value="Some random rows from the dataframe:\n" + str(x.sample(min(5, len(x)))), + structure=info_buf.getvalue() + ) + +def function_handler(name: str, x: Callable): + return VariableDescription( + name=name, + type="function", + value=inspect.getsource(x) + ) + +def basic_type_handler(name: str, x: Any): + return VariableDescription( + name=name, + type=type(x).__name__, + value=x + ) + +def list_handler(name: str, x: Any): + return VariableDescription( + name=name, + type=type(x).__name__, + value=str(random.sample(list(x), 5)), + structure=f"Size: {len(x)}" + ) + +def string_handler(name: str, x: Any): + val = x if len(x) < 10 else x[:10] + "..." + return basic_type_handler(name, val) + + +class DescriberRegistry: + """ + Maintains a registry of handlers for different dtypes for sending context + to an LLM in jupyter ai + + This also accespts a _repr_llm_ that should return a dict with mapping of mime type to content. + Currently only application/jupyer+ai+var is supported. + + Sample Usage: + >>> from jupyter_ai._variable_describer import DescriberRegistry, VariableDescription + >>> int_handler = lambda name, value: VariableDescription(name=name, type="int", value=value) + >>> DescriberRegistry.register(int, int_handler) + """ + registry: dict[type, Callable[[str, Any], str]] = {} + + @classmethod + def _init(cls): + cls.register(pd.DataFrame, dataframe_handler) + cls.register(types.FunctionType, function_handler) + cls.register(int, basic_type_handler) + cls.register(float, basic_type_handler) + cls.register(str, string_handler) + cls.register(list, list_handler) + cls.register(tuple, list_handler) + cls.register(set, list_handler) + + @classmethod + def register(cls, var_type: type, handler: Callable[[str, Any], str]): + cls.registry[var_type] = handler + + @classmethod + def get(cls, value: Any): + return cls.registry.get(type(value), None) + + @classmethod + def _get_repr_handler(cls): + def _repr_handler(name: str, value: Any): + repr_llm = value._repr_llm_(name=name) + description = repr_llm.get("application/jupyer+ai+var", None) + + if description is None: + return default_handler(name, value) + + return VariableDescription.parse_obj(description) + return _repr_handler + + + @classmethod + def describe(cls, name: str, value: Any) -> str: + """ + Returns a description of the variable in an XML format + """ + # Priority to set handlers + # followed by repr_llm handlers + # followed by default handlers + handler = cls.get(value) + + if handler is not None and hasattr(value, "_repr_llm_"): + handler = cls._get_repr_handler() + + if handler is None: + handler = default_handler + + desc = handler(name, value) + if not isinstance(desc, VariableDescription): + raise TypeError("Handler must return an instance of VariableDescription") + + return desc.format() + +DescriberRegistry._init() + +def describe_var(name: str, value: Any) -> str: + return DescriberRegistry.describe(name, value) \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index c844650ad..5a4c2592a 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -16,6 +16,7 @@ Type, Union, cast, + Any ) from typing import get_args as get_type_args from uuid import uuid4 @@ -33,6 +34,8 @@ HumanChatMessage, Message, PendingMessage, + AgentStreamChunkMessage, + AgentStreamMessage ) from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider @@ -140,6 +143,7 @@ class BaseChatHandler: message_interrupted: Dict[str, asyncio.Event] """Dictionary mapping an agent message identifier to an asyncio Event which indicates if the message generation/streaming was interrupted.""" + show_help: bool = False def __init__( self, @@ -455,6 +459,7 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non [ f"* `{command_name}` — {handler.help}" for command_name, handler in slash_commands.items() + if handler.show_help ] ) @@ -523,6 +528,7 @@ async def stream_reply( human_msg: HumanChatMessage, pending_msg="Generating response", config: Optional[RunnableConfig] = None, + extra_metadata: dict={}, ): """ Streams a reply to a human message by invoking @@ -604,7 +610,10 @@ async def stream_reply( stream_id, stream_tombstone, complete=True, - metadata=metadata_handler.jai_metadata, + metadata=dict( + metadata_handler.jai_metadata, + **extra_metadata + ), ) del self.message_interrupted[stream_id] diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index d5b0ab6c7..70813289f 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -10,7 +10,7 @@ class ClearChatHandler(BaseChatHandler): name = "Clear chat messages" help = "Clear the chat window" routing_type = SlashCommandRoutingType(slash_id="clear") - + show_help: bool = True uses_llm = False def __init__(self, *args, **kwargs): diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 266ad73ad..00f8f50a3 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -9,7 +9,6 @@ from ..context_providers import ContextProviderException, find_commands from .base import BaseChatHandler, SlashCommandRoutingType - class DefaultChatHandler(BaseChatHandler): id = "default" name = "Default" @@ -29,9 +28,26 @@ def create_llm_chain( unified_parameters = { "verbose": True, **provider_params, - **(self.get_model_parameters(provider, provider_params)), + **(self.get_model_parameters(provider, provider_params)) } - llm = provider(**unified_parameters) + config = self.config_manager.get_config() + + if config.chat_prompt: + if config.chat_prompt.system: + unified_parameters["chat_system_prompt"] = config.chat_prompt.system + if config.chat_prompt.default: + unified_parameters["chat_default_prompt"] = config.chat_prompt.default + + if config.completion_prompt: + if config.completion_prompt.system: + unified_parameters["completion_system_prompt"] = config.chat_prompt.system + if config.completion_prompt.default: + unified_parameters["completion_default_prompt"] = config.chat_prompt.default + + + llm = provider( + **unified_parameters, + ) prompt_template = llm.get_chat_prompt_template() self.llm = llm @@ -57,7 +73,19 @@ async def process_message(self, message: HumanChatMessage): self.get_llm_chain() assert self.llm_chain - inputs = {"input": message.body} + inputs = dict( + input=message.body, + selection=message.selection, + notebook_code=self.llm.process_notebook_for_context( + code_cells=[ + cell.content for cell in message.notebook.notebook_code + ] if message.notebook else [], + active_cell=int(message.notebook.active_cell_id) if message.notebook.active_cell_id else None, + ), + active_cell_id=f"Cell {message.notebook.active_cell_id}" if message.notebook.active_cell_id else None, + variable_context=message.notebook.variable_context or None + ) + if "context" in self.prompt_template.input_variables: # include context from context providers. try: @@ -68,7 +96,12 @@ async def process_message(self, message: HumanChatMessage): inputs["context"] = context_prompt inputs["input"] = self.replace_prompt(inputs["input"]) - await self.stream_reply(inputs, message) + history = self.llm_chat_memory.messages + prompt = self.prompt_template.invoke(dict(inputs, history=history)) + extra_metadata = dict( + prompt=prompt.to_string() + ) + await self.stream_reply(inputs, message, extra_metadata=extra_metadata) async def make_context_prompt(self, human_msg: HumanChatMessage) -> str: return "\n\n".join( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py index 7323d81c1..ed029916f 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/export.py @@ -13,6 +13,7 @@ class ExportChatHandler(BaseChatHandler): name = "Export chat history" help = "Export chat history to a Markdown file" routing_type = SlashCommandRoutingType(slash_id="export") + show_help: bool = True uses_llm = False diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 71ca3f185..6b5db7e90 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -5,7 +5,6 @@ import time from copy import deepcopy from typing import List, Optional, Type, Union - from deepmerge import always_merger as Merger from jsonschema import Draft202012Validator as Validator from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest @@ -228,8 +227,11 @@ def _init_defaults(self): config_keys = GlobalConfig.__fields__.keys() schema_properties = self.validator.schema.get("properties", {}) default_config = { - field: schema_properties.get(field).get("default") for field in config_keys + # Making this more defensive to prevent from any errors + # due to additions in what we store in config + field: schema_properties.get(field, {}).get("default") for field in config_keys } + if self._defaults is None: return default_config @@ -251,10 +253,16 @@ def _read_config(self) -> GlobalConfig: if last_write <= self._last_read: return self._config + # Since we store only partial configs now, + # we merge the defaults into the partial updates + # the user has stored + defaults = self._init_defaults() with open(self.config_path, encoding="utf-8") as f: self._last_read = time.time_ns() raw_config = json.loads(f.read()) - config = GlobalConfig(**raw_config) + config_final = defaults.copy() + config_final.update(raw_config) + config = GlobalConfig(**config_final) self._validate_config(config) return config @@ -348,8 +356,20 @@ def _write_config(self, new_config: GlobalConfig): } self._validate_config(new_config) + + default_config = GlobalConfig(**self._init_defaults()).dict() + + res = { + key: value + for key, value in new_config.dict().items() + # We avoid storing the values that have not changed + # This helps the user to get any updates we make to the + # config in jupyter ai automatically + if default_config.get(key) != value + } + with open(self.config_path, "w") as f: - json.dump(new_config.dict(), f, indent=self.indentation_depth) + json.dump(res, f, indent=self.indentation_depth) def delete_api_key(self, key_name: str): config_dict = self._read_config().dict() @@ -374,6 +394,16 @@ def delete_api_key(self, key_name: str): config_dict["api_keys"].pop(key_name, None) self._write_config(GlobalConfig(**config_dict)) + def delete_config(self): + """ + Deletes the config stored in user's home + This helps the user resetting any changes they have + made locally + """ + if os.path.exists(self.config_path): + os.remove(self.config_path) + self._init_config() + def update_config(self, config_update: UpdateConfigRequest): # type:ignore last_write = os.stat(self.config_path).st_mtime_ns if config_update.last_read and config_update.last_read < last_write: diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 08c8c5a47..e32ebfed7 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -7,11 +7,13 @@ from importlib_metadata import entry_points from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics import BaseProvider, JupyternautPersona +from jupyter_ai_magics.prompts import CHAT_DEFAULT_TEMPLATE, CHAT_SYSTEM_PROMPT, COMPLETION_DEFAULT_TEMPLATE, COMPLETION_SYSTEM_PROMPT from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp from tornado.web import StaticFileHandler from traitlets import Dict, Integer, List, Unicode +from .models import PromptConfig from .chat_handlers import ( AskChatHandler, ClearChatHandler, @@ -47,11 +49,10 @@ You can ask me a question using the text box below. You can also use these commands: {slash_commands_list} -You can use the following commands to add context to your questions: -{context_commands_list} +You can use @ to inject information about any variable from your active notbook +in your query. E.g. Explain the schema of `@my_dataframe`. -Jupyter AI includes [magic commands](https://jupyter-ai.readthedocs.io/en/latest/users/index.html#the-ai-and-ai-magic-commands) that you can use in your notebooks. -For more information, see the [documentation](https://jupyter-ai.readthedocs.io). +{persona_name} knows about your notebook, currently active cell and your current selection. """ @@ -161,6 +162,16 @@ class AiExtension(ExtensionApp): config=True, ) + default_completion_model = Unicode( + default_value=None, + allow_none=True, + help=""" + Default completion model to use, as string in the format + :, defaults to None. + """, + config=True, + ) + default_api_keys = Dict( key_trait=Unicode(), value_trait=Unicode(), @@ -221,12 +232,20 @@ def initialize_settings(self): self.settings["model_parameters"] = self.model_parameters self.log.info(f"Configured model parameters: {self.model_parameters}") - defaults = { "model_provider_id": self.default_language_model, "embeddings_provider_id": self.default_embeddings_model, + "completions_model_provider_id": self.default_completion_model, "api_keys": self.default_api_keys, "fields": self.model_parameters, + "chat_prompt": PromptConfig( + system=CHAT_SYSTEM_PROMPT, + default=CHAT_DEFAULT_TEMPLATE + ), + "completion_prompt": PromptConfig( + system=COMPLETION_SYSTEM_PROMPT, + default=COMPLETION_DEFAULT_TEMPLATE + ) } # Fetch LM & EM providers @@ -459,7 +478,9 @@ def _init_context_provders(self): "context_providers": self.settings["jai_context_providers"], } context_providers_clses = [ - FileContextProvider, + # TODO: Robustify and test this and add it back + # Drops file context provider + # FileContextProvider, ] for context_provider_ep in context_providers_eps: try: diff --git a/packages/jupyter-ai/jupyter_ai/filter_autocomplete_globals.py b/packages/jupyter-ai/jupyter_ai/filter_autocomplete_globals.py new file mode 100644 index 000000000..873d20512 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/filter_autocomplete_globals.py @@ -0,0 +1,43 @@ +import types +import inspect + +# These globals are provided by ipython +# and are not needed for autocomplete +IPYTHON_GLOBALS = [ + "In", + "Out", + "get_ipython", + "exit", + "quit", + "open", + "Err", + "filter_globals" +] + +def _is_defined_in_main(val): + """ + When a variable is defined in main, the source file path + is not available for it and the function throws + + TODO: Find if there is a better way to filter out objects + that are imported. For example, when we do from datetime import datetime + we typically do not want to pollute autocomplete with datetime + """ + try: + inspect.getsourcefile(val) + return False + except: + return True + +def filter_globals(data: dict): + """ + Utility method to fetch autocompletes that is used by the frontend + when user triggers it using @ for variable context insertion + """ + return [ + key + for key, value in data.items() + if not key.startswith('_') + and key not in IPYTHON_GLOBALS + and _is_defined_in_main(value) + ] \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 28b169c00..4c9b9602e 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -260,6 +260,7 @@ def broadcast_message(self, message: Message): stream_message.body += chunk.content stream_message.metadata = chunk.metadata stream_message.complete = chunk.stream_complete + stream_message.metadata = chunk.metadata break elif isinstance(message, PendingMessage): self.pending_messages.append(message) @@ -293,9 +294,6 @@ async def on_message(self, message): chat_request = request message_body = chat_request.prompt - if chat_request.selection: - message_body += f"\n\n```\n{chat_request.selection.source}\n```\n" - # message broadcast to chat clients chat_message_id = str(uuid.uuid4()) chat_message = HumanChatMessage( @@ -305,6 +303,7 @@ async def on_message(self, message): prompt=chat_request.prompt, selection=chat_request.selection, client=self.chat_client, + notebook=chat_request.notebook ) # broadcast the message to other clients @@ -556,6 +555,9 @@ def post(self): 500, "Unexpected error occurred while updating the config." ) from e + @web.authenticated + def delete(self): + self.config_manager.delete_config() class ApiKeysHandler(BaseAPIHandler): @property @@ -677,6 +679,11 @@ def _get_slash_command_options(self) -> List[ListOptionsEntry]: ): continue + # TODO: Drop support for show_help and enable all the slash + # commands after robustifying them + if not chat_handler.show_help: + continue + routing_type = chat_handler.routing_type # filter out any chat handler that is unsupported by the current LLM diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 48dbe6193..c946d11c8 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -32,11 +32,20 @@ class CellWithErrorSelection(BaseModel): Selection = Union[TextSelection, CellSelection, CellWithErrorSelection] +class NotebookCell(BaseModel): + content: Optional[str] + type: Literal["raw", "markdown", "code"] + +class Notebook(BaseModel): + notebook_code: Optional[list[NotebookCell]] + active_cell_id: Optional[str] + variable_context: Optional[str] # the type of message used to chat with the agent class ChatRequest(BaseModel): prompt: str selection: Optional[Selection] + notebook: Optional[Notebook | None] class StopRequest(BaseModel): @@ -139,6 +148,8 @@ class HumanChatMessage(BaseModel): prompt: str """The prompt typed into the chat input by the user.""" selection: Optional[Selection] + """The current notebook and its cells""" + notebook: Optional[Notebook | None] """The selection included with the prompt, if any.""" client: ChatClient @@ -227,6 +238,10 @@ class IndexMetadata(BaseModel): dirs: List[IndexedDir] +class PromptConfig(BaseModel): + system: Optional[str] + default: Optional[str] + class DescribeConfigResponse(BaseModel): model_provider_id: Optional[str] embeddings_provider_id: Optional[str] @@ -240,7 +255,8 @@ class DescribeConfigResponse(BaseModel): last_read: int completions_model_provider_id: Optional[str] completions_fields: Dict[str, Dict[str, Any]] - + chat_prompt: Optional[PromptConfig] + completion_prompt: Optional[PromptConfig] def forbid_none(cls, v): assert v is not None, "size may not be None" @@ -258,6 +274,8 @@ class UpdateConfigRequest(BaseModel): last_read: Optional[int] completions_model_provider_id: Optional[str] completions_fields: Optional[Dict[str, Dict[str, Any]]] + chat_prompt: Optional[PromptConfig] + completion_prompt: Optional[PromptConfig] _validate_send_wse = validator("send_with_shift_enter", allow_reuse=True)( forbid_none @@ -277,6 +295,8 @@ class GlobalConfig(BaseModel): api_keys: Dict[str, str] completions_model_provider_id: Optional[str] completions_fields: Dict[str, Dict[str, Any]] + chat_prompt: Optional[PromptConfig] + completion_prompt: Optional[PromptConfig] class ListSlashCommandsEntry(BaseModel): diff --git a/packages/jupyter-ai/schema/plugin.json b/packages/jupyter-ai/schema/plugin.json index 78804b5c6..fef1c3f4a 100644 --- a/packages/jupyter-ai/schema/plugin.json +++ b/packages/jupyter-ai/schema/plugin.json @@ -7,7 +7,7 @@ "jupyter.lab.shortcuts": [ { "command": "jupyter-ai:focus-chat-input", - "keys": ["Accel Shift 1"], + "keys": ["Accel Shift K"], "selector": "body", "preventDefault": false } diff --git a/packages/jupyter-ai/src/completions/plugin.ts b/packages/jupyter-ai/src/completions/plugin.ts index 4487b2752..af05a7f20 100644 --- a/packages/jupyter-ai/src/completions/plugin.ts +++ b/packages/jupyter-ai/src/completions/plugin.ts @@ -2,6 +2,7 @@ import { JupyterFrontEnd, JupyterFrontEndPlugin } from '@jupyterlab/application'; +import { INotebookTracker } from '@jupyterlab/notebook'; import { ICompletionProviderManager } from '@jupyterlab/completer'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; import { @@ -54,7 +55,8 @@ export const completionPlugin: JupyterFrontEndPlugin = { requires: [ ICompletionProviderManager, IEditorLanguageRegistry, - ISettingRegistry + ISettingRegistry, + INotebookTracker ], optional: [IJaiStatusItem], provides: IJaiCompletionProvider, @@ -63,6 +65,7 @@ export const completionPlugin: JupyterFrontEndPlugin = { completionManager: ICompletionProviderManager, languageRegistry: IEditorLanguageRegistry, settingRegistry: ISettingRegistry, + notebookTracker: INotebookTracker, statusItem: IJaiStatusItem | null ): Promise => { if (typeof completionManager.registerInlineProvider === 'undefined') { @@ -76,7 +79,8 @@ export const completionPlugin: JupyterFrontEndPlugin = { const completionHandler = new CompletionWebsocketHandler(); const provider = new JaiInlineProvider({ completionHandler, - languageRegistry + languageRegistry, + notebookTracker }); await completionHandler.initialize(); diff --git a/packages/jupyter-ai/src/completions/provider.ts b/packages/jupyter-ai/src/completions/provider.ts index 80c199f5a..5cb15308f 100644 --- a/packages/jupyter-ai/src/completions/provider.ts +++ b/packages/jupyter-ai/src/completions/provider.ts @@ -6,6 +6,7 @@ import { IInlineCompletionItem, CompletionHandler } from '@jupyterlab/completer'; +import { INotebookTracker } from '@jupyterlab/notebook'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; import { Notification, showErrorMessage } from '@jupyterlab/apputils'; import { JSONValue, PromiseDelegate } from '@lumino/coreutils'; @@ -20,6 +21,7 @@ import { AiCompleterService as AiService } from './types'; import { DocumentWidget } from '@jupyterlab/docregistry'; import { jupyternautIcon } from '../icons'; import { CompletionWebsocketHandler } from './handler'; +import { formatCodeForCell } from '../utils'; type StreamChunk = AiService.InlineCompletionStreamChunk; @@ -50,6 +52,7 @@ export class JaiInlineProvider return 'JupyterAI'; } + async fetch( request: CompletionHandler.IRequest, context: IInlineCompletionContext @@ -65,6 +68,7 @@ export class JaiInlineProvider // from other (e.g. less expensive or faster) providers. return { items: [] }; } + const mime = request.mimeType ?? 'text/plain'; const language = this.options.languageRegistry.findByMIME(mime); if (!language) { @@ -247,11 +251,21 @@ export class JaiInlineProvider * Extract prefix from request, accounting for context window limit. */ private _prefixFromRequest(request: CompletionHandler.IRequest): string { - const textBefore = request.text.slice(0, request.offset); - const prefix = textBefore.slice( - -Math.min(this._settings.maxPrefix, textBefore.length) + const notebookTracker = this.options.notebookTracker; + const cells = notebookTracker?.currentWidget?.model!.sharedModel.cells; + const currentCellIndex = cells?.findIndex( + cell => cell.id === notebookTracker?.activeCell?.model.sharedModel.id + ); + + const previousCells = cells?.slice(0, currentCellIndex); + const prevCode = previousCells?.map(cell => formatCodeForCell(cell)).join('\n\n'); + let prefix = request.text.slice(0, request.offset); + if (prevCode && previousCells) { + prefix = prevCode + '\n\n# %%\n' + prefix; + } + return prefix.slice( + Math.max(0, prefix.length - this._settings.maxPrefix) ); - return prefix; } /** @@ -259,11 +273,21 @@ export class JaiInlineProvider */ private _suffixFromRequest(request: CompletionHandler.IRequest): string { const textAfter = request.text.slice(request.offset); - const prefix = textAfter.slice( - 0, - Math.min(this._settings.maxPrefix, textAfter.length) - ); - return prefix; + const notebookTracker = this.options.notebookTracker; + const cells = notebookTracker?.currentWidget?.model?.sharedModel.cells; + + const currentCellIndex = cells?.findIndex( + cell => cell.id === notebookTracker?.activeCell?.model?.sharedModel.id + ); + + const nextCells = cells?.slice((currentCellIndex || 0) + 1); + const nextCode = nextCells?.map((cell, i) => cell.source).join('\n\n# %%\n'); + + if (textAfter && nextCode) { + return textAfter + "\n\n# %%\n" + nextCode; + } + + return textAfter.slice(0, this._settings.maxSuffix) } private _resolveLanguage(language: IEditorLanguage | null) { @@ -293,6 +317,7 @@ export namespace JaiInlineProvider { export interface IOptions { completionHandler: CompletionWebsocketHandler; languageRegistry: IEditorLanguageRegistry; + notebookTracker: INotebookTracker; } export interface ISettings { diff --git a/packages/jupyter-ai/src/components/chat-input.tsx b/packages/jupyter-ai/src/components/chat-input.tsx index 1e19f7774..3530cbd9c 100644 --- a/packages/jupyter-ai/src/components/chat-input.tsx +++ b/packages/jupyter-ai/src/components/chat-input.tsx @@ -25,6 +25,9 @@ import { AiService } from '../handler'; import { SendButton, SendButtonProps } from './chat-input/send-button'; import { useActiveCellContext } from '../contexts/active-cell-context'; import { ChatHandler } from '../chat_handler'; +import { useSelectionContext } from '../contexts/selection-context'; +import { useNotebookTrackerContext } from '../contexts/notebook-tracker-context'; +import { formatCodeForCell, getCompletion, processVariables } from '../utils'; type ChatInputProps = { chatHandler: ChatHandler; @@ -58,7 +61,8 @@ const DEFAULT_COMMAND_ICONS: Record = { '/generate': , '/help': , '/learn': , - '@file': , + // TODO: Reenable it when we are more confident here + // '@file': , unknown: }; @@ -86,12 +90,12 @@ function renderAutocompleteOption( > {option.label} - 0 && {' — ' + option.description} - + } ); @@ -110,6 +114,8 @@ export function ChatInput(props: ChatInputProps): JSX.Element { >([]); const [currSlashCommand, setCurrSlashCommand] = useState(null); const activeCell = useActiveCellContext(); + const [textSelection] = useSelectionContext(); + const notebookTracker = useNotebookTrackerContext(); /** * Effect: fetch the list of available slash commands from the backend on @@ -212,8 +218,27 @@ export function ChatInput(props: ChatInputProps): JSX.Element { } }, [open, highlighted]); - function onSend(selection?: AiService.Selection) { - const prompt = input; + + const _getNotebookCells = () => { + const cells = notebookTracker?.currentWidget?.model?.sharedModel.cells; + const notebookCode: AiService.NotebookCell[] | undefined = cells?.map((cell, index) => ({ + content: formatCodeForCell(cell, index), + type: cell.cell_type + })); + + + const activeCellId = activeCell.manager.getActiveCellId() + const activeCellIdx = cells?.findIndex(cell => cell.id === activeCellId); + + return { + notebookCode, + activeCellId: activeCellIdx !== -1 ? activeCellIdx : undefined + } + } + + async function onSend() { + const {varValues, processedInput} = await processVariables(input, notebookTracker); + const prompt = processedInput; setInput(''); // if the current slash command is `/fix`, we always include a code cell @@ -231,10 +256,64 @@ export function ChatInput(props: ChatInputProps): JSX.Element { return; } + const selection: AiService.Selection | null = textSelection?.text ? { + type: "text", + source: textSelection.text + } : activeCell.manager.getContent(false) ? { + type: "cell", + source: activeCell.manager.getContent(false)?.source || "" + } : null; + + const { notebookCode, activeCellId } = _getNotebookCells(); + const notebook: AiService.ChatNotebookContent | null = { + notebook_code: notebookCode, + active_cell_id: activeCellId, + variable_context: varValues + } + // otherwise, send a ChatRequest with the prompt and selection - props.chatHandler.sendMessage({ prompt, selection }); + props.chatHandler.sendMessage({ prompt, selection, notebook }); } + useEffect(() => { + const _getcompletionUtil = async (prefix: string) => { + const completions = await getCompletion(prefix, notebookTracker); + + setAutocompleteOptions(completions.map(option => ({ + id: option, + // Add an explict space for better user experience + // as one generally wants to select and type directly + label: `${option} `, + description: "", + only_start: false + }))) + setOpen(true); + return; + } + + // This represents a slash command run directly + // However, when a /command has a certain prompt associated + // with it we do not want to return just yet to still support + // @ based autocompletions + // For eg. /ask What is @x? + if (input.startsWith("/") && !input.includes(" ")) return; + + const splitInput = input.split(" "); + if (!splitInput.length) { + setOpen(false); + return; + } + + const prefix = splitInput[splitInput.length - 1]; + if (prefix[0] !== "@") { + setOpen(false); + return; + } + + _getcompletionUtil(prefix.slice(1)); + }, [input]) + + const inputExists = !!input.trim(); function handleKeyDown(event: React.KeyboardEvent) { if (event.key !== 'Enter') { @@ -265,15 +344,13 @@ export function ChatInput(props: ChatInputProps): JSX.Element { } // Set the helper text based on whether Shift+Enter is used for sending. - const helperText = props.sendWithShiftEnter ? ( + const helperText = ( - Press Shift+Enter to send message + Open with Ctrl+Shift+Space - ) : ( - - Press Shift+Enter to add a new line - - ); + ) + + ; const sendButtonProps: SendButtonProps = { onSend, @@ -295,8 +372,15 @@ export function ChatInput(props: ChatInputProps): JSX.Element { ): AiService.AutocompleteOption[] { const lastWord = getLastWord(inputValue); if (lastWord === '') { - return []; + return []; + } + + // When we trigger @ based autocompletions, we want to filter out + // slash commands + if (lastWord.startsWith("@")) { + return options.filter(option => !option.label.startsWith("/")) } + const isStart = lastWord === inputValue; return options.filter( option => @@ -304,6 +388,7 @@ export function ChatInput(props: ChatInputProps): JSX.Element { ); } + return ( 2 ? helperText : ' '} + helperText={helperText} /> )} /> diff --git a/packages/jupyter-ai/src/components/chat-input/send-button.tsx b/packages/jupyter-ai/src/components/chat-input/send-button.tsx index f62f6ee67..06b49098e 100644 --- a/packages/jupyter-ai/src/components/chat-input/send-button.tsx +++ b/packages/jupyter-ai/src/components/chat-input/send-button.tsx @@ -1,20 +1,15 @@ -import React, { useCallback, useState } from 'react'; -import { Box, Menu, MenuItem, Typography } from '@mui/material'; -import KeyboardArrowDown from '@mui/icons-material/KeyboardArrowDown'; +import React from 'react'; +import { Box } from '@mui/material'; import SendIcon from '@mui/icons-material/Send'; import StopIcon from '@mui/icons-material/Stop'; import { TooltippedButton } from '../mui-extras/tooltipped-button'; -import { includeSelectionIcon } from '../../icons'; -import { useActiveCellContext } from '../../contexts/active-cell-context'; -import { useSelectionContext } from '../../contexts/selection-context'; -import { AiService } from '../../handler'; const FIX_TOOLTIP = '/fix requires an active code cell with an error'; export type SendButtonProps = { - onSend: (selection?: AiService.Selection) => unknown; onStop: () => unknown; + onSend: () => unknown; sendWithShiftEnter: boolean; currSlashCommand: string | null; inputExists: boolean; @@ -27,20 +22,6 @@ export type SendButtonProps = { }; export function SendButton(props: SendButtonProps): JSX.Element { - const [menuAnchorEl, setMenuAnchorEl] = useState(null); - const [menuOpen, setMenuOpen] = useState(false); - const [textSelection] = useSelectionContext(); - const activeCell = useActiveCellContext(); - - const openMenu = useCallback((el: HTMLElement | null) => { - setMenuAnchorEl(el); - setMenuOpen(true); - }, []); - - const closeMenu = useCallback(() => { - setMenuOpen(false); - }, []); - let action: 'send' | 'stop' | 'fix' = props.inputExists ? 'send' : props.streamingReplyHere @@ -58,17 +39,6 @@ export function SendButton(props: SendButtonProps): JSX.Element { disabled = true; } - const includeSelectionDisabled = !(activeCell.exists || textSelection); - - const includeSelectionTooltip = - action === 'fix' - ? FIX_TOOLTIP - : textSelection - ? `${textSelection.text.split('\n').length} lines selected` - : activeCell.exists - ? 'Code from 1 active cell' - : 'No selection or active cell'; - const defaultTooltip = props.sendWithShiftEnter ? 'Send message (SHIFT+ENTER)' : 'Send message (ENTER)'; @@ -82,37 +52,6 @@ export function SendButton(props: SendButtonProps): JSX.Element { ? 'Message must not be empty' : defaultTooltip; - function sendWithSelection() { - // if the current slash command is `/fix`, `props.onSend()` should always - // include the code cell with error output, so the `selection` argument does - // not need to be defined. - if (action === 'fix') { - props.onSend(); - closeMenu(); - return; - } - - // otherwise, parse the text selection or active cell, with the text - // selection taking precedence. - if (textSelection?.text) { - props.onSend({ - type: 'text', - source: textSelection.text - }); - closeMenu(); - return; - } - - if (activeCell.exists) { - props.onSend({ - type: 'cell', - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - source: activeCell.manager.getContent(false)!.source - }); - closeMenu(); - return; - } - } return ( @@ -132,73 +71,6 @@ export function SendButton(props: SendButtonProps): JSX.Element { > {action === 'stop' ? : } - { - openMenu(e.currentTarget); - }} - disabled={disabled} - tooltip="" - buttonProps={{ - variant: 'contained', - onKeyDown: e => { - if (e.key !== 'Enter' && e.key !== ' ') { - return; - } - openMenu(e.currentTarget); - // stopping propagation of this event prevents the prompt from being - // sent when the dropdown button is selected and clicked via 'Enter'. - e.stopPropagation(); - } - }} - sx={{ - minWidth: 'unset', - padding: '4px 0px', - borderRadius: '0px 2px 2px 0px', - borderLeft: '1px solid white' - }} - > - - - - { - sendWithSelection(); - // prevent sending second message with no selection - e.stopPropagation(); - }} - disabled={includeSelectionDisabled} - > - - - Send message with selection - - {includeSelectionTooltip} - - - - ); } diff --git a/packages/jupyter-ai/src/components/chat-messages.tsx b/packages/jupyter-ai/src/components/chat-messages.tsx index 5c4286f8f..a93f64009 100644 --- a/packages/jupyter-ai/src/components/chat-messages.tsx +++ b/packages/jupyter-ai/src/components/chat-messages.tsx @@ -1,18 +1,19 @@ import React, { useState, useEffect } from 'react'; - -import { Avatar, Box, Typography } from '@mui/material'; +import { Avatar, Box, IconButton, Paper, Typography } from '@mui/material'; import type { SxProps, Theme } from '@mui/material'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ServerConnection } from '@jupyterlab/services'; // TODO: delete jupyternaut from frontend package - +import CheckIcon from '@mui/icons-material/Check'; +import ContentCopyIcon from '@mui/icons-material/ContentCopy'; import { AiService } from '../handler'; import { RendermimeMarkdown } from './rendermime-markdown'; import { useCollaboratorsContext } from '../contexts/collaborators-context'; import { ChatMessageMenu } from './chat-messages/chat-message-menu'; -import { ChatMessageDelete } from './chat-messages/chat-message-delete'; import { ChatHandler } from '../chat_handler'; import { IJaiMessageFooter } from '../tokens'; +import { ChevronRight, ExpandMore } from '@mui/icons-material'; +import { CopyStatus, useCopy } from '../hooks/use-copy'; type ChatMessagesProps = { rmRegistry: IRenderMimeRegistry; @@ -26,6 +27,7 @@ type ChatMessageHeaderProps = { chatHandler: ChatHandler; timestamp: string; sx?: SxProps; + onFullPromptChange?: () => void; }; function sortMessages( @@ -114,11 +116,6 @@ export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element { ? props.message.client.display_name : props.message.persona.name; - const shouldShowMenu = - props.message.type === 'agent' || - (props.message.type === 'agent-stream' && props.message.complete); - const shouldShowDelete = props.message.type === 'human'; - return ( {name} - + {props.timestamp} - {shouldShowMenu && ( - )} - {shouldShowDelete && ( - - )} ); } +interface ChatMessageProps { + message: AiService.ChatMessage; + timestamp: string; + chatHandler: ChatHandler; + rmRegistry: IRenderMimeRegistry; + messageFooter?: IJaiMessageFooter | null; +} + +const ChatMessage = ({ + message, + timestamp, + chatHandler, + rmRegistry, + messageFooter +}: ChatMessageProps) => { + const [showFullPrompt, setShowFullPrompt] = useState(false); + const { + copy: copyText, + copyStatus + } = useCopy(); + + const CopyStatusIcon = copyStatus === CopyStatus.Copied ? CheckIcon : ContentCopyIcon; + + return ( + + + {message.type === 'agent-stream' && message.metadata?.prompt && + setShowFullPrompt(x => !x)}> + {showFullPrompt ? : } + + + + setShowFullPrompt(x => !x)} + > + {showFullPrompt ? "Hide" : "View"} Prompt + + copyText(message.metadata?.prompt)}> + + + + { + showFullPrompt && ( + + + + ) + } + + } + + {messageFooter && ( + + )} + + ) +} + export function ChatMessages(props: ChatMessagesProps): JSX.Element { const [timestamps, setTimestamps] = useState>({}); const [sortedMessages, setSortedMessages] = useState( @@ -216,25 +283,13 @@ export function ChatMessages(props: ChatMessagesProps): JSX.Element { > {sortedMessages.map(message => { return ( - - - - {props.messageFooter && ( - - )} - ); })} diff --git a/packages/jupyter-ai/src/components/chat-messages/chat-message-menu.tsx b/packages/jupyter-ai/src/components/chat-messages/chat-message-menu.tsx index a10c061ee..8b00c51b1 100644 --- a/packages/jupyter-ai/src/components/chat-messages/chat-message-menu.tsx +++ b/packages/jupyter-ai/src/components/chat-messages/chat-message-menu.tsx @@ -1,11 +1,12 @@ -import React, { useRef, useState } from 'react'; +import React, { useMemo, useRef, useState } from 'react'; import { IconButton, Menu, MenuItem, SxProps } from '@mui/material'; import { MoreVert } from '@mui/icons-material'; import { addAboveIcon, addBelowIcon, - copyIcon + copyIcon, + deleteIcon } from '@jupyterlab/ui-components'; import { AiService } from '../../handler'; @@ -13,24 +14,17 @@ import { CopyStatus, useCopy } from '../../hooks/use-copy'; import { useReplace } from '../../hooks/use-replace'; import { useActiveCellContext } from '../../contexts/active-cell-context'; import { replaceCellIcon } from '../../icons'; +import { ChatHandler } from '../../chat_handler'; -type ChatMessageMenuProps = { - message: AiService.ChatMessage; - - /** - * Styles applied to the menu icon button. - */ - sx?: SxProps; +type ChatMessageMenuBaseProps = { + menuItems: { + onClick: () => void; + children?: React.ReactNode; + }[]; }; -export function ChatMessageMenu(props: ChatMessageMenuProps): JSX.Element { +export function ChatMessageMenuBase(props: ChatMessageMenuBaseProps): JSX.Element { const menuButtonRef = useRef(null); - const { copy, copyLabel } = useCopy({ - labelOverrides: { [CopyStatus.None]: 'Copy response' } - }); - const { replace, replaceLabel } = useReplace(); - const activeCell = useActiveCellContext(); - const [menuOpen, setMenuOpen] = useState(false); const [anchorEl, setAnchorEl] = React.useState(null); @@ -39,14 +33,6 @@ export function ChatMessageMenu(props: ChatMessageMenuProps): JSX.Element { setMenuOpen(true); }; - const insertAboveLabel = activeCell.exists - ? 'Insert response above active cell' - : 'Insert response above active cell (no active cell)'; - - const insertBelowLabel = activeCell.exists - ? 'Insert response below active cell' - : 'Insert response below active cell (no active cell)'; - const menuItemSx: SxProps = { display: 'flex', alignItems: 'center', @@ -54,9 +40,16 @@ export function ChatMessageMenu(props: ChatMessageMenuProps): JSX.Element { lineHeight: 0 }; + const onClickHandler = (callback: () => void) => { + return () => { + callback(); + setMenuOpen(false); + } + } + return ( <> - + setMenuOpen(false)} anchorEl={anchorEl} > - copy(props.message.body)} sx={menuItemSx}> - - {copyLabel} - - replace(props.message.body)} sx={menuItemSx}> - - {replaceLabel} - - activeCell.manager.insertAbove(props.message.body)} - disabled={!activeCell.exists} - sx={menuItemSx} - > - - {insertAboveLabel} - - activeCell.manager.insertBelow(props.message.body)} - disabled={!activeCell.exists} - sx={menuItemSx} - > - - {insertBelowLabel} - + { + props.menuItems.map(menuItemProps => ( + + {menuItemProps.children} + + )) + } ); } + + +interface ChatMessageMenuProps { + message: AiService.ChatMessage; + chatHandler: ChatHandler; +} + +export const AIMessageMenu = (props: ChatMessageMenuProps) => { + const { copy, copyLabel } = useCopy({ + labelOverrides: { [CopyStatus.None]: 'Copy response' } + }); + const { replace, replaceLabel } = useReplace(); + const activeCell = useActiveCellContext(); + + const menuItems: ChatMessageMenuBaseProps["menuItems"] = useMemo(() => { + const res = [{ + onClick: () => copy(props.message.body), + children: ( + <> + + {copyLabel} + + ) + }, { + onClick: () => replace(props.message.body), + children: ( + <> + + {replaceLabel} + + ) + }]; + + if (activeCell.exists) { + res.push({ + onClick: () => activeCell.manager.insertAbove(props.message.body), + children: ( + <> + + Insert response above active cell + + ) + }) + + res.push({ + onClick: () => activeCell.manager.insertBelow(props.message.body), + children: ( + <> + + Insert response below active cell + + ) + }) + } + + return res; + }, [activeCell]); + + return ( + + ) +} + +export const HumanMessageMenu = (props: ChatMessageMenuProps) => { + const { copy, copyLabel } = useCopy(); + + // We replace back all the content of the form `var` to @var + const messageToCopy = props.message.body.replace(/`([^`]*)`/g, '@$1'); + + const menuItems: ChatMessageMenuBaseProps["menuItems"] = useMemo(() => { + const res = [{ + onClick: () => copy(messageToCopy), + children: ( + <> + + {copyLabel} + + ) + }, { + onClick: () => props.chatHandler.sendMessage({ + type: 'clear', + target: props.message.id + }), + children: ( + <> + + Delete This Exchange + + ) + }]; + + + return res; + }, []); + + return ( + + ) +} + +export const ChatMessageMenu = (props: ChatMessageMenuProps) => { + switch (props.message.type) { + case "human": + return + case "agent": + case "agent-stream": + return + } +} \ No newline at end of file diff --git a/packages/jupyter-ai/src/components/chat-settings.tsx b/packages/jupyter-ai/src/components/chat-settings.tsx index a1ad0a9b6..9d084d0de 100644 --- a/packages/jupyter-ai/src/components/chat-settings.tsx +++ b/packages/jupyter-ai/src/components/chat-settings.tsx @@ -5,15 +5,11 @@ import { Alert, Button, IconButton, - FormControl, - FormControlLabel, - FormLabel, MenuItem, - Radio, - RadioGroup, TextField, Tooltip, - CircularProgress + CircularProgress, + Stack, } from '@mui/material'; import SettingsIcon from '@mui/icons-material/Settings'; import WarningAmberIcon from '@mui/icons-material/WarningAmber'; @@ -22,13 +18,13 @@ import { Select } from './select'; import { AiService } from '../handler'; import { ModelFields } from './settings/model-fields'; import { ServerInfoState, useServerInfo } from './settings/use-server-info'; -import { ExistingApiKeys } from './settings/existing-api-keys'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { minifyUpdate } from './settings/minify'; import { useStackingAlert } from './mui-extras/stacking-alert'; import { RendermimeMarkdown } from './rendermime-markdown'; import { IJaiCompletionProvider } from '../tokens'; import { getProviderId, getModelLocalId } from '../utils'; +import { ExpandLess, ExpandMore } from '@mui/icons-material'; type ChatSettingsProps = { rmRegistry: IRenderMimeRegistry; @@ -88,6 +84,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { const [apiKeys, setApiKeys] = useState>({}); const [sendWse, setSendWse] = useState(false); const [fields, setFields] = useState>({}); + const [chatPromptConfig, setChatPromptConfig] = useState(); + const [clmPromptConfig, setClmPromptConfig] = useState(); + const [showPrompt, setShowPrompt] = useState(false); const [isCompleterEnabled, setIsCompleterEnabled] = useState( props.completionProvider && props.completionProvider.isEnabled() @@ -109,6 +108,7 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { // whether the form is currently saving const [saving, setSaving] = useState(false); + const [reseting, setReseting] = useState(false); /** * Effect: initialize inputs after fetching server info. @@ -132,6 +132,8 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { } setLmProvider(server.chat.lmProvider); setClmProvider(server.completions.lmProvider); + setChatPromptConfig(server.config.chat_prompt); + setClmPromptConfig(server.config.completion_prompt) }, [server]); /** @@ -190,6 +192,31 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { setFields(currFields); }, [server, lmProvider]); + const handleReset = async () => { + if (server.state !== ServerInfoState.Ready) { + return; + } + + setReseting(true); + try { + await apiKeysAlert.clear(); + await AiService.deleteConfig(); + } catch (e) { + console.error(e); + const msg = + e instanceof Error || typeof e === 'string' + ? e.toString() + : 'An unknown error occurred. Check the console for more details.'; + alert.show('error', msg); + return; + } finally { + setSaving(false); + } + await server.refetchAll(); + alert.show('success', 'Settings reset successfully.'); + setReseting(false); + } + const handleSave = async () => { // compress fields with JSON values if (server.state !== ServerInfoState.Ready) { @@ -226,7 +253,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element { } }), completions_model_provider_id: clmGlobalId, - send_with_shift_enter: sendWse + send_with_shift_enter: sendWse, + chat_prompt: chatPromptConfig, + completion_prompt: clmPromptConfig }; updateRequest = minifyUpdate(server.config, updateRequest); updateRequest.last_read = server.config.last_read; @@ -480,68 +509,110 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {

No Inline Completion models.

)} - {/* API Keys section */} -

API Keys

- - {Object.entries(apiKeys).length === 0 && - server.config.api_keys.length === 0 ? ( -

No API keys are required by the selected models.

- ) : null} - - {/* API key inputs for newly-used providers */} - {Object.entries(apiKeys).map(([apiKeyName, apiKeyValue], idx) => ( - - setApiKeys(apiKeys => ({ - ...apiKeys, - [apiKeyName]: e.target.value - })) - } - /> - ))} - {/* Pre-existing API keys */} - - - {/* Input */} -

Input

- - - When writing a message, press Enter to: - - { - setSendWse(e.target.value === 'newline'); - }} - > - } - label="Send the message" - /> - } - label={ - <> - Start a new line (use Shift+Enter to send) - - } - /> - - - + + {/* Prompt Editor */} + setShowPrompt(x => !x)} sx={{ cursor: "pointer" }}> +

+ Edit Prompts +

+ {showPrompt ? : } +
+ { + showPrompt && ( + + setChatPromptConfig(config => ({ + default: "", + ...config, + system: evt.target.value + }))} + fullWidth + sx={{ + scrollbarWidth: "none" + }} + /> + setChatPromptConfig(config => ({ + system: "", + ...config, + default: evt.target.value + }))} + fullWidth + sx={{ + scrollbarWidth: "none" + }} + helperText={ + <> + notebook_code (str): The code of the active nodebook in jupytext percent format +
+ active_cell_id (str): The uuid of the active cell +
+ selection (dict | None): The current selection of the user +
+ variable_context (str): Description of variables used in the input using @ decorator + + } + /> + { + ( + <> + setChatPromptConfig(config => ({ + default: "", + ...config, + system: evt.target.value + }))} + fullWidth + sx={{ + scrollbarWidth: "none" + }} + /> + setChatPromptConfig(config => ({ + system: "", + ...config, + default: evt.target.value + }))} + fullWidth + sx={{ + scrollbarWidth: "none" + }} + helperText={ + <> + prefix (str): The code of before the cursor position including previous cells +
+ suffix (str): The code after the cursor position including previous cells + + } + /> + + ) + } +
+ ) + } + + + + diff --git a/packages/jupyter-ai/src/components/chat.tsx b/packages/jupyter-ai/src/components/chat.tsx index d55b1a5ce..d40e50d8c 100644 --- a/packages/jupyter-ai/src/components/chat.tsx +++ b/packages/jupyter-ai/src/components/chat.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect } from 'react'; +import React, { useState, useEffect, useRef } from 'react'; import { Box } from '@mui/system'; import { Button, IconButton, Stack } from '@mui/material'; import SettingsIcon from '@mui/icons-material/Settings'; @@ -9,7 +9,7 @@ import type { IThemeManager } from '@jupyterlab/apputils'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import type { User } from '@jupyterlab/services'; import { ISignal } from '@lumino/signaling'; - +import { INotebookTracker } from '@jupyterlab/notebook'; import { JlThemeProvider } from './jl-theme-provider'; import { ChatMessages } from './chat-messages'; import { PendingMessages } from './pending-messages'; @@ -33,6 +33,8 @@ import { UserContextProvider, useUserContext } from '../contexts/user-context'; import { ScrollContainer } from './scroll-container'; import { TooltippedIconButton } from './mui-extras/tooltipped-icon-button'; import { TelemetryContextProvider } from '../contexts/telemetry-context'; +import { NotebookTrackerContext } from '../contexts/notebook-tracker-context'; +import { UtilsContext } from '../contexts/utils-context'; type ChatBodyProps = { chatHandler: ChatHandler; @@ -207,6 +209,10 @@ export type ChatProps = { messageFooter: IJaiMessageFooter | null; telemetryHandler: IJaiTelemetryHandler | null; userManager: User.IManager; + notebookTracker: INotebookTracker | null; + // Helps going back to the notebook to the active cell + // After some action such as code addition is performed + goBackToNotebook: () => void; }; enum ChatView { @@ -217,12 +223,26 @@ enum ChatView { export function Chat(props: ChatProps): JSX.Element { const [view, setView] = useState(props.chatView || ChatView.Chat); const [showWelcomeMessage, setShowWelcomeMessage] = useState(false); + const ref = useRef(); const openSettingsView = () => { setShowWelcomeMessage(false); setView(ChatView.Settings); }; + // TODO: Find a better way to do this + // This helps in keyboard accessibility where the user can simply press + // Escape to go back to the notebook whenever anything is active in the chat + // One way could be to add this to all focusable elements such as inputs and + // buttons that is less maintainable + useEffect(() => { + document.addEventListener("keydown", (evt) => { + if (evt.key === "Escape" && ref.current && ref.current.contains(document.activeElement)) { + props.goBackToNotebook(); + } + }) + }, []); + return ( @@ -231,76 +251,80 @@ export function Chat(props: ChatProps): JSX.Element { activeCellManager={props.activeCellManager} > - - =4.3.0. - // See: https://jupyterlab.readthedocs.io/en/latest/extension/extension_migration.html#css-styling - className="jp-ThemedContainer" - // root box should not include padding as it offsets the vertical - // scrollbar to the left - sx={{ - width: '100%', - height: '100%', - boxSizing: 'border-box', - background: 'var(--jp-layout-color0)', - display: 'flex', - flexDirection: 'column' - }} + + + - {/* top bar */} - - {view !== ChatView.Chat ? ( - setView(ChatView.Chat)}> - - - ) : ( - - )} - {view === ChatView.Chat ? ( - - {!showWelcomeMessage && ( - - props.chatHandler.sendMessage({ type: 'clear' }) - } - tooltip="New chat" - > - - - )} - openSettingsView()}> - + + {/* top bar */} + + {view !== ChatView.Chat ? ( + setView(ChatView.Chat)}> + - - ) : ( - + ) : ( + + )} + {view === ChatView.Chat ? ( + + {!showWelcomeMessage && ( + + props.chatHandler.sendMessage({ type: 'clear' }) + } + tooltip="New chat" + > + + + )} + openSettingsView()}> + + + + ) : ( + + )} + + {/* body */} + {view === ChatView.Chat && ( + + )} + {view === ChatView.Settings && ( + )} - - {/* body */} - {view === ChatView.Chat && ( - - )} - {view === ChatView.Settings && ( - - )} - - + + + + diff --git a/packages/jupyter-ai/src/components/code-blocks/code-toolbar.tsx b/packages/jupyter-ai/src/components/code-blocks/code-toolbar.tsx index 315e5d4d6..bbce5bb2d 100644 --- a/packages/jupyter-ai/src/components/code-blocks/code-toolbar.tsx +++ b/packages/jupyter-ai/src/components/code-blocks/code-toolbar.tsx @@ -17,6 +17,7 @@ import { useCopy } from '../../hooks/use-copy'; import { AiService } from '../../handler'; import { useTelemetry } from '../../contexts/telemetry-context'; import { TelemetryEvent } from '../../tokens'; +import { useUtilsContext } from '../../contexts/utils-context'; export type CodeToolbarProps = { /** @@ -50,14 +51,31 @@ export function CodeToolbar(props: CodeToolbarProps): JSX.Element { borderTop: 'none' }} > + + - - ); } +const useCellAction = (callback: () => void, command: string, props: ToolbarButtonProps) => { + const { goBackToNotebook } = useUtilsContext(); + const telemetryHandler = useTelemetry(); + + return () => { + callback(); + if (goBackToNotebook) goBackToNotebook(); + + try { + telemetryHandler.onEvent(buildTelemetryEvent(command, props)); + } catch (e) { + console.error(e); + return; + } + } +} + type ToolbarButtonProps = { code: string; activeCellExists: boolean; @@ -96,24 +114,15 @@ function buildTelemetryEvent( } function InsertAboveButton(props: ToolbarButtonProps) { - const telemetryHandler = useTelemetry(); const tooltip = props.activeCellExists ? 'Insert above active cell' : 'Insert above active cell (no active cell)'; + const cb = useCellAction(() => props.activeCellManager.insertAbove(props.code), 'insert-above', props); return ( { - props.activeCellManager.insertAbove(props.code); - - try { - telemetryHandler.onEvent(buildTelemetryEvent('insert-above', props)); - } catch (e) { - console.error(e); - return; - } - }} + onClick={cb} disabled={!props.activeCellExists} > @@ -122,25 +131,16 @@ function InsertAboveButton(props: ToolbarButtonProps) { } function InsertBelowButton(props: ToolbarButtonProps) { - const telemetryHandler = useTelemetry(); const tooltip = props.activeCellExists ? 'Insert below active cell' : 'Insert below active cell (no active cell)'; + const cb = useCellAction(() => props.activeCellManager.insertBelow(props.code), 'insert-below', props); return ( { - props.activeCellManager.insertBelow(props.code); - - try { - telemetryHandler.onEvent(buildTelemetryEvent('insert-below', props)); - } catch (e) { - console.error(e); - return; - } - }} + onClick={cb} > @@ -148,23 +148,14 @@ function InsertBelowButton(props: ToolbarButtonProps) { } function ReplaceButton(props: ToolbarButtonProps) { - const telemetryHandler = useTelemetry(); const { replace, replaceDisabled, replaceLabel } = useReplace(); + const cb = useCellAction(() => replace(props.code), 'replace', props); return ( { - replace(props.code); - - try { - telemetryHandler.onEvent(buildTelemetryEvent('replace', props)); - } catch (e) { - console.error(e); - return; - } - }} + onClick={cb} > @@ -172,23 +163,15 @@ function ReplaceButton(props: ToolbarButtonProps) { } export function CopyButton(props: ToolbarButtonProps): JSX.Element { - const telemetryHandler = useTelemetry(); const { copy, copyLabel } = useCopy(); + const cb = useCellAction(() => copy(props.code), 'copy', props); + return ( { - copy(props.code); - - try { - telemetryHandler.onEvent(buildTelemetryEvent('copy', props)); - } catch (e) { - console.error(e); - return; - } - }} + onClick={cb} aria-label="Copy to clipboard" > diff --git a/packages/jupyter-ai/src/components/rendermime-markdown.tsx b/packages/jupyter-ai/src/components/rendermime-markdown.tsx index 9a0278517..31c996111 100644 --- a/packages/jupyter-ai/src/components/rendermime-markdown.tsx +++ b/packages/jupyter-ai/src/components/rendermime-markdown.tsx @@ -21,6 +21,11 @@ type RendermimeMarkdownProps = { * `AgentStreamMessage`, in which case this should be set to `false`. */ complete: boolean; + /** + * Hides the code toolbar. We require it in cases when we + * are rendering code blocks in prompts + */ + hideCodeToolbar?: boolean; }; /** @@ -63,7 +68,7 @@ function RendermimeMarkdownBase(props: RendermimeMarkdownProps): JSX.Element { useEffect(() => { const renderContent = async () => { // initialize mime model - const mdStr = escapeLatexDelimiters(props.markdownStr); + const mdStr = escapeLatexDelimiters(props.markdownStr); const model = props.rmRegistry.createModel({ data: { [MD_MIME_TYPE]: mdStr } }); @@ -126,7 +131,7 @@ function RendermimeMarkdownBase(props: RendermimeMarkdownProps): JSX.Element { // Render a `CodeToolbar` element underneath each code block. // We use ReactDOM.createPortal() so each `CodeToolbar` element is able // to use the context in the main React tree. - codeToolbarDefns.map(codeToolbarDefn => { + !props.hideCodeToolbar && codeToolbarDefns.map(codeToolbarDefn => { const [codeToolbarRoot, codeToolbarProps] = codeToolbarDefn; return createPortal( , diff --git a/packages/jupyter-ai/src/contexts/active-cell-context.tsx b/packages/jupyter-ai/src/contexts/active-cell-context.tsx index 72e93a8ca..6979a2ffc 100644 --- a/packages/jupyter-ai/src/contexts/active-cell-context.tsx +++ b/packages/jupyter-ai/src/contexts/active-cell-context.tsx @@ -74,6 +74,11 @@ export class ActiveCellManager { return this._activeCellErrorChanged; } + getActiveCellId() { + const sharedModel = this._activeCell?.model.sharedModel; + return sharedModel?.id; + } + /** * Returns an `ActiveCellContent` object that describes the current active * cell. If no active cell exists, this method returns `null`. diff --git a/packages/jupyter-ai/src/contexts/notebook-tracker-context.tsx b/packages/jupyter-ai/src/contexts/notebook-tracker-context.tsx new file mode 100644 index 000000000..ace1042dc --- /dev/null +++ b/packages/jupyter-ai/src/contexts/notebook-tracker-context.tsx @@ -0,0 +1,19 @@ +import { INotebookTracker } from '@jupyterlab/notebook'; +import React, { useContext } from 'react'; + +/** + * NotebookTrackerContext is a React context that holds an INotebookTracker instance or null. + * It is used to provide access to the notebook tracker object throughout the React component tree. + */ +export const NotebookTrackerContext = React.createContext(null); + +/** + * useNotebookTrackerContext is a custom hook that provides access to the current value of the NotebookTrackerContext. + * This hook allows components to easily access the notebook tracker without directly using the context API. + * + * @returns The current value of the NotebookTrackerContext, which is either an INotebookTracker instance or null. + */ +export const useNotebookTrackerContext = () => { + const tracker = useContext(NotebookTrackerContext); + return tracker; +} diff --git a/packages/jupyter-ai/src/contexts/utils-context.tsx b/packages/jupyter-ai/src/contexts/utils-context.tsx new file mode 100644 index 000000000..c3647b059 --- /dev/null +++ b/packages/jupyter-ai/src/contexts/utils-context.tsx @@ -0,0 +1,20 @@ +import React, { useContext } from 'react'; + +/** + * UtilsContext is a React context that holds an object with an optional goBackToNotebook function. + * It is used to provide utility functions throughout the React component tree. + */ +export const UtilsContext = React.createContext<{ + goBackToNotebook?: () => void; +}>({}); + +/** + * useUtilsContext is a custom hook that provides access to the current value of the UtilsContext. + * This hook allows components to easily access the utility functions without directly using the context API. + * + * @returns The current value of the UtilsContext, which is an object that may contain a goBackToNotebook function. + */ +export const useUtilsContext = () => { + const utils = useContext(UtilsContext); + return utils; +} diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 1b58ed2af..89758bf5d 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -78,9 +78,21 @@ export namespace AiService { | CellSelection | CellWithErrorSelection; + export type NotebookCell = { + content?: string; + type: string; + } + + export type ChatNotebookContent = { + notebook_code?: NotebookCell[]; + active_cell_id?: number; + variable_context?: string; + } + export type ChatRequest = { prompt: string; - selection?: Selection; + selection?: Selection | null; + notebook?: ChatNotebookContent | null; }; export type ClearRequest = { @@ -201,6 +213,11 @@ export namespace AiService { pending_messages: PendingMessage[]; }; + export type PromptConfig = { + system: string; + default: string; + } + export type DescribeConfigResponse = { model_provider_id: string | null; embeddings_provider_id: string | null; @@ -209,6 +226,8 @@ export namespace AiService { fields: Record>; last_read: number; completions_model_provider_id: string | null; + chat_prompt: PromptConfig; + completion_prompt: PromptConfig }; export type UpdateConfigRequest = { @@ -220,6 +239,8 @@ export namespace AiService { last_read?: number; completions_model_provider_id?: string | null; completions_fields?: Record>; + chat_prompt?: PromptConfig; + completion_prompt?: PromptConfig }; export async function getConfig(): Promise { @@ -302,6 +323,12 @@ export namespace AiService { }); } + export async function deleteConfig(): Promise { + return requestAPI('config', { + method: 'DELETE' + }); + } + export async function deleteApiKey(keyName: string): Promise { return requestAPI(`api_keys/${keyName}`, { method: 'DELETE' diff --git a/packages/jupyter-ai/src/index.ts b/packages/jupyter-ai/src/index.ts index f24fbfa00..b6534dcf0 100644 --- a/packages/jupyter-ai/src/index.ts +++ b/packages/jupyter-ai/src/index.ts @@ -3,7 +3,7 @@ import { JupyterFrontEndPlugin, ILayoutRestorer } from '@jupyterlab/application'; - +import { INotebookTracker } from '@jupyterlab/notebook'; import { IWidgetTracker, ReactWidget, @@ -51,7 +51,8 @@ const plugin: JupyterFrontEndPlugin = { IThemeManager, IJaiCompletionProvider, IJaiMessageFooter, - IJaiTelemetryHandler + IJaiTelemetryHandler, + INotebookTracker ], provides: IJaiCore, activate: async ( @@ -62,7 +63,8 @@ const plugin: JupyterFrontEndPlugin = { themeManager: IThemeManager | null, completionProvider: IJaiCompletionProvider | null, messageFooter: IJaiMessageFooter | null, - telemetryHandler: IJaiTelemetryHandler | null + telemetryHandler: IJaiTelemetryHandler | null, + notebookTracker: INotebookTracker | null ) => { /** * Initialize selection watcher singleton @@ -85,6 +87,10 @@ const plugin: JupyterFrontEndPlugin = { }); }; + const goBackToNotebook = () => { + app.commands.execute('notebook:enter-edit-mode'); + } + const focusInputSignal = new Signal({}); let chatWidget: ReactWidget; @@ -102,7 +108,9 @@ const plugin: JupyterFrontEndPlugin = { focusInputSignal, messageFooter, telemetryHandler, - app.serviceManager.user + app.serviceManager.user, + notebookTracker, + goBackToNotebook ); } catch (e) { chatWidget = buildErrorWidget(themeManager); diff --git a/packages/jupyter-ai/src/utils.ts b/packages/jupyter-ai/src/utils.ts index e21b47c6b..55f4e088c 100644 --- a/packages/jupyter-ai/src/utils.ts +++ b/packages/jupyter-ai/src/utils.ts @@ -1,10 +1,13 @@ /** * Contains various utility functions shared throughout the project. */ -import { Notebook } from '@jupyterlab/notebook'; +import { INotebookTracker, Notebook } from '@jupyterlab/notebook'; import { FileEditor } from '@jupyterlab/fileeditor'; import { CodeEditor } from '@jupyterlab/codeeditor'; import { Widget } from '@lumino/widgets'; +import { IKernelConnection } from '@jupyterlab/services/lib/kernel/kernel'; +import { IIOPubMessage } from '@jupyterlab/services/lib/kernel/messages'; +import { CellModel } from '@jupyterlab/cells'; /** * Get text selection from an editor widget (DocumentWidget#content). @@ -71,3 +74,175 @@ export function getModelLocalId(globalModelId: string): string | null { const components = globalModelId.split(':').slice(1); return components.join(':'); } + +/** + * Executes a given code string in the provided Jupyter Notebook kernel and retrieves the output. + * + * @param kernel - The kernel connection to execute the code. + * @param code - The code string to be executed. + * @returns A promise that resolves with the output of the executed code or null if the execution fails. + */ +export async function executeCode(kernel: IKernelConnection, code: string): Promise { + const executeRequest = kernel.requestExecute({ code }); + let variableValue: string | null = null; + + kernel.registerMessageHook(executeRequest.msg.header.msg_id, (msg: IIOPubMessage) => { + if ( + msg.header.msg_type === 'stream' && + // @ts-expect-error tserror + msg.content.name === 'stdout' + ) { + // @ts-expect-error tserror + variableValue = msg.content.text.trim(); + } + return true; + }); + + const reply = await executeRequest.done; + if (reply && reply.content.status === 'ok') { + return variableValue; + } else { + console.error('Failed to retrieve variable value'); + return null; + } +} + +/** + * Generates the code to describe a variable using a Jupyter AI utility. + * + * @param name - The name of the variable to describe. + * @returns The generated code as a string. + */ +const getCode = (name: string) => ` +from jupyter_ai._variable_describer import describe_var +print(describe_var('${name}', ${name})) +` + +/** + * Retrieves the description of a variable from the current Jupyter Notebook kernel. + * + * @param variableName - The name of the variable to describe. + * @param notebookTracker - The notebook tracker to get the current notebook and kernel. + * @returns A promise that resolves with the variable description or null if retrieval fails. + */ +export async function getVariableDescription( + variableName: string, + notebookTracker: INotebookTracker +): Promise { + const notebook = notebookTracker.currentWidget; + if (notebook && notebook.sessionContext.session?.kernel) { + const kernel = notebook.sessionContext.session.kernel; + try { + return await executeCode(kernel, getCode(variableName)); + } catch (error) { + console.error('Error retrieving variable value:', error); + return null; + } + } else { + console.error('No active kernel found'); + return null; + } +} + +/** + * Processes variables in the user input by replacing them with their descriptions from the Jupyter Notebook kernel. + * + * @param userInput - The user input string containing variables to process. + * @param notebookTracker - The notebook tracker to get the current notebook and kernel. + * @returns A promise that resolves with an object containing the processed input and variable values. + */ +export async function processVariables( + userInput: string, + notebookTracker: INotebookTracker | null + ): Promise<{ + processedInput: string, + varValues: string + }> { + if (!notebookTracker) { + return { + processedInput: userInput, + varValues: "" + } + } + + const variablePattern = / @([a-zA-Z_][a-zA-Z0-9_]*(\[[^\]]+\]|(\.[a-zA-Z_][a-zA-Z0-9_]*)?)*)[\s,\-]?/g; + let match; + let variablesProcessed: string[] = []; + let varValues = ''; + let processedInput = userInput; + + while ((match = variablePattern.exec(userInput)) !== null) { + const variableName = match[1]; + if (variablesProcessed.includes(variableName)) { + continue; + } + variablesProcessed.push(variableName); + try { + const varDescription = await getVariableDescription(variableName, notebookTracker); + varValues += `${varDescription}`; + processedInput = processedInput.replace(`@${variableName}`, `\`${variableName}\``); + } catch (error) { + console.error(`Error accessing variable ${variableName}:`, error); + } + } + + return { processedInput, varValues} + } + +/** + * Code to retrieve the global variables from the Jupyter Notebook kernel. + */ +const GLOBALS_CODE = ` +from jupyter_ai.filter_autocomplete_globals import filter_globals +print(filter_globals(globals())) +` + +/** + * Retrieves completion suggestions for a given prefix from the current Jupyter Notebook kernel's global variables. + * + * @param prefix - The prefix to match global variables against. + * @param notebookTracker - The notebook tracker to get the current notebook and kernel. + * @returns A promise that resolves with an array of matching global variable names. + */ +export async function getCompletion( + prefix: string, + notebookTracker: INotebookTracker | null +): Promise { + const notebook = notebookTracker?.currentWidget; + if (notebook?.sessionContext?.session?.kernel) { + const kernel = notebook.sessionContext.session.kernel; + const globalsString = await executeCode(kernel, GLOBALS_CODE); + + if (!globalsString) return []; + + const globals: string[] = JSON.parse(globalsString.replace(/'/g, '"')); + const filteredGlobals = globals.filter(x => x.startsWith(prefix)) + + + return filteredGlobals; + } + + return []; +} + +/** + * Formats code for a Jupyter Notebook cell, adding appropriate annotations based on cell type. + * + * @param cell - The cell model containing the code to format. + * @param index - Optional index of the cell for annotation. + * @returns The formatted code as a string. + */ +export const formatCodeForCell = (cell: CellModel["sharedModel"], index?: number) => { + const cellNumber = index !== undefined ? `Cell ${index}` : ''; + + switch (cell.cell_type) { + case "code": + return `# %% ${cellNumber}\n${cell.source}` + case "markdown": + case "raw": + const source = cell.source.split("\n").map(line => `# ${line}`).join("\n") + return `# %% [markdown] ${cellNumber}\n${source}` + default: + return "" + } + } \ No newline at end of file diff --git a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx index 732eedd3c..76eac4e9c 100644 --- a/packages/jupyter-ai/src/widgets/chat-sidebar.tsx +++ b/packages/jupyter-ai/src/widgets/chat-sidebar.tsx @@ -4,7 +4,7 @@ import { ReactWidget } from '@jupyterlab/apputils'; import type { IThemeManager } from '@jupyterlab/apputils'; import type { User } from '@jupyterlab/services'; import type { Awareness } from 'y-protocols/awareness'; - +import { INotebookTracker } from '@jupyterlab/notebook'; import { Chat } from '../components/chat'; import { chatIcon } from '../icons'; import { SelectionWatcher } from '../selection-watcher'; @@ -29,7 +29,9 @@ export function buildChatSidebar( focusInputSignal: ISignal, messageFooter: IJaiMessageFooter | null, telemetryHandler: IJaiTelemetryHandler | null, - userManager: User.IManager + userManager: User.IManager, + notebookTracker: INotebookTracker | null, + goBackToNotebook: () => void ): ReactWidget { const ChatWidget = ReactWidget.create( ); ChatWidget.id = 'jupyter-ai::chat';