diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 339863c3f..b30a80d0e 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -289,13 +289,17 @@ def filter_predicate(local_model_id: str): else: return model_id in self.allowed_models + def filter_model_list(model_list: Optional[List[str]]) -> List[str]: + if model_list is not None: + return list(filter(filter_predicate, provider.models)) + else: + return [] + # filter out every model w/ model ID according to allow/blocklist for provider in providers: - provider.models = list(filter(filter_predicate, provider.models)) - provider.chat_models = list(filter(filter_predicate, provider.chat_models)) - provider.completion_models = list( - filter(filter_predicate, provider.completion_models) - ) + provider.models = filter_model_list(provider.models) + provider.chat_models = filter_model_list(provider.chat_models) + provider.completion_models = filter_model_list(provider.completion_models) # filter out every provider with no models which satisfy the allow/blocklist, then return return filter((lambda p: len(p.models) > 0), providers)