Skip to content

Commit

Permalink
refactor _init_config
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-i committed Nov 16, 2023
1 parent 73df3f4 commit 7680760
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 65 deletions.
127 changes: 64 additions & 63 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,70 +163,71 @@ def _init_validator(self) -> Validator:
self.validator = Validator(schema)

def _init_config(self):
try:
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config = GlobalConfig(**json.loads(f.read()))
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(
lm_id, raise_exc=False
):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not self._validate_model(
em_id, raise_exc=False
):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
config.embeddings_provider_id = None

# if the currently selected language or embedding model ids are
# not associated with models, set them to `None` and log a warning.
if (
lm_id is not None
and not get_lm_provider(lm_id, self._lm_providers)[1]
):
self.log.warning(
f"No language model is associated with '{lm_id}'. Setting to None."
)
config.model_provider_id = None
if (
em_id is not None
and not get_em_provider(em_id, self._em_providers)[1]
):
self.log.warning(
f"No embedding model is associated with '{em_id}'. Setting to None."
)
config.embeddings_provider_id = None

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)
return

properties = self.validator.schema.get("properties", {})
field_list = GlobalConfig.__fields__.keys()
field_dict = {
field: properties.get(field).get("default") for field in field_list
}
default_config = GlobalConfig(**field_dict)
self._write_config(default_config)

except ValidationError as e:
formatted_error = _format_validation_errors(e)
self.config_error = APIErrorModel(
type="ValidationError",
message="Configuration validation failed",
details=formatted_error,
# try:
if os.path.exists(self.config_path):
self._process_existing_config()
else:
self._create_default_config()
# except ValidationError as e:
# self._handle_validation_error(e)

def _process_existing_config(self):
with open(self.config_path, encoding="utf-8") as f:
config = GlobalConfig(**json.loads(f.read()))
self._validate_lm_em_id(config)
self._write_config(config)

def _create_default_config(self):
properties = self.validator.schema.get("properties", {})
field_list = GlobalConfig.__fields__.keys()
field_dict = {
field: properties.get(field).get("default") for field in field_list
}
default_config = GlobalConfig(**field_dict)
self._write_config(default_config)

def _validate_lm_em_id(self, config):
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(lm_id, raise_exc=False):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not self._validate_model(em_id, raise_exc=False):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
self.log.error(f"Configuration validation error: {self.config_error}")
config.embeddings_provider_id = None

# if the currently selected language or embedding model ids are
# not associated with models, set them to `None` and log a warning.
if lm_id is not None and not get_lm_provider(lm_id, self._lm_providers)[1]:
self.log.warning(
f"No language model is associated with '{lm_id}'. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not get_em_provider(em_id, self._em_providers)[1]:
self.log.warning(
f"No embedding model is associated with '{em_id}'. Setting to None."
)
config.embeddings_provider_id = None

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)

def _handle_validation_error(self, e: ValidationError):
formatted_error = _format_validation_errors(e)
self._config_error = APIErrorModel(
type="ValidationError",
message="Configuration validation failed",
details=formatted_error,
)
self.log.error(f"Configuration validation error: {self.config_error}")

def _read_config(self) -> GlobalConfig:
"""Returns the user's current configuration as a GlobalConfig object.
Expand Down
5 changes: 3 additions & 2 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,10 @@ def config_manager(self):

@web.authenticated
def get(self):
if self.config_manager.config_error:
config_error = self.config_manager.get_config_error()
if config_error:
self.set_status(400)
self.finish(self.config_manager.config_error.json())
self.finish(config_error.json())
return

config = self.config_manager.get_config()
Expand Down

0 comments on commit 7680760

Please sign in to comment.