diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 1383b62e9..65d3ab159 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -163,70 +163,71 @@ def _init_validator(self) -> Validator: self.validator = Validator(schema) def _init_config(self): - try: - if os.path.exists(self.config_path): - with open(self.config_path, encoding="utf-8") as f: - config = GlobalConfig(**json.loads(f.read())) - lm_id = config.model_provider_id - em_id = config.embeddings_provider_id - - # if the currently selected language or embedding model are - # forbidden, set them to `None` and log a warning. - if lm_id is not None and not self._validate_model( - lm_id, raise_exc=False - ): - self.log.warning( - f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." - ) - config.model_provider_id = None - if em_id is not None and not self._validate_model( - em_id, raise_exc=False - ): - self.log.warning( - f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." - ) - config.embeddings_provider_id = None - - # if the currently selected language or embedding model ids are - # not associated with models, set them to `None` and log a warning. - if ( - lm_id is not None - and not get_lm_provider(lm_id, self._lm_providers)[1] - ): - self.log.warning( - f"No language model is associated with '{lm_id}'. Setting to None." - ) - config.model_provider_id = None - if ( - em_id is not None - and not get_em_provider(em_id, self._em_providers)[1] - ): - self.log.warning( - f"No embedding model is associated with '{em_id}'. Setting to None." - ) - config.embeddings_provider_id = None - - # re-write to the file to validate the config and apply any - # updates to the config file immediately - self._write_config(config) - return - - properties = self.validator.schema.get("properties", {}) - field_list = GlobalConfig.__fields__.keys() - field_dict = { - field: properties.get(field).get("default") for field in field_list - } - default_config = GlobalConfig(**field_dict) - self._write_config(default_config) - - except ValidationError as e: - formatted_error = _format_validation_errors(e) - self.config_error = APIErrorModel( - type="ValidationError", - message="Configuration validation failed", - details=formatted_error, + # try: + if os.path.exists(self.config_path): + self._process_existing_config() + else: + self._create_default_config() + # except ValidationError as e: + # self._handle_validation_error(e) + + def _process_existing_config(self): + with open(self.config_path, encoding="utf-8") as f: + config = GlobalConfig(**json.loads(f.read())) + self._validate_lm_em_id(config) + self._write_config(config) + + def _create_default_config(self): + properties = self.validator.schema.get("properties", {}) + field_list = GlobalConfig.__fields__.keys() + field_dict = { + field: properties.get(field).get("default") for field in field_list + } + default_config = GlobalConfig(**field_dict) + self._write_config(default_config) + + def _validate_lm_em_id(self, config): + lm_id = config.model_provider_id + em_id = config.embeddings_provider_id + + # if the currently selected language or embedding model are + # forbidden, set them to `None` and log a warning. + if lm_id is not None and not self._validate_model(lm_id, raise_exc=False): + self.log.warning( + f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." + ) + config.model_provider_id = None + if em_id is not None and not self._validate_model(em_id, raise_exc=False): + self.log.warning( + f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." ) - self.log.error(f"Configuration validation error: {self.config_error}") + config.embeddings_provider_id = None + + # if the currently selected language or embedding model ids are + # not associated with models, set them to `None` and log a warning. + if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]: + self.log.warning( + f"No language model is associated with '{lm_id}'. Setting to None." + ) + config.model_provider_id = None + if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]: + self.log.warning( + f"No embedding model is associated with '{em_id}'. Setting to None." + ) + config.embeddings_provider_id = None + + # re-write to the file to validate the config and apply any + # updates to the config file immediately + self._write_config(config) + + def _handle_validation_error(self, e: ValidationError): + formatted_error = _format_validation_errors(e) + self._config_error = APIErrorModel( + type="ValidationError", + message="Configuration validation failed", + details=formatted_error, + ) + self.log.error(f"Configuration validation error: {self.config_error}") def _read_config(self) -> GlobalConfig: """Returns the user's current configuration as a GlobalConfig object. diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 010df9a93..b15be577f 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -361,9 +361,10 @@ def config_manager(self): @web.authenticated def get(self): - if self.config_manager.config_error: + config_error = self.config_manager.get_config_error() + if config_error: self.set_status(400) - self.finish(self.config_manager.config_error.json()) + self.finish(config_error.json()) return config = self.config_manager.get_config()