From 09c3c359addf60e26078207990ad2ca932cf2613 Mon Sep 17 00:00:00 2001 From: Jason Grout Date: Tue, 25 Jul 2023 10:11:22 -0600 Subject: [PATCH] Merge connection info into existing connection file if it already exists (#1133) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ipykernel/kernelapp.py | 26 +++++---- ipykernel/tests/test_ipkernel_direct.py | 2 +- ipykernel/tests/test_kernelapp.py | 71 +++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 10 deletions(-) diff --git a/ipykernel/kernelapp.py b/ipykernel/kernelapp.py index 922db6347..da6d19656 100644 --- a/ipykernel/kernelapp.py +++ b/ipykernel/kernelapp.py @@ -24,7 +24,6 @@ ) from IPython.core.profiledir import ProfileDir from IPython.core.shellapp import InteractiveShellApp, shell_aliases, shell_flags -from jupyter_client import write_connection_file from jupyter_client.connect import ConnectionFileMixin from jupyter_client.session import Session, session_aliases, session_flags from jupyter_core.paths import jupyter_runtime_dir @@ -44,10 +43,11 @@ from traitlets.utils.importstring import import_item from zmq.eventloop.zmqstream import ZMQStream -from .control import ControlThread -from .heartbeat import Heartbeat +from .connect import get_connection_info, write_connection_file # local imports +from .control import ControlThread +from .heartbeat import Heartbeat from .iostream import IOPubThread from .ipkernel import IPythonKernel from .parentpoller import ParentPollerUnix, ParentPollerWindows @@ -260,12 +260,7 @@ def _bind_socket(self, s, port): def write_connection_file(self): """write connection info to JSON file""" cf = self.abs_connection_file - if os.path.exists(cf): - self.log.debug("Connection file %s already exists", cf) - return - self.log.debug("Writing connection file: %s", cf) - write_connection_file( - cf, + connection_info = dict( ip=self.ip, key=self.session.key, transport=self.transport, @@ -275,6 +270,19 @@ def write_connection_file(self): iopub_port=self.iopub_port, control_port=self.control_port, ) + if os.path.exists(cf): + # If the file exists, merge our info into it. For example, if the + # original file had port number 0, we update with the actual port + # used. + existing_connection_info = get_connection_info(cf, unpack=True) + connection_info = dict(existing_connection_info, **connection_info) + if connection_info == existing_connection_info: + self.log.debug("Connection file %s with current information already exists", cf) + return + + self.log.debug("Writing connection file: %s", cf) + + write_connection_file(cf, **connection_info) def cleanup_connection_file(self): """Clean up our connection file.""" diff --git a/ipykernel/tests/test_ipkernel_direct.py b/ipykernel/tests/test_ipkernel_direct.py index 20e92a402..b0dbb01f5 100644 --- a/ipykernel/tests/test_ipkernel_direct.py +++ b/ipykernel/tests/test_ipkernel_direct.py @@ -20,7 +20,7 @@ class user_mod: __dict__ = {} -async def test_properities(ipkernel: IPythonKernel) -> None: +async def test_properties(ipkernel: IPythonKernel) -> None: ipkernel.user_module = user_mod() ipkernel.user_ns = {} diff --git a/ipykernel/tests/test_kernelapp.py b/ipykernel/tests/test_kernelapp.py index 9a7b1f92e..da38777d0 100644 --- a/ipykernel/tests/test_kernelapp.py +++ b/ipykernel/tests/test_kernelapp.py @@ -1,13 +1,17 @@ +import json import os import threading import time from unittest.mock import patch import pytest +from jupyter_core.paths import secure_write +from traitlets.config.loader import Config from ipykernel.kernelapp import IPKernelApp from .conftest import MockKernel +from .utils import TemporaryWorkingDirectory try: import trio @@ -47,6 +51,73 @@ def trigger_stop(): app.close() +@pytest.mark.skipif(os.name == "nt", reason="permission errors on windows") +def test_merge_connection_file(): + cfg = Config() + with TemporaryWorkingDirectory() as d: + cfg.ProfileDir.location = d + cf = os.path.join(d, "kernel.json") + initial_connection_info = { + "ip": "*", + "transport": "tcp", + "shell_port": 0, + "hb_port": 0, + "iopub_port": 0, + "stdin_port": 0, + "control_port": 53555, + "key": "abc123", + "signature_scheme": "hmac-sha256", + "kernel_name": "My Kernel", + } + # We cannot use connect.write_connection_file since + # it replaces port number 0 with a random port + # and we want IPKernelApp to do that replacement. + with secure_write(cf) as f: + json.dump(initial_connection_info, f) + assert os.path.exists(cf) + + app = IPKernelApp(config=cfg, connection_file=cf) + + # Calling app.initialize() does not work in the test, so we call the relevant functions that initialize() calls + # We must pass in an empty argv, otherwise the default is to try to parse the test runner's argv + super(IPKernelApp, app).initialize(argv=[""]) + app.init_connection_file() + app.init_sockets() + app.init_heartbeat() + app.write_connection_file() + + # Initialize should have merged the actual connection info + # with the connection info in the file + assert cf == app.abs_connection_file + assert os.path.exists(cf) + + with open(cf) as f: + new_connection_info = json.load(f) + + # ports originally set as 0 have been replaced + for port in ("shell", "hb", "iopub", "stdin"): + key = f"{port}_port" + # We initially had the port as 0 + assert initial_connection_info[key] == 0 + # the port is not 0 now + assert new_connection_info[key] > 0 + # the port matches the port the kernel actually used + assert new_connection_info[key] == getattr(app, key), f"{key}" + del new_connection_info[key] + del initial_connection_info[key] + + # The wildcard ip address was also replaced + assert new_connection_info["ip"] != "*" + del new_connection_info["ip"] + del initial_connection_info["ip"] + + # everything else in the connection file is the same + assert initial_connection_info == new_connection_info + + app.close() + os.remove(cf) + + @pytest.mark.skipif(trio is None, reason="requires trio") def test_trio_loop(): app = IPKernelApp(trio_loop=True)