Skip to content

Commit

Permalink
initialize config manager with errors
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-i committed Dec 8, 2023
1 parent de7fbaf commit 742f20b
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 67 deletions.
104 changes: 66 additions & 38 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,19 +164,29 @@ def _init_validator(self) -> Validator:
self.validator = Validator(schema)

def _init_config(self):
try:
if os.path.exists(self.config_path):
self._process_existing_config()
else:
self._create_default_config()
except ValidationError as e:
self._handle_validation_error(e)
# try:
if os.path.exists(self.config_path):
self._process_existing_config()
else:
self._create_default_config()
# except ValidationError as e:
# self._handle_validation_error(e)
# self._config = GlobalConfig(
# send_with_shift_enter=False, fields={}, api_keys={}
# )

def _process_existing_config(self):
with open(self.config_path, encoding="utf-8") as f:
config = GlobalConfig(**json.loads(f.read()))
self._validate_lm_em_id(config)
raw_config = json.loads(f.read())

validated_raw_config = self._validate_lm_em_id(raw_config)

try:
config = GlobalConfig(**validated_raw_config)
self._write_config(config)
except ValidationError as e:
corrected_config = self._handle_validation_error(e, validated_raw_config)
self._write_config(corrected_config)

def _create_default_config(self):
properties = self.validator.schema.get("properties", {})
Expand All @@ -187,16 +197,16 @@ def _create_default_config(self):
default_config = GlobalConfig(**field_dict)
self._write_config(default_config)

def _validate_lm_em_id(self, config):
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id
def _validate_lm_em_id(self, raw_config):
lm_id = raw_config.get("model_provider_id")
em_id = raw_config.get("embeddings_provider_id")

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(lm_id, raise_exc=False):
warning_message = f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
self.log.warning(warning_message)
config.model_provider_id = None
raw_config["model_provider_id"] = None
self._config_errors.append(
ConfigErrorModel(
error_type=ConfigErrorType.WARNING, message=warning_message
Expand All @@ -206,7 +216,7 @@ def _validate_lm_em_id(self, config):
if em_id is not None and not self._validate_model(em_id, raise_exc=False):
warning_message = f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
self.log.warning(warning_message)
config.embeddings_provider_id = None
raw_config["embeddings_provider_id"] = None
self._config_errors.append(
ConfigErrorModel(
error_type=ConfigErrorType.WARNING, message=warning_message
Expand All @@ -220,7 +230,7 @@ def _validate_lm_em_id(self, config):
f"No language model is associated with '{lm_id}'. Setting to None."
)
self.log.warning(warning_message)
config.model_provider_id = None
raw_config["model_provider_id"] = None
self._config_errors.append(
ConfigErrorModel(
error_type=ConfigErrorType.WARNING, message=warning_message
Expand All @@ -232,28 +242,43 @@ def _validate_lm_em_id(self, config):
f"No embedding model is associated with '{em_id}'. Setting to None."
)
self.log.warning(warning_message)
config.embeddings_provider_id = None
raw_config["embeddings_provider_id"] = None
self._config_errors.append(
ConfigErrorModel(
error_type=ConfigErrorType.WARNING, message=warning_message
)
)

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)

def _handle_validation_error(self, e: ValidationError):
formatted_error = _format_validation_errors(e)
error_message = "Configuration validation failed"
self._config_errors.append(
ConfigErrorModel(
error_type=ConfigErrorType.CRITICAL,
message=error_message,
details=formatted_error,
)
)
self.log.error(f"{error_message}: {formatted_error}")
return raw_config

def _handle_validation_error(self, e: ValidationError, raw_config):
# Extract default values from schema
properties = self.validator.schema.get("properties", {})
field_list = GlobalConfig.__fields__.keys()
default_values = {
field: properties.get(field).get("default") for field in field_list
}

# Apply default values to erroneous fields
for error in e.errors():
field = error["loc"][0]
if field in default_values:
raw_config[field] = default_values[field]
warning_message = f"Error in '{field}': {error['msg']}. Resetting to default value ('{default_values[field]}')."
self.log.warning(warning_message)
self._config_errors.append(
ConfigErrorModel(
error_type=ConfigErrorType.WARNING, message=warning_message
)
)

# Create a config with default values for erroneous fields
config = GlobalConfig(**raw_config)
self.log.warning("\n\n\n Config \n\n\n")

self.log.warning(config)
self._validate_config(config)
return config

def _read_config(self) -> GlobalConfig:
"""Returns the user's current configuration as a GlobalConfig object.
Expand All @@ -264,12 +289,15 @@ def _read_config(self) -> GlobalConfig:
if last_write <= self._last_read:
return self._config

with open(self.config_path, encoding="utf-8") as f:
self._last_read = time.time_ns()
raw_config = json.loads(f.read())
config = GlobalConfig(**raw_config)
self._validate_config(config)
return config
with open(self.config_path, encoding="utf-8") as f:
self._last_read = time.time_ns()
raw_config = json.loads(f.read())
try:
config = GlobalConfig(**raw_config)
except ValidationError as e:
config = self._handle_validation_error(e, raw_config)
self._validate_config(config)
return config

def _validate_config(self, config: GlobalConfig):
"""Method used to validate the configuration. This is called after every
Expand Down Expand Up @@ -414,7 +442,7 @@ def update_config(self, config_update: UpdateConfigRequest):
def get_config(self):
config = self._read_config()
config_dict = config.dict(exclude_unset=True)
api_key_names = list(config_dict.pop("api_keys").keys())
api_key_names = list(config_dict.pop("api_keys", {}).keys())
return DescribeConfigResponse(
**config_dict,
api_keys=api_key_names,
Expand Down
38 changes: 28 additions & 10 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,7 @@

class AiExtension(ExtensionApp):
name = "jupyter_ai"
handlers = [
(r"api/ai/api_keys/(?P<api_key_name>\w+)", ApiKeysHandler),
(r"api/ai/config/?", GlobalConfigHandler),
(r"api/ai/chats/?", RootChatHandler),
(r"api/ai/chats/history?", ChatHistoryHandler),
(r"api/ai/providers?", ModelProviderHandler),
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
]
handlers = [(r"api/ai/config/?", GlobalConfigHandler)]

allowed_providers = List(
Unicode(),
Expand Down Expand Up @@ -139,15 +132,24 @@ def initialize_settings(self):
else:
# Log the error and proceed with limited functionality
self.log.error(f"Configuration errors detected: {config_errors}")
# TODO: self._initialize_limited_functionality()
self._initialize_limited_functionality(config_errors)

self.log.info("Registered providers.")
self.log.info(f"Registered {self.name} server extension")

latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.")

def _initialize_full_functionality(self):
self.handlers.extend(
[
(r"api/ai/api_keys/(?P<api_key_name>\w+)", ApiKeysHandler),
(r"api/ai/chats/?", RootChatHandler),
(r"api/ai/chats/history?", ChatHistoryHandler),
(r"api/ai/providers?", ModelProviderHandler),
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
]
)

# Store chat clients in a dictionary
self.settings["chat_clients"] = {}
self.settings["jai_root_chat_handlers"] = {}
Expand Down Expand Up @@ -204,5 +206,21 @@ def _initialize_full_functionality(self):
"/help": help_chat_handler,
}

self.log.info("Registered providers.")

def _initialize_limited_functionality(self, config_errors):
"""
Initialize the extension with limited functionality due to configuration errors.
"""
self.log.warning(
"Initializing Jupyter AI extension with limited functionality due to configuration errors."
)

# Capture configuration error details
config_errors = self.settings["jai_config_manager"].get_config_errors()
self.settings["config_errors"] = config_errors

self.settings["jai_chat_handlers"] = []

async def _get_dask_client(self):
return DaskClient(processes=False, asynchronous=True)
34 changes: 22 additions & 12 deletions packages/jupyter-ai/src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { CollaboratorsContextProvider } from '../contexts/collaborators-context'
import { ScrollContainer } from './scroll-container';

type ChatBodyProps = {
chatHandler: ChatHandler;
chatHandler: ChatHandler | null;
setChatView: (view: ChatView) => void;
};

Expand All @@ -43,12 +43,22 @@ function ChatBody({
useEffect(() => {
async function fetchHistory() {
try {
const [history, config] = await Promise.all([
chatHandler.getHistory(),
AiService.getConfig()
]);
const config = await AiService.getConfig();
setSendWithShiftEnter(config.send_with_shift_enter ?? false);
setMessages(history.messages);

// Check if there are critical errors
const hasCriticalErrors = config.config_errors?.some(
error => error.error_type === AiService.ConfigErrorType.CRITICAL
);
console.log('\n\n\n *** \n\n\n');
console.log(hasCriticalErrors);
if (!hasCriticalErrors && chatHandler) {
const history = await chatHandler.getHistory();
setMessages(history.messages);
} else {
setMessages([]);
}

if (!config.model_provider_id) {
setShowWelcomeMessage(true);
}
Expand Down Expand Up @@ -78,9 +88,9 @@ function ChatBody({
setMessages(messageGroups => [...messageGroups, message]);
}

chatHandler.addListener(handleChatEvents);
chatHandler?.addListener(handleChatEvents);
return function cleanup() {
chatHandler.removeListener(handleChatEvents);
chatHandler?.removeListener(handleChatEvents);
};
}, [chatHandler]);

Expand All @@ -96,18 +106,18 @@ function ChatBody({
: '');

// send message to backend
const messageId = await chatHandler.sendMessage({ prompt });
const messageId = await chatHandler?.sendMessage({ prompt });

// await reply from agent
// no need to append to messageGroups state variable, since that's already
// handled in the effect hooks.
const reply = await chatHandler.replyFor(messageId);
const reply = await chatHandler?.replyFor(messageId ?? '');
if (replaceSelection && selection) {
const { cellId, ...selectionProps } = selection;
replaceSelectionFn({
...selectionProps,
...(cellId && { cellId }),
text: reply.body
text: reply?.body ?? ''
});
}
};
Expand Down Expand Up @@ -187,7 +197,7 @@ function ChatBody({

export type ChatProps = {
selectionWatcher: SelectionWatcher;
chatHandler: ChatHandler;
chatHandler: ChatHandler | null;
globalAwareness: Awareness | null;
chatView?: ChatView;
};
Expand Down
24 changes: 18 additions & 6 deletions packages/jupyter-ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { buildChatSidebar } from './widgets/chat-sidebar';
import { SelectionWatcher } from './selection-watcher';
import { ChatHandler } from './chat_handler';
import { buildErrorWidget } from './widgets/chat-error';
import { AiService } from './handler';

export type DocumentTracker = IWidgetTracker<IDocumentWidget>;

Expand All @@ -32,14 +33,25 @@ const plugin: JupyterFrontEndPlugin<void> = {
*/
const selectionWatcher = new SelectionWatcher(app.shell);

/**
* Initialize chat handler, open WS connection
*/
const chatHandler = new ChatHandler();

let chatWidget: ReactWidget | null = null;
let chatHandler: ChatHandler | null = null;

try {
await chatHandler.initialize();
// Fetch configuration to check for critical errors
const config = await AiService.getConfig();
console.log('\n\n\n *** \n\n\n');
console.log(config.config_errors);
const hasCriticalErrors = config.config_errors?.some(
error => error.error_type === AiService.ConfigErrorType.CRITICAL
);

if (!hasCriticalErrors) {
/**
* Initialize chat handler, open WS connection
*/
chatHandler = new ChatHandler();
await chatHandler.initialize();
}
chatWidget = buildChatSidebar(
selectionWatcher,
chatHandler,
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/src/widgets/chat-sidebar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { ChatHandler } from '../chat_handler';

export function buildChatSidebar(
selectionWatcher: SelectionWatcher,
chatHandler: ChatHandler,
chatHandler: ChatHandler | null,
globalAwareness: Awareness | null
): ReactWidget {
const ChatWidget = ReactWidget.create(
Expand Down

0 comments on commit 742f20b

Please sign in to comment.