Last active
October 29, 2020 11:34
-
-
Save vinayak-mehta/cf4ab04141809544a4f11280a78e98df to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import json | |
import hmac | |
import hashlib | |
import traceback | |
from binascii import b2a_hex | |
from datetime import datetime, timezone | |
import zmq | |
from zmq.utils import jsonapi | |
from jupyter_client import KernelManager | |
json_packer = lambda obj: jsonapi.dumps(obj, ensure_ascii=False, allow_nan=False) | |
json_unpacker = lambda s: jsonapi.loads(s) | |
def new_id(): | |
buf = os.urandom(16) | |
return "-".join(b2a_hex(x).decode("ascii") for x in (buf[:4], buf[4:])) | |
class Session(object): | |
def __init__(self, key): | |
self.key = key | |
self.session_id = new_id() | |
class Cutypr(object): | |
def __init__(self, session=None, ports=None): | |
self.context = zmq.Context() | |
self.session = session | |
self.ports = ports | |
self.message_count = 0 | |
self._shell_channel = None | |
self._iopub_channel = None | |
self.auth = hmac.HMAC(session.key, digestmod=hashlib.sha256) | |
def _make_url(self, channel): | |
port = self.ports[channel] | |
return f"tcp://127.0.0.1:{port}" | |
def _make_channel(self, channel): | |
socket_type = { | |
"shell": zmq.DEALER, | |
"iopub": zmq.SUB, | |
} | |
sock = self.context.socket(socket_type[channel]) | |
sock.linger = 1000 | |
sock.connect(self._make_url(channel)) | |
if channel == "iopub": | |
sock.setsockopt(zmq.SUBSCRIBE, b"") | |
return sock | |
@property | |
def shell_channel(self): | |
if self._shell_channel is None: | |
self._shell_channel = self._make_channel("shell") | |
return self._shell_channel | |
@property | |
def iopub_channel(self): | |
if self._iopub_channel is None: | |
self._iopub_channel = self._make_channel("iopub") | |
return self._iopub_channel | |
def _make_message(self, message_type, content): | |
msg = {} | |
msg_id = f"{self.session.session_id}_{self.message_count}" | |
self.message_count += 1 | |
header = { | |
"msg_id": msg_id, | |
"msg_type": message_type, | |
"username": "vinayak", | |
"session": self.session.session_id, | |
} | |
msg["header"] = header | |
msg["msg_id"] = header["msg_id"] | |
msg["msg_type"] = header["msg_type"] | |
msg["content"] = content | |
msg["metadata"] = {} | |
msg["parent_header"] = {} | |
return msg | |
def sign(self, msg_list): | |
h = self.auth.copy() | |
for m in msg_list: | |
h.update(m) | |
return h.hexdigest().encode("utf-8") | |
def serialize(self, msg): | |
msg_list = [ | |
json_packer(msg["header"]), | |
json_packer(msg["parent_header"]), | |
json_packer(msg["metadata"]), | |
json_packer(msg.get("content", {})), | |
] | |
DELIM = b"<IDS|MSG>" | |
signature = self.sign(msg_list) | |
to_send = [ | |
DELIM, | |
signature, | |
] | |
to_send.extend(msg_list) | |
return to_send | |
def execute(self, code): | |
content = dict( | |
code=code, | |
silent=False, | |
store_history=True, | |
user_expressions=None, | |
allow_stdin=True, | |
stop_on_error=True, | |
) | |
msg = self._make_message("execute_request", content) | |
msg_list = self.serialize(msg) | |
self.shell_channel.send_multipart(msg_list) | |
return msg["header"]["msg_id"] | |
def deserialize(self, msg_list): | |
message = {} | |
header = json_unpacker(msg_list[3]) | |
message["header"] = header | |
message["msg_id"] = header["msg_id"] | |
message["msg_type"] = header["msg_type"] | |
message["metadata"] = json_unpacker(msg_list[5]) | |
message["content"] = json_unpacker(msg_list[6]) | |
return message | |
def is_alive(self, channel): | |
return eval(f"self.{channel}_channel") is not None | |
def msg_ready(self): | |
return bool(self.iopub_channel.poll(timeout=0)) | |
def get_msg(self): | |
msg_list = self.iopub_channel.recv_multipart() | |
msg = self.deserialize(msg_list) | |
return msg | |
if __name__ == "__main__": | |
try: | |
manager = KernelManager() | |
manager.start_kernel() | |
port_names = ["shell", "stdin", "iopub", "hb", "control"] | |
ports = dict(list(zip(port_names, manager.ports))) | |
session = Session(key=manager.session.key) | |
client = Cutypr(session=session, ports=ports) | |
execution_state = "idle" | |
execution_count = 1 | |
while True: | |
code = input(f"In [{execution_count}]: ") | |
if not code.strip(): | |
continue | |
client.execute(code) | |
execution_state = "busy" | |
while execution_state != "idle" and client.is_alive("iopub"): | |
while client.msg_ready(): | |
msg = client.get_msg() | |
msg_type = msg["header"]["msg_type"] | |
if msg_type == "status": | |
execution_state = msg["content"]["execution_state"] | |
elif msg_type == "stream": | |
if msg["content"]["name"] == "stdout": | |
print(msg["content"]["text"]) | |
sys.stdout.flush() | |
elif msg["content"]["name"] == "stderr": | |
print(msg["content"]["text"], file=sys.stderr) | |
sys.stderr.flush() | |
elif msg_type == "execute_result": | |
pass | |
elif msg_type == "display_data": | |
pass | |
elif msg_type == "execute_input": | |
execution_count = int(msg["content"]["execution_count"]) + 1 | |
elif msg_type == "clear_output": | |
pass | |
elif msg_type == "error": | |
for frame in msg["content"]["traceback"]: | |
print(frame, file=sys.stderr) | |
except Exception as e: | |
traceback.print_exc() | |
finally: | |
manager.shutdown_kernel() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment