diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index a41c556c7..a9385801a 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -166,7 +166,7 @@ def initialize_settings(self): # requires the event loop to be running on init. So instead we schedule # this as a task that is run as soon as the loop starts, and pass # consumers a Future that resolves to the Dask client when awaited. - dask_client_future = loop.create_task(self._get_dask_client()) + self.settings["dask_client_future"] = loop.create_task(self._get_dask_client()) eps = entry_points() # initialize chat handlers @@ -178,7 +178,7 @@ def initialize_settings(self): "root_chat_handlers": self.settings["jai_root_chat_handlers"], "chat_history": self.settings["chat_history"], "root_dir": self.serverapp.root_dir, - "dask_client_future": dask_client_future, + "dask_client_future": self.settings["dask_client_future"], "model_parameters": self.settings["model_parameters"], } default_chat_handler = DefaultChatHandler(**chat_handler_kwargs) @@ -263,3 +263,27 @@ def initialize_settings(self): async def _get_dask_client(self): return DaskClient(processes=False, asynchronous=True) + + async def stop_extension(self): + """ + Public method called by Jupyter Server when the server is stopping. + This calls the cleanup code defined in `self._stop_exception()` inside + an exception handler, as the server halts if this method raises an + exception. + """ + try: + await self._stop_extension() + except Exception as e: + self.log.error("Jupyter AI raised an exception while stopping:") + self.log.exception(e) + + async def _stop_extension(self): + """ + Private method that defines the cleanup code to run when the server is + stopping. + """ + if "dask_client_future" in self.settings: + dask_client: DaskClient = await self.settings["dask_client_future"] + self.log.info("Closing Dask client.") + await dask_client.close() + self.log.debug("Closed Dask client.")