-
-
Save elazarcoh/adbbf8761817c8319b7e496d3fd6f888 to your computer and use it in GitHub Desktop.
simple socket serialization python-nodejs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import json | |
import os | |
import logging | |
import sys | |
from pathlib import Path | |
from socket_client import open_send_and_close | |
def load_port_from_file(): | |
port_file =os.path.join(os.path.dirname(__file__), 'port.txt') | |
with open(port_file) as f: | |
port = int(f.read()) | |
logging.info(f'Loaded port {port} from {port_file}') | |
return port | |
def generate_json_message(length): | |
message = json.dumps({'a': 'b' * length}) | |
return message | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.DEBUG) | |
port = load_port_from_file() | |
logging.info(f'Port is {port}') | |
# ~6MB json message | |
length = 6 * 1024 * 1024 | |
json_message = generate_json_message(length) | |
open_send_and_close(port, request_id, json_message, 0x02) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { SocketServer } from "SocketSerialization"; | |
import * as fs from "fs"; | |
import * as path from "path"; | |
function writePortToFile(port: number) { | |
const filePath = path.join(__dirname, "port.txt"); | |
fs.writeFileSync(filePath, port.toString()); | |
} | |
function main() { | |
console.log("SocketSerialization tester"); | |
const server = new SocketServer(); | |
server.start(); | |
const port = server.portNumber; | |
writePortToFile(port); | |
} | |
main(); | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import socket | |
import traceback | |
import struct | |
import numpy as np | |
# See SocketSerialization.ts for message format | |
# MessageType | |
PythonSendingObject = 0x01 | |
# ObjectType | |
NumpyArray = 0x01 | |
Json = 0x02 | |
ExceptionObject = 0xff | |
# ArrayDataType | |
Float32 = 0x01 | |
Float64 = 0x02 | |
Int8 = 0x03 | |
Int16 = 0x04 | |
Int32 = 0x05 | |
Int64 = 0x06 | |
Uint8 = 0x07 | |
Uint16 = 0x08 | |
Uint32 = 0x09 | |
Uint64 = 0x0a | |
Bool = 0x0b | |
# ExceptionType | |
ExceptionTypes = { | |
BaseException: 0x01, | |
Exception: 0x02, | |
RuntimeError: 0x03, | |
TypeError: 0x04, | |
ValueError: 0x05, | |
None: 0xff, # Unknown exception | |
} | |
MessageType = np.uint8 | |
RequestIdType = np.uint32 | |
ObjectIdType = np.uint32 | |
ObjectType = np.uint8 | |
NumDimsType = np.uint8 | |
DimType = np.uint32 | |
array_dtype_to_array_data_type = { | |
"float32": Float32, | |
"float64": Float64, | |
"int8": Int8, | |
"int16": Int16, | |
"int32": Int32, | |
"int64": Int64, | |
"uint8": Uint8, | |
"uint16": Uint16, | |
"uint32": Uint32, | |
"uint64": Uint64, | |
"bool": Bool, | |
"bool_": Bool, | |
} | |
def create_numpy_message( | |
request_id: RequestIdType, | |
array: np.ndarray, | |
): | |
# Create the message header | |
message_type = MessageType(PythonSendingObject) | |
request_id = RequestIdType(request_id) | |
object_id = ObjectIdType(id(array)) | |
object_type = ObjectType(NumpyArray) | |
# Create the message body | |
array_dtype = array_dtype_to_array_data_type[str(array.dtype)] | |
num_dimensions = NumDimsType(len(array.shape)) | |
array_shape = np.array(array.shape, dtype=DimType) | |
array_data = array.tobytes() | |
# Create the message | |
message = [ | |
message_type, | |
request_id, | |
object_id, | |
object_type, | |
array_dtype, | |
num_dimensions, | |
*array_shape, | |
array_data, | |
] | |
# Create the message format string | |
message_format = f'!BIIBBB{num_dimensions}I{len(array_data)}s' | |
# Pack the message | |
message_pack = struct.pack(message_format, *message) | |
return message_pack | |
def create_exception_message( | |
request_id: RequestIdType, | |
exception: Exception, | |
): | |
# Create the message header | |
message_type = MessageType(PythonSendingObject) | |
request_id = RequestIdType(request_id) | |
object_id = ObjectIdType(id(exception)) | |
object_type = ObjectType(ExceptionObject) | |
# Create the message body | |
exception_type = ExceptionTypes.get(type(exception), ExceptionTypes[None]) | |
exception_message = traceback.format_exc().encode() | |
# Create the message | |
message = [ | |
message_type, | |
request_id, | |
object_id, | |
object_type, | |
exception_type, | |
exception_message, | |
] | |
# Create the message format string | |
message_format = f'!BIIBB{len(exception_message)}s' | |
# Pack the message | |
message_pack = struct.pack(message_format, *message) | |
return message_pack | |
def create_json_message( | |
request_id: RequestIdType, | |
obj: str, | |
): | |
# Create the message header | |
message_type = MessageType(PythonSendingObject) | |
request_id = RequestIdType(request_id) | |
object_id = ObjectIdType(id(obj)) | |
object_type = ObjectType(Json) | |
# Create the message body | |
json_message = obj.encode() | |
# Create the message | |
message = [ | |
message_type, | |
request_id, | |
object_id, | |
object_type, | |
json_message, | |
] | |
# Create the message format string | |
message_format = f'!BIIB{len(json_message)}s' | |
# Pack the message | |
message_pack = struct.pack(message_format, *message) | |
return message_pack | |
def open_send_and_close(port, request_id, obj, type): | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
s.connect(('localhost', port)) | |
try: | |
if type == NumpyArray: | |
message = create_numpy_message(request_id, obj) | |
elif type == Json: | |
message = create_json_message(request_id, obj) | |
else: | |
raise ValueError(f'Unknown type {type}') | |
except Exception as e: | |
message = create_exception_message(request_id, e) | |
# Send the message length (4 bytes, big-endian) | |
msg_len = struct.pack('!I', len(message)) | |
s.sendall(msg_len) | |
# Send the message data | |
s.sendall(message) | |
# Close the socket | |
s.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { Service } from "typedi"; | |
import * as net from "net"; | |
// import { logDebug, logInfo } from "../Logging"; | |
// Message format: | |
// 1. Message length (4 bytes) | |
// 2. Message type (1 byte) | |
// 3. Message data (length bytes) | |
// | |
// Message types: | |
// 1. Python sending an object => 0x01 | |
// Message data format: | |
// Request ID (4 bytes) | |
// Object ID (4 bytes) | |
// Object type (1 byte) | |
// Object data (length bytes) | |
// | |
// Object types: | |
// 1. numpy array => 0x01 | |
// Object data format: | |
// Data type (1 byte) | |
// Byte order of the array data (1 byte) | |
// Number of dimensions (1 byte) | |
// Dimensions (4 bytes each) | |
// Padding (0-7 bytes) | |
// Data (length bytes) | |
// | |
// Data types: | |
// 1. float32 => 0x01 | |
// 2. float64 => 0x02 | |
// 3. int8 => 0x03 | |
// 4. int16 => 0x04 | |
// 5. int32 => 0x05 | |
// 6. int64 => 0x06 | |
// 7. uint8 => 0x07 | |
// 8. uint16 => 0x08 | |
// 9. uint32 => 0x09 | |
// 10. uint64 => 0x0a | |
// 11. bool => 0x0b | |
// | |
// 2. Json => 0x02 | |
// Object data format: | |
// Json string (length bytes) | |
// | |
// -1. Exception => 0xff | |
// Object data format: | |
// Exception type (1 byte) | |
// Exception message (length bytes) | |
// | |
// 2. Webview Hello => 0x02 | |
// Message data format: | |
// None | |
enum MessageType { | |
PythonSendingObject = 0x01, | |
WebviewHello = 0x02, | |
} | |
export enum ObjectType { | |
NumpyArray = 0x01, | |
Json = 0x02, | |
Exception = 0xff, | |
} | |
enum ByteOrder { | |
LittleEndian = 0x01, | |
BigEndian = 0x02, | |
} | |
enum ArrayDataType { | |
Float32 = 0x01, | |
Float64 = 0x02, | |
Int8 = 0x03, | |
Int16 = 0x04, | |
Int32 = 0x05, | |
Int64 = 0x06, | |
Uint8 = 0x07, | |
Uint16 = 0x08, | |
Uint32 = 0x09, | |
Uint64 = 0x0a, | |
Bool = 0x0b, | |
} | |
enum ExceptionType { | |
// Common exceptions in Python | |
BaseException = 0x01, | |
Exception = 0x02, | |
RuntimeError = 0x03, | |
TypeError = 0x04, | |
ValueError = 0x05, | |
UnknownException = 0xff, | |
} | |
class StatefulReader { | |
readonly functions: { | |
[key: string]: [number, (offset?: number) => number]; | |
} = { | |
readUInt8: [1, Buffer.prototype.readUInt8], | |
readUInt32: [4, Buffer.prototype.readUInt32BE], | |
readFloat32: [4, Buffer.prototype.readFloatBE], | |
readFloat64: [8, Buffer.prototype.readDoubleBE], | |
}; | |
constructor(private buffer: Buffer) {} | |
get currentBuffer() { | |
return this.buffer; | |
} | |
private read([length, readFunction]: [ | |
number, | |
(offset?: number) => number | |
]) { | |
const result = readFunction.call(this.buffer, 0); | |
const newBuffer = this.buffer.subarray(length); | |
this.buffer = newBuffer; | |
return result; | |
} | |
readUInt8() { | |
return this.read(this.functions.readUInt8); | |
} | |
readUInt32() { | |
return this.read(this.functions.readUInt32); | |
} | |
readFloat32() { | |
return this.read(this.functions.readFloat32); | |
} | |
readFloat64() { | |
return this.read(this.functions.readFloat64); | |
} | |
} | |
type WebviewHelloMessage = { type: MessageType.WebviewHello }; | |
type PythonSendingObjectMessage<T> = { | |
type: MessageType.PythonSendingObject; | |
requestId: number; | |
objectId: number; | |
objectType: ObjectType; | |
object: T; | |
}; | |
type Message = PythonSendingObjectMessage<unknown> | WebviewHelloMessage; | |
function parseMessage(messageType: MessageType, buffer: Buffer): Message { | |
const reader = new StatefulReader(buffer); | |
switch (messageType) { | |
case MessageType.PythonSendingObject: | |
return parsePythonSendingObjectMessage(reader.currentBuffer); | |
case MessageType.WebviewHello: | |
return { type: MessageType.WebviewHello }; | |
default: | |
throw new Error("Unknown message type: " + messageType); | |
} | |
} | |
function parsePythonSendingObjectMessage(buffer: Buffer) { | |
const reader = new StatefulReader(buffer); | |
const requestId = reader.readUInt32(); | |
logDebug("Request ID (ui32): ", requestId, "; ", reader.currentBuffer); | |
const objectId = reader.readUInt32(); | |
logDebug("Object ID (ui32): ", objectId, "; ", reader.currentBuffer); | |
const objectType = reader.readUInt8(); | |
logDebug("Object type (ui8): ", objectType, "; ", reader.currentBuffer); | |
let obj; | |
switch (objectType) { | |
case ObjectType.NumpyArray: | |
obj = parseNumpyArrayMessage(reader.currentBuffer); | |
break; | |
case ObjectType.Json: | |
obj = parseJsonMessage(reader.currentBuffer); | |
break; | |
case ObjectType.Exception: | |
obj = parseExceptionMessage(reader.currentBuffer); | |
break; | |
default: | |
throw new Error("Unknown object type: " + objectType); | |
} | |
return { | |
type: MessageType.PythonSendingObject, | |
requestId, | |
objectId, | |
objectType, | |
object: obj, | |
}; | |
} | |
function parseJsonMessage(buffer: Buffer) { | |
const reader = new StatefulReader(buffer); | |
const json = reader.currentBuffer.toString( | |
"utf-8", | |
0, | |
reader.currentBuffer.length | |
); | |
logDebug("Json string: ", json, "; ", reader.currentBuffer); | |
const obj = JSON.parse(json); | |
return obj; | |
} | |
function checkEndian() { | |
const arrayBuffer = new ArrayBuffer(2); | |
const uint8Array = new Uint8Array(arrayBuffer); | |
const uint16array = new Uint16Array(arrayBuffer); | |
uint8Array[0] = 0xaa; // set first byte | |
uint8Array[1] = 0xbb; // set second byte | |
if (uint16array[0] === 0xbbaa) return ByteOrder.LittleEndian; | |
if (uint16array[0] === 0xaabb) return ByteOrder.BigEndian; | |
else throw new Error("Something crazy just happened"); | |
} | |
const machineByteOrder = checkEndian(); | |
const typedArrayConstructor = { | |
[ArrayDataType.Float32]: Float32Array, | |
[ArrayDataType.Float64]: Float64Array, | |
[ArrayDataType.Int8]: Int8Array, | |
[ArrayDataType.Int16]: Int16Array, | |
[ArrayDataType.Int32]: Int32Array, | |
[ArrayDataType.Int64]: BigInt64Array, | |
[ArrayDataType.Uint8]: Uint8Array, | |
[ArrayDataType.Uint16]: Uint16Array, | |
[ArrayDataType.Uint32]: Uint32Array, | |
[ArrayDataType.Uint64]: BigUint64Array, | |
[ArrayDataType.Bool]: Uint8Array, | |
}; | |
const dataviewGetter = { | |
[ArrayDataType.Float32]: DataView.prototype.getFloat32, | |
[ArrayDataType.Float64]: DataView.prototype.getFloat64, | |
[ArrayDataType.Int8]: DataView.prototype.getInt8, | |
[ArrayDataType.Int16]: DataView.prototype.getInt16, | |
[ArrayDataType.Int32]: DataView.prototype.getInt32, | |
[ArrayDataType.Int64]: DataView.prototype.getBigInt64, | |
[ArrayDataType.Uint8]: DataView.prototype.getUint8, | |
[ArrayDataType.Uint16]: DataView.prototype.getUint16, | |
[ArrayDataType.Uint32]: DataView.prototype.getUint32, | |
[ArrayDataType.Uint64]: DataView.prototype.getBigUint64, | |
[ArrayDataType.Bool]: DataView.prototype.getUint8, | |
}; | |
function arrayBuilder(datatype: ArrayDataType, byteOrder: ByteOrder) { | |
const ctor = typedArrayConstructor[datatype]; | |
const bytesPerElement = ctor.BYTES_PER_ELEMENT; | |
if (datatype === ArrayDataType.Int64 || datatype === ArrayDataType.Uint64) { | |
// BigInt64Array and BigUint64Array need special handling | |
return (buffer: Buffer, padding: number) => { | |
buffer = buffer.subarray(padding); | |
const length = buffer.length / 8; | |
const array = new ctor(length); | |
for (let i = 0; i < length; i++) { | |
const value = buffer.readBigInt64LE(i * 8); | |
array[i] = value; | |
} | |
return array; | |
}; | |
} else { | |
if (byteOrder === machineByteOrder) { | |
return (buffer: Buffer, padding: number) => { | |
logDebug("Byte order is the same as machine byte order"); | |
return new ctor( | |
buffer, | |
buffer.byteOffset + padding, | |
buffer.byteLength / bytesPerElement | |
); | |
}; | |
} else { | |
return (buffer: Buffer, padding: number) => { | |
logDebug("Byte order is different from machine byte order"); | |
buffer = buffer.subarray(padding); | |
const littelEndian = byteOrder === ByteOrder.LittleEndian; | |
const dataview = new DataView( | |
buffer.buffer, | |
buffer.byteOffset, | |
buffer.byteLength | |
); | |
const getter = dataviewGetter[datatype].bind(dataview); | |
const array = new ctor(buffer.byteLength / bytesPerElement); | |
for (let i = 0; i < array.length; i++) { | |
const value = getter(i * bytesPerElement, littelEndian); | |
array[i] = value; | |
} | |
return array; | |
}; | |
} | |
} | |
} | |
function bufferToArray( | |
buffer: Buffer, | |
datatype: ArrayDataType, | |
byteOrder: ByteOrder, | |
numElements: number | |
) { | |
logDebug("Buffer: ", buffer); | |
logDebug("Buffer length: ", buffer.length); | |
logDebug("Buffer byte length: ", buffer.byteLength); | |
logDebug("Buffer byte offset: ", buffer.byteOffset); | |
const arraySizeBytes = | |
numElements * typedArrayConstructor[datatype].BYTES_PER_ELEMENT; | |
logDebug("Array size (bytes): ", arraySizeBytes); | |
const padding = buffer.length - arraySizeBytes; | |
logDebug("Padding (bytes): ", padding); | |
const array = arrayBuilder(datatype, byteOrder)(buffer, padding); | |
logDebug("Typed array: ", array); | |
return array; | |
} | |
type NumpyArray<T> = { | |
dimensions: number[]; | |
data: T; | |
}; | |
type Exception = { | |
type: ExceptionType; | |
message: string; | |
}; | |
function parseNumpyArrayMessage(buffer: Buffer) { | |
const reader = new StatefulReader(buffer); | |
const dataType = reader.readUInt8() as ArrayDataType; | |
logDebug("Data type (ui8): ", dataType, "; ", reader.currentBuffer); | |
const byteOrder = reader.readUInt8(); | |
logDebug("Byte order (ui8): ", byteOrder, "; ", reader.currentBuffer); | |
const numberOfDimensions = reader.readUInt8(); | |
logDebug( | |
"Number of dimensions (ui8): ", | |
numberOfDimensions, | |
"; ", | |
reader.currentBuffer | |
); | |
const dimensions = []; | |
for (let i = 0; i < numberOfDimensions; i++) { | |
dimensions.push(reader.readUInt32()); | |
} | |
logDebug("Dimensions (ui32): ", dimensions, "; ", reader.currentBuffer); | |
const numElements = dimensions.reduce((a, b) => a * b, 1); | |
const data = reader.currentBuffer; | |
const array = bufferToArray(data, dataType, byteOrder, numElements); | |
const numpyArray: NumpyArray<typeof array> = { | |
dimensions, | |
data: array, | |
}; | |
return numpyArray; | |
} | |
function parseExceptionMessage(buffer: Buffer) { | |
const reader = new StatefulReader(buffer); | |
const exceptionType = reader.readUInt8(); | |
logDebug( | |
"Exception type (ui8): ", | |
exceptionType, | |
"; ", | |
reader.currentBuffer | |
); | |
const message = reader.currentBuffer.toString( | |
"utf-8", | |
0, | |
reader.currentBuffer.length | |
); | |
logDebug("Message: ", message, "; ", reader.currentBuffer); | |
const exception: Exception = { | |
type: exceptionType, | |
message, | |
}; | |
return exception; | |
} | |
class MessageChunks { | |
private messageChunks: Buffer[] = []; | |
private messageLength: number = 0; | |
constructor(private expectedMessageLength: number) {} | |
addChunk(chunk: Buffer) { | |
const chunkLength = chunk.length; | |
if (this.messageLength + chunkLength > this.expectedMessageLength) { | |
throw new Error("Chunk is too big"); | |
} | |
this.messageChunks.push(chunk); | |
this.messageLength += chunkLength; | |
} | |
get isComplete() { | |
return this.messageLength === this.expectedMessageLength; | |
} | |
fullMessage() { | |
if (!this.isComplete) { | |
throw new Error("Message is not complete"); | |
} | |
const fullMessage = Buffer.concat(this.messageChunks); | |
return fullMessage; | |
} | |
} | |
function parseMessageLength(buffer: Buffer): [number, Buffer] { | |
const reader = new StatefulReader(buffer); | |
const messageLength = reader.readUInt32(); | |
return [messageLength, reader.currentBuffer]; | |
} | |
function parseMessageType(buffer: Buffer): [MessageType, Buffer] { | |
const reader = new StatefulReader(buffer); | |
const messageType = reader.readUInt8(); | |
return [messageType, reader.currentBuffer]; | |
} | |
class Client { | |
private messageChunks?: MessageChunks = undefined; | |
constructor(private socket: net.Socket) {} | |
onData( | |
data: Buffer, | |
onMessage: (message: Message) => void, | |
_onError: (error: Error) => void | |
) { | |
if (this.messageChunks === undefined) { | |
const [messageLength, messageData] = parseMessageLength(data); | |
this.messageChunks = new MessageChunks(messageLength); | |
this.messageChunks.addChunk(messageData); | |
} else { | |
this.messageChunks.addChunk(data); | |
} | |
if (this.messageChunks.isComplete) { | |
const fullMessage = this.messageChunks.fullMessage(); | |
const [messageType, messageData] = parseMessageType(fullMessage); | |
try { | |
const message = parseMessage(messageType, messageData); | |
onMessage(message); | |
} catch (error) { | |
logDebug("Error: ", error); | |
} | |
this.messageChunks = undefined; | |
} | |
} | |
} | |
@Service() | |
export class SocketServer { | |
private server: net.Server; | |
private port?: number = undefined; | |
private started: boolean = false; | |
private webviewClient?: net.Socket = undefined; | |
constructor() { | |
const options: net.ServerOpts = { | |
allowHalfOpen: true, | |
pauseOnConnect: false, | |
keepAlive: true, | |
}; | |
this.server = net.createServer(options); | |
this.server.on("connection", this.onClientConnected); | |
} | |
async start() { | |
if (this.started) { | |
throw new Error("SocketServer already started"); | |
} | |
this.server.listen(0); | |
const address = this.server.address(); | |
if (typeof address === "string") { | |
throw new Error("SocketServer address is a string"); | |
} else if (address === null) { | |
throw new Error("SocketServer address is null"); | |
} | |
this.port = address.port; | |
logDebug("SocketServer started on port " + this.port); | |
this.started = true; | |
} | |
get portNumber() { | |
if (!this.started) { | |
throw new Error("SocketServer is not started"); | |
} | |
if (this.port === undefined) { | |
throw new Error("SocketServer is not listening"); | |
} | |
return this.port; | |
} | |
onClientConnected(socket: net.Socket): void { | |
const onMessage = (message: Message) => { | |
logDebug("Message: ", message); | |
if (message.type === MessageType.WebviewHello) { | |
this.webviewClient = socket; | |
logDebug("Webview client connected"); | |
} | |
}; | |
const onError = (error: Error) => { | |
logDebug("Error: ", error); | |
}; | |
logDebug("Client connected"); | |
const client = new Client(socket); | |
socket.on("data", (data) => client.onData(data, onMessage, onError)); | |
} | |
} | |
// TODO: Remove this | |
// function logInfo(...obj: any[]): void { | |
// console.log(...obj); | |
// } | |
function logDebug(...obj: any[]): void { | |
console.log(...obj); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment