diff --git a/jupyter_resource_usage/api.py b/jupyter_resource_usage/api.py index 2b948af..d54dea3 100644 --- a/jupyter_resource_usage/api.py +++ b/jupyter_resource_usage/api.py @@ -1,21 +1,15 @@ import json from concurrent.futures import ThreadPoolExecutor +from inspect import isawaitable import psutil -import zmq +import zmq.asyncio from jupyter_client.jsonutil import date_default from jupyter_server.base.handlers import APIHandler -from jupyter_server.utils import url_path_join from packaging import version from tornado import web from tornado.concurrent import run_on_executor -try: - # Traitlets >= 4.3.3 - from traitlets import Callable -except ImportError: - from .utils import Callable - try: import ipykernel @@ -24,8 +18,6 @@ except ImportError: USAGE_IS_SUPPORTED = False -MAX_RETRIES = 3 - class ApiHandler(APIHandler): executor = ThreadPoolExecutor(max_workers=5) @@ -113,17 +105,18 @@ async def get(self, matched_part=None, *args, **kwargs): usage_request = session.msg("usage_request", {}) control_channel.send(usage_request) - poller = zmq.Poller() + poller = zmq.asyncio.Poller() control_socket = control_channel.socket poller.register(control_socket, zmq.POLLIN) - for i in range(1, MAX_RETRIES + 1): - timeout_ms = 1000 * i - events = dict(poller.poll(timeout_ms)) - if not events: - self.write(json.dumps({})) - break - if control_socket not in events: - continue - res = await client.control_channel.get_msg(timeout=0) + # previous behavior was 3 retries: 1 + 2 + 3 = 6 seconds + timeout_ms = 6_000 + events = dict(await poller.poll(timeout_ms)) + if control_socket not in events: + self.write(json.dumps({})) + else: + res = client.control_channel.get_msg(timeout=0) + if isawaitable(res): + # control_channel.get_msg may return a Future, + # depending on configured KernelManager class + res = await res self.write(json.dumps(res, default=date_default)) - break diff --git a/pyproject.toml b/pyproject.toml index 9baa269..5921a17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "jupyter_server>=1.0", "prometheus_client", "psutil~=5.6", + "pyzmq>=19", ] dynamic = ["version"]