Created
July 7, 2023 21:44
-
-
Save tmaxwell-anthropic/de7a54753d312ccce831dbfd48ef2b12 to your computer and use it in GitHub Desktop.
Make Jupyter report the exit status when the kernel dies
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import signal | |
from jupyter_server.services.kernels.connection.base import ( | |
deserialize_binary_message, | |
deserialize_msg_from_ws_v1, | |
) | |
from jupyter_server.services.kernels.connection.channels import ( | |
ZMQChannelsWebsocketConnection, | |
) | |
from jupyter_server.services.kernels.kernelmanager import ServerKernelManager | |
class StatusReportingKernelManager(ServerKernelManager): | |
"""StatusReportingKernelManager is like ServerKernelManager but also records the reason why the | |
kernel died, in the ServerKernelManager.reason field.""" | |
async def _async_is_alive(self) -> bool: | |
if self.has_kernel: | |
assert self.provisioner is not None | |
ret = await self.provisioner.poll() | |
if ret is None: | |
self.reason = "" | |
return True | |
else: | |
if ret < 0: | |
reason = f"Jupyter kernel unexpectedly exited with signal {signal.Signals(-ret).name}." | |
if ret == -signal.SIGKILL: | |
reason += ( | |
f" (If it was using a lot of memory, it may have been killed by the " | |
f"OOM killer.)" | |
) | |
else: | |
reason = f"Jupyter kernel unexpectedly exited with code {ret}." | |
# If the reason has changed, log and store it | |
if reason != self.reason: | |
self.log.warning(f"Kernel {self.kernel_id}: {reason}") | |
self.reason = reason | |
return False | |
else: | |
self.reason = "" | |
return False | |
is_alive = _async_is_alive | |
class StatusReportingZMQChannelsWebsocketConnection(ZMQChannelsWebsocketConnection): | |
"""StatusReportingZMQChannelsWebsocketConnection is like ZMQChannelsWebsocketConnection but it | |
also tells the client the reason why the kernel restarted/died.""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.last_execute_request_header = None | |
# Override the on_kernel_restarted() / on_restart_failed() hooks to check if | |
# kernel_manager.reason is set to anything helpful | |
def on_kernel_restarted(self): | |
if reason := getattr(self.kernel_manager, "reason", None): | |
self.write_stderr(reason, parent_header=self.last_execute_request_header) | |
super().on_kernel_restarted() | |
def on_restart_failed(self): | |
if reason := getattr(self.kernel_manager, "reason", None): | |
self.write_stderr(reason, parent_header=self.last_execute_request_header) | |
super().on_restart_failed() | |
def handle_incoming_message(self, incoming_msg: str) -> None: | |
"""Same as ZMQChannelsWebsocketConnection.handle_incoming_message(), except we also record | |
last_execute_request_header.""" | |
ws_msg = incoming_msg | |
if not self.channels: | |
# already closed, ignore the message | |
self.log.debug("Received message on closed websocket %r", ws_msg) | |
return | |
if self.subprotocol == "v1.kernel.websocket.jupyter.org": | |
channel, msg_list = deserialize_msg_from_ws_v1(ws_msg) | |
msg = { | |
"header": None, | |
} | |
else: | |
if isinstance(ws_msg, bytes): | |
msg = deserialize_binary_message(ws_msg) | |
else: | |
msg = json.loads(ws_msg) | |
msg_list = [] | |
channel = msg.pop("channel", None) | |
if channel is None: | |
self.log.warning("No channel specified, assuming shell: %s", msg) | |
channel = "shell" | |
if channel not in self.channels: | |
self.log.warning("No such channel: %r", channel) | |
return | |
msg["header"] = self.get_part("header", msg["header"], msg_list) | |
assert msg["header"] is not None | |
if msg["header"]["msg_type"] == "execute_request": | |
self.last_execute_request_header = msg["header"] | |
am = self.multi_kernel_manager.allowed_message_types | |
ignore_msg = False | |
if am: | |
if msg["header"]["msg_type"] not in am: | |
self.log.warning( | |
'Received message of type "%s", which is not allowed. Ignoring.' | |
% msg["header"]["msg_type"] | |
) | |
ignore_msg = True | |
if not ignore_msg: | |
stream = self.channels[channel] | |
if self.subprotocol == "v1.kernel.websocket.jupyter.org": | |
self.session.send_raw(stream, msg_list) | |
else: | |
self.session.send(stream, msg) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment