diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index a41c556c7..a9385801a 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -166,7 +166,7 @@ def initialize_settings(self): # requires the event loop to be running on init. So instead we schedule # this as a task that is run as soon as the loop starts, and pass # consumers a Future that resolves to the Dask client when awaited. - dask_client_future = loop.create_task(self._get_dask_client()) + self.settings["dask_client_future"] = loop.create_task(self._get_dask_client()) eps = entry_points() # initialize chat handlers @@ -178,7 +178,7 @@ def initialize_settings(self): "root_chat_handlers": self.settings["jai_root_chat_handlers"], "chat_history": self.settings["chat_history"], "root_dir": self.serverapp.root_dir, - "dask_client_future": dask_client_future, + "dask_client_future": self.settings["dask_client_future"], "model_parameters": self.settings["model_parameters"], } default_chat_handler = DefaultChatHandler(**chat_handler_kwargs) @@ -263,3 +263,27 @@ def initialize_settings(self): async def _get_dask_client(self): return DaskClient(processes=False, asynchronous=True) + + async def stop_extension(self): + """ + Public method called by Jupyter Server when the server is stopping. + This calls the cleanup code defined in `self._stop_exception()` inside + an exception handler, as the server halts if this method raises an + exception. + """ + try: + await self._stop_extension() + except Exception as e: + self.log.error("Jupyter AI raised an exception while stopping:") + self.log.exception(e) + + async def _stop_extension(self): + """ + Private method that defines the cleanup code to run when the server is + stopping. + """ + if "dask_client_future" in self.settings: + dask_client: DaskClient = await self.settings["dask_client_future"] + self.log.info("Closing Dask client.") + await dask_client.close() + self.log.debug("Closed Dask client.") diff --git a/yarn.lock b/yarn.lock index 3cdc4ba8b..8c398f3fd 100644 --- a/yarn.lock +++ b/yarn.lock @@ -15070,11 +15070,11 @@ __metadata: "typescript@patch:typescript@^3 || ^4#~builtin": version: 4.9.5 - resolution: "typescript@patch:typescript@npm%3A4.9.5#~builtin::version=4.9.5&hash=23ec76" + resolution: "typescript@patch:typescript@npm%3A4.9.5#~builtin::version=4.9.5&hash=289587" bin: tsc: bin/tsc tsserver: bin/tsserver - checksum: ab417a2f398380c90a6cf5a5f74badd17866adf57f1165617d6a551f059c3ba0a3e4da0d147b3ac5681db9ac76a303c5876394b13b3de75fdd5b1eaa06181c9d + checksum: 1f8f3b6aaea19f0f67cba79057674ba580438a7db55057eb89cc06950483c5d632115c14077f6663ea76fd09fce3c190e6414bb98582ec80aa5a4eaf345d5b68 languageName: node linkType: hard