Skip to content

Instantly share code, notes, and snippets.

@SpencerPark
Last active December 20, 2023 13:26
Show Gist options
  • Save SpencerPark/e2732061ad19c1afa4a33a58cb8f18a9 to your computer and use it in GitHub Desktop.
Save SpencerPark/e2732061ad19c1afa4a33a58cb8f18a9 to your computer and use it in GitHub Desktop.
A little proxy kernel (and installer) that manages a wrapped kernel connected with tcp. It was designed to support the case where the server starts kernels with ipc transport but only tcp is supported (like IJava). See https://gist.github.com/SpencerPark/447de114fcd3e6a272dc140809462e30 for a sample notebook that installs this.
import argparse
import json
import os
import os.path
import shutil
import sys
from jupyter_client.kernelspec import KernelSpec, KernelSpecManager, NoSuchKernel
parser = argparse.ArgumentParser()
parser.add_argument("--kernel", type=str, required=True)
parser.add_argument("--implementation", type=str, required=True)
parser.add_argument("--quiet", action="store_true", default=False)
args = parser.parse_args()
def log(*log_args):
if not args.quiet:
print(*log_args)
kernel_spec_manager = KernelSpecManager()
try:
real_kernel_spec: KernelSpec = kernel_spec_manager.get_kernel_spec(args.kernel)
except NoSuchKernel:
print(f"No kernel installed with name {args.kernel}. Available kernels:")
for name, path in kernel_spec_manager.find_kernel_specs().items():
print(f" - {name}\t{path}")
exit(1)
log(f"Moving {args.kernel} kernel from {real_kernel_spec.resource_dir}...")
real_kernel_install_path = real_kernel_spec.resource_dir
new_kernel_name = f"{args.kernel}_tcp"
new_kernel_install_path = os.path.join(
os.path.dirname(real_kernel_install_path), new_kernel_name
)
shutil.move(real_kernel_install_path, new_kernel_install_path)
# Update the moved kernel name and args. We tag it _tcp because the proxy will
# impersonate it and should be the one using the real name.
new_kernel_json_path = os.path.join(new_kernel_install_path, "kernel.json")
with open(new_kernel_json_path, "r") as in_:
real_kernel_json = json.load(in_)
real_kernel_json["name"] = new_kernel_name
real_kernel_json["argv"] = list(
map(
lambda arg: arg.replace(real_kernel_install_path, new_kernel_install_path),
real_kernel_json["argv"],
)
)
with open(new_kernel_json_path, "w") as out:
json.dump(real_kernel_json, out)
log(f"Wrote modified kernel.json for {new_kernel_name} in {new_kernel_json_path}")
log(
f"Installing the proxy kernel in place of {args.kernel} in {real_kernel_install_path}"
)
os.makedirs(real_kernel_install_path)
proxy_kernel_implementation_path = os.path.join(
real_kernel_install_path, "ipc_proxy_kernel.py"
)
proxy_kernel_spec = KernelSpec()
proxy_kernel_spec.argv = [
sys.executable,
proxy_kernel_implementation_path,
"{connection_file}",
f"--kernel={new_kernel_name}",
]
proxy_kernel_spec.display_name = real_kernel_spec.display_name
proxy_kernel_spec.interrupt_mode = real_kernel_spec.interrupt_mode or "message"
proxy_kernel_spec.language = real_kernel_spec.language
proxy_kernel_json_path = os.path.join(real_kernel_install_path, "kernel.json")
with open(proxy_kernel_json_path, "w") as out:
json.dump(proxy_kernel_spec.to_dict(), out, indent=2)
log(f"Installed proxy kernelspec: {proxy_kernel_spec.to_json()}")
shutil.copy(args.implementation, proxy_kernel_implementation_path)
print("Proxy kernel installed. Go to 'Runtime > Change runtime type' and select 'java'")
import argparse
import json
from threading import Thread
import zmq
from jupyter_client import KernelClient
from jupyter_client.channels import HBChannel
from jupyter_client.manager import KernelManager
from jupyter_client.session import Session
from traitlets.traitlets import Type
parser = argparse.ArgumentParser()
parser.add_argument("connection_file")
parser.add_argument("--kernel", type=str, required=True)
args = parser.parse_args()
# parse connection file details
with open(args.connection_file, "r") as connection_file:
connection_file_contents = json.load(connection_file)
transport = str(connection_file_contents["transport"])
ip = str(connection_file_contents["ip"])
shell_port = int(connection_file_contents["shell_port"])
stdin_port = int(connection_file_contents["stdin_port"])
control_port = int(connection_file_contents["control_port"])
iopub_port = int(connection_file_contents["iopub_port"])
hb_port = int(connection_file_contents["hb_port"])
signature_scheme = str(connection_file_contents["signature_scheme"])
key = str(connection_file_contents["key"]).encode()
# channel | kernel_type | client_type
# shell | ROUTER | DEALER
# stdin | ROUTER | DEALER
# ctrl | ROUTER | DEALER
# iopub | PUB | SUB
# hb | REP | REQ
zmq_context = zmq.Context()
def create_and_bind_socket(port: int, socket_type: int):
if port <= 0:
raise ValueError(f"Invalid port: {port}")
if transport == "tcp":
addr = f"tcp://{ip}:{port}"
elif transport == "ipc":
addr = f"ipc://{ip}-{port}"
else:
raise ValueError(f"Unknown transport: {transport}")
socket: zmq.Socket = zmq_context.socket(socket_type)
socket.linger = 1000 # ipykernel does this
socket.bind(addr)
return socket
shell_socket = create_and_bind_socket(shell_port, zmq.ROUTER)
stdin_socket = create_and_bind_socket(stdin_port, zmq.ROUTER)
control_socket = create_and_bind_socket(control_port, zmq.ROUTER)
iopub_socket = create_and_bind_socket(iopub_port, zmq.PUB)
hb_socket = create_and_bind_socket(hb_port, zmq.REP)
# Proxy and the real kernel have their own heartbeats. (shoutout to ipykernel
# for this neat little heartbeat implementation)
Thread(target=zmq.device, args=(zmq.QUEUE, hb_socket, hb_socket)).start()
def ZMQProxyChannel_factory(proxy_server_socket: zmq.Socket):
class ZMQProxyChannel(object):
kernel_client_socket: zmq.Socket = None
session: Session = None
def __init__(self, socket: zmq.Socket, session: Session, _=None):
super().__init__()
self.kernel_client_socket = socket
self.session = session
def start(self):
# Very convenient zmq device here, proxy will handle the actual zmq
# proxying on each of our connected sockets (other than heartbeat).
# It blocks while they are connected so stick it in a thread.
Thread(
target=zmq.proxy,
args=(proxy_server_socket, self.kernel_client_socket),
).start()
def stop(self):
if self.kernel_client_socket is not None:
try:
self.kernel_client_socket.close(linger=0)
except Exception:
pass
self.kernel_client_socket = None
def is_alive(self):
return self.kernel_client_socket is not None
return ZMQProxyChannel
class ProxyKernelClient(KernelClient):
shell_channel_class = Type(ZMQProxyChannel_factory(shell_socket))
stdin_channel_class = Type(ZMQProxyChannel_factory(stdin_socket))
control_channel_class = Type(ZMQProxyChannel_factory(control_socket))
iopub_channel_class = Type(ZMQProxyChannel_factory(iopub_socket))
hb_channel_class = Type(HBChannel)
kernel_manager = KernelManager()
kernel_manager.kernel_name = args.kernel
kernel_manager.transport = "tcp"
kernel_manager.client_factory = ProxyKernelClient
kernel_manager.autorestart = False
# Make sure the wrapped kernel uses the same session info. This way we don't
# need to decode them before forwarding, we can directly pass everything
# through.
kernel_manager.session.signature_scheme = signature_scheme
kernel_manager.session.key = key
kernel_manager.start_kernel()
# Connect to the real kernel we just started and start up all the proxies.
kernel_client: ProxyKernelClient = kernel_manager.client()
kernel_client.start_channels()
# Everything should be up and running. We now just wait for the managed kernel
# process to exit and when that happens, shutdown and exit with the same code.
exit_code = kernel_manager.kernel.wait()
kernel_client.stop_channels()
zmq_context.destroy(0)
exit(exit_code)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment