Skip to content

Instantly share code, notes, and snippets.

@jramseygreen
Last active December 18, 2021 13:58
Show Gist options
  • Save jramseygreen/12f988545b0e505aa9e3e20572b892bb to your computer and use it in GitHub Desktop.
Save jramseygreen/12f988545b0e505aa9e3e20572b892bb to your computer and use it in GitHub Desktop.
AES CBC python -> javascript, javascript -> python (w/ strings)
from Crypto import Random
from Crypto.Cipher import AES
import base64
from hashlib import md5
BLOCK_SIZE = 16
class CryptoWrapper:
def __init__(self, password):
self.__password = password.encode()
def __pad(self, data):
length = BLOCK_SIZE - (len(data) % BLOCK_SIZE)
return data + (chr(length) * length).encode()
def __unpad(self, data):
return data[:-(data[-1] if type(data[-1]) == int else ord(data[-1]))]
def __bytes_to_key(self, data, salt, output=48):
# extended from https://gist.github.com/gsakkis/4546068
assert len(salt) == 8, len(salt)
data += salt
key = md5(data).digest()
final_key = key
while len(final_key) < output:
key = md5(key + data).digest()
final_key += key
return final_key[:output]
def encrypt(self, message):
message = message.encode()
salt = Random.new().read(8)
key_iv = self.__bytes_to_key(self.__password, salt, 32 + 16)
key = key_iv[:32]
iv = key_iv[32:]
aes = AES.new(key, AES.MODE_CBC, iv)
return base64.b64encode(b"Salted__" + salt + aes.encrypt(self.__pad(message))).decode()
def decrypt(self, encrypted):
encrypted = base64.b64decode(encrypted)
assert encrypted[0:8] == b"Salted__"
salt = encrypted[8:16]
key_iv = self.__bytes_to_key(self.__password, salt, 32 + 16)
key = key_iv[:32]
iv = key_iv[32:]
aes = AES.new(key, AES.MODE_CBC, iv)
return self.__unpad(aes.decrypt(encrypted[16:])).decode()
def set_password(self, password):
self.__password = password.encode()
def get_password(self):
return self.__password.decode()
import ws_server
import crypto
def on_message(conn, message):
print(message)
decrypted = crypt.decrypt(message)
print(decrypted)
ws.send(conn, crypt.encrypt(decrypted))
crypt = crypto.CryptoWrapper("password")
ws = ws_server.ws_server(on_message_function=on_message)
ws.listen()
<html>
<body>
<script src="https://cdnjs.cloudflare.com/ajax/libs/crypto-js/4.1.1/crypto-js.min.js"></script>
<script>
window.onload = function() {
var connection = new WebSocket("ws://localhost:8080/");
connection.onopen = function () {
connection.send(encrypt('hello world'));
};
connection.onerror = function (error) {
console.log('WebSocket Error ' + error);
};
connection.onmessage = function (e) {
console.log(e.data);
console.log(decrypt(e.data))
};
};
function encrypt(message) {
var password ='password'//key used in Python
var encrypted = CryptoJS.AES.encrypt(message, password);
return encrypted.toString();
}
function decrypt(message) {
var password ='password'//key used in Python
var decrypted = CryptoJS.AES.decrypt(message, password);
return decrypted.toString(CryptoJS.enc.Utf8);
}
</script>
</body>
</html>
import hashlib
import base64
import socket
import struct
import six
import threading
class ws_server:
def __init__(self, host='localhost', port=8080, on_message_function=None):
self.__running = True
self.__host = host
self.__port = port
self.__sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.__on_message_function = on_message_function
self.__clients = []
def __handshake(self, conn):
request = conn.recv(1024).strip()
specificationGUID = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
websocketKey = b''
lines = request.splitlines()
for line in lines:
args = line.partition(b': ')
if args[0] == b'Sec-WebSocket-Key':
websocketKey = args[2]
break
fullKey = hashlib.sha1((websocketKey + specificationGUID)).digest()
b64Key = base64.b64encode(fullKey)
response = b'HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ' + b64Key + b'\r\n\r\n'
conn.send(response)
def __ws_encode(self, data=""):
if isinstance(data, six.text_type):
data = data.encode('utf-8')
length = len(data)
fin, rsv1, rsv2, rsv3, opcode = 1, 0, 0, 0, 0x1
frame_header = chr(fin << 7 | rsv1 << 6 | rsv2 << 5 | rsv3 << 4 | opcode)
if length < 0x7e:
frame_header += chr(0 << 7 | length)
frame_header = six.b(frame_header)
elif length < 1 << 16:
frame_header += chr(0 << 7 | 0x7e)
frame_header = six.b(frame_header)
frame_header += struct.pack("!H", length)
else:
frame_header += chr(0 << 7 | 0x7f)
frame_header = six.b(frame_header)
frame_header += struct.pack("!Q", length)
return frame_header + data
def __ws_decode(self, data):
frame = bytearray(data)
length = frame[1] & 127
indexFirstMask = 2
if length == 126:
indexFirstMask = 4
elif length == 127:
indexFirstMask = 10
indexFirstDataByte = indexFirstMask + 4
mask = frame[indexFirstMask:indexFirstDataByte]
i = indexFirstDataByte
j = 0
decoded = []
while i < len(frame):
decoded.append(frame[i] ^ mask[j % 4])
i += 1
j += 1
return bytes(decoded).decode("utf-8")
def __client_thread(self, conn):
self.__clients.append(conn)
try:
while True:
msg = conn.recv(2046)
self.__on_message_function(conn, self.__ws_decode(msg))
except:
self.__clients.remove(conn)
conn.close()
def listen(self, running=False):
if not running:
x = threading.Thread(target=self.listen, args=(True,))
x.start()
else:
self.__sock.bind((self.__host, self.__port))
self.__sock.listen()
while self.__running:
conn, addr = self.__sock.accept()
self.__handshake(conn)
x = threading.Thread(target=self.__client_thread, args=(conn,))
x.setDaemon(True)
x.start()
def send(self, conn, message):
conn.send(self.__ws_encode(message))
def on_message(self, func):
self.__on_message_function = func
def get_clients(self):
return self.__clients
def set_port(self, port):
self.__port = port
def get_port(self):
return self.__port
def get_host(self):
return self.__host
@jramseygreen
Copy link
Author

usage with pycryptodome and six libraries

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment