From 4094644fd4c77350ebc7a1ea28e69df49a428097 Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 16 Feb 2023 11:37:07 +0100 Subject: [PATCH] async fixes in kernel usage handler - handle get_msg being async _or not_ - use async poller to avoid blocking while waiting for usage response --- jupyter_resource_usage/api.py | 35 ++++++++++++++--------------------- pyproject.toml | 1 + 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/jupyter_resource_usage/api.py b/jupyter_resource_usage/api.py index 2b948af..d54dea3 100644 --- a/jupyter_resource_usage/api.py +++ b/jupyter_resource_usage/api.py @@ -1,21 +1,15 @@ import json from concurrent.futures import ThreadPoolExecutor +from inspect import isawaitable import psutil -import zmq +import zmq.asyncio from jupyter_client.jsonutil import date_default from jupyter_server.base.handlers import APIHandler -from jupyter_server.utils import url_path_join from packaging import version from tornado import web from tornado.concurrent import run_on_executor -try: - # Traitlets >= 4.3.3 - from traitlets import Callable -except ImportError: - from .utils import Callable - try: import ipykernel @@ -24,8 +18,6 @@ except ImportError: USAGE_IS_SUPPORTED = False -MAX_RETRIES = 3 - class ApiHandler(APIHandler): executor = ThreadPoolExecutor(max_workers=5) @@ -113,17 +105,18 @@ async def get(self, matched_part=None, *args, **kwargs): usage_request = session.msg("usage_request", {}) control_channel.send(usage_request) - poller = zmq.Poller() + poller = zmq.asyncio.Poller() control_socket = control_channel.socket poller.register(control_socket, zmq.POLLIN) - for i in range(1, MAX_RETRIES + 1): - timeout_ms = 1000 * i - events = dict(poller.poll(timeout_ms)) - if not events: - self.write(json.dumps({})) - break - if control_socket not in events: - continue - res = await client.control_channel.get_msg(timeout=0) + # previous behavior was 3 retries: 1 + 2 + 3 = 6 seconds + timeout_ms = 6_000 + events = dict(await poller.poll(timeout_ms)) + if control_socket not in events: + self.write(json.dumps({})) + else: + res = client.control_channel.get_msg(timeout=0) + if isawaitable(res): + # control_channel.get_msg may return a Future, + # depending on configured KernelManager class + res = await res self.write(json.dumps(res, default=date_default)) - break diff --git a/pyproject.toml b/pyproject.toml index 9baa269..5921a17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "jupyter_server>=1.0", "prometheus_client", "psutil~=5.6", + "pyzmq>=19", ] dynamic = ["version"]