Skip to content

Instantly share code, notes, and snippets.

@oeway
Created June 30, 2024 21:25
Show Gist options
  • Save oeway/74632b5fd3fbb66e5dc2d54d065a8495 to your computer and use it in GitHub Desktop.
Save oeway/74632b5fd3fbb66e5dc2d54d065a8495 to your computer and use it in GitHub Desktop.
Python script for setting up a WebSocket proxy in Google Colab to communicate between the embedded web page with a WebSocket server.

Description for colab_websocket.js

Title: colab_websocket.js

Description: This JavaScript file implements a WebSocket client for Google Colab that facilitates communication between the Colab notebook and a WebSocket server. It includes buffer handling for binary data and a simple debug console for sending and receiving messages.

Description for colab_websocket_proxy.py

Title: colab_websocket_proxy.py

Description: This Python script sets up a WebSocket proxy in Google Colab, allowing communication between the Colab notebook and an external WebSocket server. It uses websockets for WebSocket communication and integrates with Google Colab's output module to handle messages between Python and the JavaScript client.

function isSerializable(object) {
return typeof object === "object" && object && object.toJSON;
}
function isObject(value) {
return value && typeof value === "object" && value.constructor === Object;
}
function put_buffers(state, buffer_paths, buffers) {
buffers = buffers.map(b => b instanceof DataView ? b.buffer : b);
for (let i = 0; i < buffer_paths.length; i++) {
const buffer_path = buffer_paths[i];
let obj = state;
for (let j = 0; j < buffer_path.length - 1; j++) {
obj = obj[buffer_path[j]];
}
obj[buffer_path[buffer_path.length - 1]] = buffers[i];
}
}
function remove_buffers(state) {
const buffers = [];
const buffer_paths = [];
function remove(obj, path) {
if (isSerializable(obj)) {
obj = obj.toJSON();
}
if (Array.isArray(obj)) {
let is_cloned = false;
for (let i = 0; i < obj.length; i++) {
const value = obj[i];
if (value) {
if (value instanceof ArrayBuffer || ArrayBuffer.isView(value)) {
if (!is_cloned) {
obj = obj.slice();
is_cloned = true;
}
buffers.push(ArrayBuffer.isView(value) ? value.buffer : value);
buffer_paths.push(path.concat([i]));
obj[i] = null;
} else {
const new_value = remove(value, path.concat([i]));
if (new_value !== value) {
if (!is_cloned) {
obj = obj.slice();
is_cloned = true;
}
obj[i] = new_value;
}
}
}
}
} else if (isObject(obj)) {
for (const key in obj) {
let is_cloned = false;
if (Object.prototype.hasOwnProperty.call(obj, key)) {
const value = obj[key];
if (value) {
if (value instanceof ArrayBuffer || ArrayBuffer.isView(value)) {
if (!is_cloned) {
obj = { ...obj };
is_cloned = true;
}
buffers.push(ArrayBuffer.isView(value) ? value.buffer : value);
buffer_paths.push(path.concat([key]));
delete obj[key];
} else {
const new_value = remove(value, path.concat([key]));
if (new_value !== value) {
if (!is_cloned) {
obj = { ...obj };
is_cloned = true;
}
obj[key] = new_value;
}
}
}
}
}
}
return obj;
}
const new_state = remove(state, []);
return { state: new_state, buffers, buffer_paths };
}
class LocalWebSocket {
constructor(url, client_id, workspace) {
this.url = url;
this.onopen = () => {};
this.onmessage = () => {};
this.onclose = () => {};
this.onerror = () => {};
this.client_id = client_id;
this.workspace = workspace;
console.log('Initializing LocalWebSocket with URL:', url);
console.log('Client ID:', client_id);
console.log('Workspace:', workspace);
this.postMessage = (message) => {
console.log('Posting message to kernel:', message);
if (this.comm) {
this.comm.send(message);
} else {
console.error('Comm not initialized.');
}
};
this.readyState = WebSocket.CONNECTING;
console.log('Initial readyState:', this.readyState);
google.colab.kernel.comms.open(this.workspace, {}).then((comm) => {
setTimeout(async () => {
this.readyState = WebSocket.OPEN;
console.log('===> WebSocket connection opened');
this.onopen();
for await (const msg of comm.messages) {
const data = msg.data;
console.log('Received message from kernel:', data);
const buffer_paths = data.__buffer_paths__ || [];
delete data.__buffer_paths__;
put_buffers(data, buffer_paths, msg.buffers || []);
if (data.type === "log" || data.type === "info") {
console.log(data.message);
} else if (data.type === "error") {
console.error(data.message);
} else {
this.onmessage(data);
}
}
}, 0)
this.comm = comm;
}).catch((e) => {
console.error("Failed to connect to kernel comm:", e);
});
this._initUI();
}
_initUI() {
console.log('Initializing UI');
const container = document.createElement('div');
container.style.border = '1px solid black';
container.style.padding = '10px';
container.style.marginTop = '10px';
const title = document.createElement('h3');
title.innerText = 'WebSocket Debug Console';
container.appendChild(title);
const input = document.createElement('input');
input.type = 'text';
input.placeholder = 'Enter message...';
container.appendChild(input);
const button = document.createElement('button');
button.innerText = 'Send';
button.onclick = () => {
const message = input.value;
console.log('Sending message:', message);
if (message && this.readyState === WebSocket.OPEN) {
this.send(message);
const log = document.createElement('p');
log.innerText = `Sent: ${message}`;
container.appendChild(log);
} else {
console.log('Cannot send message, WebSocket not open:', this.readyState);
}
};
container.appendChild(button);
const logContainer = document.createElement('div');
container.appendChild(logContainer);
this.onmessage = (event) => {
const log = document.createElement('p');
log.innerText = `Received: ${JSON.stringify(event.data)}`;
logContainer.appendChild(log);
};
document.body.appendChild(container);
}
send(data) {
if (this.readyState === WebSocket.OPEN) {
console.log('Sending data:', data);
this.postMessage({
type: "message",
data,
});
} else {
console.log('Cannot send data, WebSocket not open:', this.readyState);
}
}
close() {
this.readyState = WebSocket.CLOSING;
console.log('Closing connection');
this.postMessage({
type: "close"
});
this.onclose();
}
addEventListener(type, listener) {
if (type === "message") this.onmessage = listener;
if (type === "open") this.onopen = listener;
if (type === "close") this.onclose = listener;
if (type === "error") this.onerror = listener;
}
}
(function() {
const client_id = "<client_id>";
const ws_url = "<ws_url>";
const comm_target = "colab_ws_proxy_" + client_id;
const ws = new LocalWebSocket(ws_url, client_id, comm_target);
ws.onopen = () => {
console.log('WebSocket connection opened');
};
ws.onclose = () => {
console.log('WebSocket connection closed');
};
})();
import asyncio
import websockets
import uuid
from IPython.display import display, HTML
from imjoy_rpc.connection.jupyter_connection import put_buffers, remove_buffers
class ColabWebSocketProxy:
def __init__(self, uri):
self.uri = uri
self.client_id = str(uuid.uuid4())
self.comm = None
self.websocket = None
self.connected_event = asyncio.Event()
async def connect(self):
"""Create a WebSocket connection and start proxying messages."""
loop = asyncio.get_running_loop()
async with websockets.connect(self.uri) as websocket:
self.websocket = websocket
self._setup_comm()
async for message in websocket:
self.comm.send({"type": "log", "message": message })
await loop.run_in_executor(None, self.emit, {"type": "message", "data": message})
print(f"Proxy received from server: {message}")
def _setup_comm(self):
"""Set up Colab communication channel."""
def registered(comm, open_msg):
"""Handle registration."""
self.comm = comm
def msg_cb(msg):
"""Handle a message."""
data = msg["content"]["data"]
if "type" in data:
if "__buffer_paths__" in data:
buffer_paths = data["__buffer_paths__"]
del data["__buffer_paths__"]
put_buffers(data, buffer_paths, msg["buffers"])
loop = asyncio.get_running_loop()
loop.create_task(self._handle_comm_message(data))
comm.on_msg(msg_cb)
get_ipython().kernel.comm_manager.register_target(f"colab_ws_proxy_{self.client_id}", registered)
with open('colab_websocket.js', 'r') as f:
js_code = f.read()
js_code = js_code.replace('<client_id>', self.client_id).replace('<ws_url>', self.uri)
display(HTML(f"""
<script>
{js_code}
</script>
"""))
async def _handle_comm_message(self, message):
"""Handle incoming messages from JavaScript."""
if message['type'] == 'message':
await self.websocket.send(message['data'])
print(f"Sent message to WebSocket server: {message['data']}")
elif message['type'] == 'close':
await self.websocket.close()
print("Closed WebSocket connection")
def emit(self, msg):
"""Emit a message."""
msg, buffer_paths, buffers = remove_buffers(msg)
if len(buffers) > 0:
msg["__buffer_paths__"] = buffer_paths
self.comm.send(msg, buffers=buffers)
else:
self.comm.send(msg)
# Example usage:
uri = "ws://127.0.0.1:8765" # Local WebSocket server for testing
proxy = ColabWebSocketProxy(uri)
async def test_websocket_proxy():
await proxy.connect()
loop = asyncio.get_event_loop()
loop.create_task(test_websocket_proxy())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment