Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save wiseaidev/bc102165f43db4ebd84fcdb4c5bfb129 to your computer and use it in GitHub Desktop.
Save wiseaidev/bc102165f43db4ebd84fcdb4c5bfb129 to your computer and use it in GitHub Desktop.
A little proxy kernel (and installer) that manages a wrapped kernel connected with tcp. It was designed to support the case where the server starts kernels with ipc transport but only tcp is supported (like Rust).
import argparse
import json
import os
import os.path
import shutil
import sys
from jupyter_client.kernelspec import (KernelSpec, KernelSpecManager,
NoSuchKernel)
parser = argparse.ArgumentParser()
parser.add_argument("--kernel", type=str, required=True)
parser.add_argument("--implementation", type=str, required=True)
parser.add_argument("--quiet", action="store_true", default=False)
args = parser.parse_args()
def log(*log_args):
if not args.quiet:
print(*log_args)
kernel_spec_manager = KernelSpecManager()
try:
real_kernel_spec: KernelSpec = kernel_spec_manager.get_kernel_spec(args.kernel)
except NoSuchKernel:
print(f"No kernel installed with name {args.kernel}. Available kernels:")
for name, path in kernel_spec_manager.find_kernel_specs().items():
print(f" - {name}\t{path}")
exit(1)
log(f"Moving {args.kernel} kernel from {real_kernel_spec.resource_dir}...")
real_kernel_install_path = real_kernel_spec.resource_dir
new_kernel_name = f"{args.kernel}_tcp"
new_kernel_install_path = os.path.join(
os.path.dirname(real_kernel_install_path), new_kernel_name
)
shutil.move(real_kernel_install_path, new_kernel_install_path)
# Update the moved kernel name and args. We tag it _tcp because the proxy will
# impersonate it and should be the one using the real name.
new_kernel_json_path = os.path.join(new_kernel_install_path, "kernel.json")
with open(new_kernel_json_path, "r") as in_:
real_kernel_json = json.load(in_)
real_kernel_json["name"] = new_kernel_name
real_kernel_json["argv"] = list(
map(
lambda arg: arg.replace(real_kernel_install_path, new_kernel_install_path),
real_kernel_json["argv"],
)
)
with open(new_kernel_json_path, "w") as out:
json.dump(real_kernel_json, out)
log(f"Wrote modified kernel.json for {new_kernel_name} in {new_kernel_json_path}")
log(
f"Installing the proxy kernel in place of {args.kernel} in {real_kernel_install_path}"
)
os.makedirs(real_kernel_install_path)
proxy_kernel_implementation_path = os.path.join(
real_kernel_install_path, "ipc_proxy_kernel.py"
)
proxy_kernel_spec = KernelSpec()
proxy_kernel_spec.argv = [
sys.executable,
proxy_kernel_implementation_path,
"{connection_file}",
f"--kernel={new_kernel_name}",
]
proxy_kernel_spec.display_name = real_kernel_spec.display_name
proxy_kernel_spec.interrupt_mode = real_kernel_spec.interrupt_mode or "message"
proxy_kernel_spec.language = real_kernel_spec.language
proxy_kernel_json_path = os.path.join(real_kernel_install_path, "kernel.json")
with open(proxy_kernel_json_path, "w") as out:
json.dump(proxy_kernel_spec.to_dict(), out, indent=2)
log(f"Installed proxy kernelspec: {proxy_kernel_spec.to_json()}")
shutil.copy(args.implementation, proxy_kernel_implementation_path)
print("Proxy kernel installed. Go to 'Runtime > Change runtime type' and select 'Rust'")
import argparse
import json
from threading import Thread
import zmq
from jupyter_client import KernelClient
from jupyter_client.channels import HBChannel
from jupyter_client.manager import KernelManager
from jupyter_client.session import Session
from traitlets.traitlets import Type
parser = argparse.ArgumentParser()
parser.add_argument("connection_file")
parser.add_argument("--kernel", type=str, required=True)
args = parser.parse_args()
# parse connection file details
with open(args.connection_file, "r") as connection_file:
connection_file_contents = json.load(connection_file)
transport = str(connection_file_contents["transport"])
ip = str(connection_file_contents["ip"])
shell_port = int(connection_file_contents["shell_port"])
stdin_port = int(connection_file_contents["stdin_port"])
control_port = int(connection_file_contents["control_port"])
iopub_port = int(connection_file_contents["iopub_port"])
hb_port = int(connection_file_contents["hb_port"])
signature_scheme = str(connection_file_contents["signature_scheme"])
key = str(connection_file_contents["key"]).encode()
# channel | kernel_type | client_type
# shell | ROUTER | DEALER
# stdin | ROUTER | DEALER
# ctrl | ROUTER | DEALER
# iopub | PUB | SUB
# hb | REP | REQ
zmq_context = zmq.Context()
def create_and_bind_socket(port: int, socket_type: int):
if port <= 0:
raise ValueError(f"Invalid port: {port}")
if transport == "tcp":
addr = f"tcp://{ip}:{port}"
elif transport == "ipc":
addr = f"ipc://{ip}-{port}"
else:
raise ValueError(f"Unknown transport: {transport}")
socket: zmq.Socket = zmq_context.socket(socket_type)
socket.linger = 1000 # ipykernel does this
socket.bind(addr)
return socket
shell_socket = create_and_bind_socket(shell_port, zmq.ROUTER)
stdin_socket = create_and_bind_socket(stdin_port, zmq.ROUTER)
control_socket = create_and_bind_socket(control_port, zmq.ROUTER)
iopub_socket = create_and_bind_socket(iopub_port, zmq.PUB)
hb_socket = create_and_bind_socket(hb_port, zmq.REP)
# Proxy and the real kernel have their own heartbeats. (shoutout to ipykernel
# for this neat little heartbeat implementation)
Thread(target=zmq.device, args=(zmq.QUEUE, hb_socket, hb_socket)).start()
def ZMQProxyChannel_factory(proxy_server_socket: zmq.Socket):
class ZMQProxyChannel(object):
kernel_client_socket: zmq.Socket = None
session: Session = None
def __init__(self, socket: zmq.Socket, session: Session, _=None):
super().__init__()
self.kernel_client_socket = socket
self.session = session
def start(self):
# Very convenient zmq device here, proxy will handle the actual zmq
# proxying on each of our connected sockets (other than heartbeat).
# It blocks while they are connected so stick it in a thread.
Thread(
target=zmq.proxy,
args=(proxy_server_socket, self.kernel_client_socket),
).start()
def stop(self):
if self.kernel_client_socket is not None:
try:
self.kernel_client_socket.close(linger=0)
except Exception:
pass
self.kernel_client_socket = None
def is_alive(self):
return self.kernel_client_socket is not None
return ZMQProxyChannel
class ProxyKernelClient(KernelClient):
shell_channel_class = Type(ZMQProxyChannel_factory(shell_socket))
stdin_channel_class = Type(ZMQProxyChannel_factory(stdin_socket))
control_channel_class = Type(ZMQProxyChannel_factory(control_socket))
iopub_channel_class = Type(ZMQProxyChannel_factory(iopub_socket))
hb_channel_class = Type(HBChannel)
kernel_manager = KernelManager()
kernel_manager.kernel_name = args.kernel
kernel_manager.transport = "tcp"
kernel_manager.client_factory = ProxyKernelClient
kernel_manager.autorestart = False
# Make sure the wrapped kernel uses the same session info. This way we don't
# need to decode them before forwarding, we can directly pass everything
# through.
kernel_manager.session.signature_scheme = signature_scheme
kernel_manager.session.key = key
kernel_manager.start_kernel()
# Connect to the real kernel we just started and start up all the proxies.
kernel_client: ProxyKernelClient = kernel_manager.client()
kernel_client.start_channels()
# Everything should be up and running. We now just wait for the managed kernel
# process to exit and when that happens, shutdown and exit with the same code.
exit_code = kernel_manager.kernel.wait()
kernel_client.stop_channels()
zmq_context.destroy(0)
exit(exit_code)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment