Skip to content

Instantly share code, notes, and snippets.

@tex2e
Last active July 20, 2023 07:05
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tex2e/a55cfe8f006799ff745dc888a0149183 to your computer and use it in GitHub Desktop.
Save tex2e/a55cfe8f006799ff745dc888a0149183 to your computer and use it in GitHub Desktop.
import io # バイトストリーム操作
import textwrap # テキストの折り返しと詰め込み
import re # 正規表現
import shutil # ターミナル幅の取得
from metatype import Type, List, ListMeta
import dataclasses
# TLSメッセージの構造体を表すためのクラス群
# 使い方:
#
# import metastruct as meta
#
# @meta.struct
# class ClientHello(meta.MetaStruct):
# legacy_version: ProtocolVersion
# random: Random
# legacy_session_id: Opaque(size_t=Uint8)
# cipher_suites: List(size_t=Uint16, elem_t=CipherSuite)
# legacy_compression_methods: List(size_t=Uint16, elem_t=Uint8)
# extensions: List(size_t=Uint16, elem_t=Extension)
#
# 構造体のデコレータ
def struct(cls):
for name, elem_t in cls.__annotations__.items():
if not hasattr(cls, name):
setattr(cls, name, None)
return dataclasses.dataclass(repr=False)(cls)
# 構造体の抽象クラス
class MetaStruct(Type):
def __post_init__(self):
self.set_parent(None)
# create_emptyメソッドで生成されたとき(全ての要素がNoneのとき)は何もしない
if all(not getattr(self, name) for name in self.get_struct().keys()):
return
for name, field in self.get_struct().items():
elem = getattr(self, name)
# デフォルト値がラムダのとき、ラムダを評価した値を格納する
if callable(field.default) and not isinstance(elem, field.type):
setattr(self, name, field.default(self))
# 要素が親インスタンスを参照できるようにする
if isinstance(elem, Type):
elem.set_parent(self)
@classmethod
def create_empty(cls):
dict = {}
for name, field in cls.__dataclass_fields__.items():
dict[name] = None
return cls(**dict)
# 全てのMetaStructは親インスタンスを参照できるようにする。
def set_parent(self, parent):
self.parent = parent
def __bytes__(self):
f = io.BytesIO()
for name, field in self.get_struct().items():
elem = getattr(self, name)
if elem is None:
raise Exception('%s.%s is None!' % (self.__class__.__name__, name))
f.write(bytes(elem))
return f.getvalue()
@classmethod
def from_stream(cls, fs, parent=None):
# デフォルト値などを導出せずにインスタンス化する
instance = cls.create_empty()
instance.set_parent(parent) # 子が親インスタンスを参照できるようにする
for name, field in cls.get_struct().items():
elem_t = field.type
if isinstance(elem_t, Select):
# 型がSelectのときは、既に格納した値から型を決定する
elem_t = elem_t.select_type_by_switch(instance)
# バイト列から構造体への変換
elem = elem_t.from_stream(fs, instance)
# 値を構造体へ格納
setattr(instance, name, elem)
return instance
@classmethod
def get_struct(cls):
return cls.__dataclass_fields__
def __repr__(self):
# 出力は次のようにする
# 1. 各要素を表示するときはプラス(+)記号を加えて、出力幅が70を超えないようにする
# + 要素名: 型(値)
# 2. 要素もMetaStructのときは、内部要素をスペース2つ分だけインデントする
# + 要素名: MetaStruct名:
# + 要素: 型(値)
title = "%s:\n" % self.__class__.__name__
elems = []
for name, field in self.get_struct().items():
elem = getattr(self, name)
content = repr(elem)
output = '%s: %s' % (name, content)
def is_MetaStruct(elem):
return isinstance(elem, MetaStruct)
def is_List_of_MetaStruct(elem):
return (isinstance(elem, ListMeta) and
issubclass(elem.__class__.elem_t, MetaStruct))
if is_MetaStruct(elem) or is_List_of_MetaStruct(elem):
# 要素のMetaStructは出力が複数行になるので、その要素をインデントさせる
output = textwrap.indent(output, prefix=" ").strip()
else:
# その他の要素は出力が1行なので、コンソールの幅を超えないように折返し出力させる
nest = self.count_ancestors() + 1
output = '\n '.join(textwrap.wrap(output,
width=shutil.get_terminal_size().columns-(nest*2)))
elems.append('+ ' + output)
return title + "\n".join(elems)
def count_ancestors(self):
tmp = self.parent
count = 0
while tmp is not None:
tmp = tmp.parent
count += 1
return count
def __len__(self):
return len(bytes(self))
# 状況に応じて型を選択するためのクラス。
# 例えば、Handshake.msg_type が client_hello と server_hello で、
# 自身や子要素の構造体フィールドの型が変化する場合に使用する。
class Select:
def __init__(self, switch, cases):
assert isinstance(switch, str)
assert isinstance(cases, dict)
self.switch = switch
self.cases = cases
# 引数 switch の構文が正しいか確認する。
# 自身のプロパティを参照する場合 : "プロパティ名"
# 親のプロパティを参照する場合 : "親クラス名.プロパティ名"
if not re.match(r'^[a-zA-Z0-9_]+(\.[a-zA-Z_]+)?$', self.switch):
raise Exception('Select(%s) is invalid syntax!' % self.switch)
# フィールド .switch の内容を元に、構築中のインスタンスからプロパティを検索し、
# プロパティの値から導出した型を返す。
def select_type_by_switch(self, instance):
if re.match(r'^[^.]+\.[^.]+$', self.switch):
# 条件が「クラス名.プロパティ名」のとき
class_name, prop_name = self.switch.split('.', maxsplit=1)
else:
# 条件が「プロパティ名」のみ
class_name, prop_name = instance.__class__.__name__, self.switch
# インスタンスのクラス名がclass_nameと一致するまで親をさかのぼる
tmp = instance
while tmp is not None:
if tmp.__class__.__name__ == class_name: break
tmp = tmp.parent
if tmp is None:
raise Exception('Not found %s class in ancestors from %s!' % \
(class_name, instance.__class__.__name__))
# 既に格納した値の取得
value = getattr(tmp, prop_name)
# 既に格納した値から使用する型を決定する
ret = self.cases.get(value)
if ret is None:
ret = self.cases.get(Otherwise)
if ret is None:
raise Exception('Select(%s) cannot map to class in %s!' % \
(value, instance.__class__.__name__))
return ret
# Select で条件に当てはまらない場合の default を表すクラス
# Usage:
# meta.Select('fieldName', cases={
# HandshakeType.client_hello: ClientHello,
# meta.Otherwise: OpaqueLength
# })
class Otherwise:
pass
import struct # バイト列の解釈
import io # バイトストリーム操作
import textwrap # テキストの折り返しと詰め込み
from enum import Enum as BuildinEnum
# 全ての型が継承するクラス
class Type:
size = None # 固定長のときに使用する
size_t = None # 可変長のときに使用する
# バイト列から構造体を構築するメソッドの中では、
# バイト列の代わりにストリームを渡すことで、読み取った文字数をストリームが保持する。
@classmethod
def from_bytes(cls, data):
return cls.from_stream(io.BytesIO(data))
# 抽象クラス以外は必ず上書きすること
@classmethod
def from_stream(cls, fs, parent=None):
raise NotImplementedError
# 構造体の構築時には、Opaqueは親インスタンスを参照できるようにする。
def set_parent(self, instance):
self.parent = instance
def __bytes__(self):
raise NotImplementedError(self.__class__.__name__ + "#bytes")
def __repr__(self):
raise NotImplementedError(self.__class__.__name__ + "#repr")
# --- Uint ---------------------------------------------------------------------
# UintNの抽象クラス
class Uint(Type):
def __init__(self, value):
assert self.__class__ != Uint, \
"Uint (Abstract Class) cannot construct instance!"
assert isinstance(value, int)
max_value = 1 << (8 * self.__class__.size)
assert 0 <= value < max_value
self.value = value
def __bytes__(self):
res = []
tmp = self.value
for i in range(self.__class__.size):
res.append(bytes([tmp & 0xff]))
tmp >>= 8
res.reverse()
return self.value.to_bytes(self.__class__.size, byteorder='big')
@classmethod
def from_stream(cls, fs, parent=None):
data = fs.read(cls.size)
return cls(int.from_bytes(data, byteorder='big'))
def __len__(self):
return self.__class__.size
def __int__(self):
return self.value
def __eq__(self, other):
return hasattr(other, 'value') and self.value == other.value
def __repr__(self):
classname = self.__class__.__name__
value = self.value
width = self.__class__.size * 2
return "{}(0x{:0{width}x})".format(classname, value, width=width)
def __hash__(self):
return hash((self.__class__.size, self.value))
class Uint8(Uint):
size = 1 # unsinged char
class Uint16(Uint):
size = 2 # unsigned short
class Uint24(Uint):
size = 3
class Uint32(Uint):
size = 4 # unsigned int
class Uint64(Uint):
size = 8 # unsinged long
# Variable-Length Integer Encoding
#
# +======+========+=============+=======================+
# | 2MSB | Length | Usable Bits | Range |
# +======+========+=============+=======================+
# | 00 | 1 | 6 | 0-63 |
# | 01 | 2 | 14 | 0-16383 |
# | 10 | 4 | 30 | 0-1073741823 |
# | 11 | 8 | 62 | 0-4611686018427387903 |
# +------+--------+-------------+-----------------------+
#
class VarLenIntEncoding(Type):
def __init__(self, value):
assert isinstance(value, Uint)
size_t = value.__class__
size = size_t.size
assert 0 <= int(value) < (1 << (6 + 8*(size-1)))
self.size = size
self.size_t = size_t
self.value = value
@classmethod
def from_stream(cls, fs, parent=None):
head = fs.read(1)
msb2bit = ord(head) >> 6
length, UintN = cls._get_msb2bit_info(msb2bit)
rest = fs.read(length - 1)
byte = bytes([ord(head) & 0b00111111]) + rest
value = UintN.from_bytes(byte)
return VarLenIntEncoding(value)
def __bytes__(self):
size = self._get_size()
msb2bit = self._get_msb2bit()
msb2bit_mask = msb2bit << 6
byte = bytearray(bytes(self.value))
byte[0] |= msb2bit_mask
return bytes(byte)
def __repr__(self):
return 'Quic' + repr(self.value)
def _get_size(self):
size = self.size_t.size
assert size in (1, 2, 4, 8)
return size
def _get_msb2bit(self):
length = self._get_size()
if length == 1: return 0b00
if length == 2: return 0b01
if length == 4: return 0b10
if length == 8: return 0b11
@classmethod
def _get_msb2bit_info(cls, msb2bit):
if msb2bit == 0b00: return (1, Uint8)
if msb2bit == 0b01: return (2, Uint16)
if msb2bit == 0b10: return (4, Uint32)
if msb2bit == 0b11: return (8, Uint64)
@staticmethod
def len2uint(byte_len):
if 0 <= byte_len <= 63: return Uint8
if 0 <= byte_len <= 16383: return Uint16
if 0 <= byte_len <= 1073741823: return Uint32
if 0 <= byte_len <= 4611686018427387903: return Uint64
def __eq__(self, other):
return (self.size_t == other.size_t and self.value == other.value)
def __int__(self):
return int(self.value)
def __len__(self):
return self._get_size()
# --- Opaque -------------------------------------------------------------------
class OpaqueMeta(Type):
def get_raw_bytes(self):
return self.byte
def __eq__(self, other):
return self.byte == other.byte
def __len__(self):
return len(self.byte)
def Opaque(size_t):
if isinstance(size_t, int): # 引数がintのときは固定長
return OpaqueFix(size_t)
if isinstance(size_t, type(lambda: None)): # 引数がラムダのときは実行時に決定する固定長
return OpaqueFix(size_t)
if issubclass(size_t, (Uint, VarLenIntEncoding)): # 引数がUintNのときは可変長
return OpaqueVar(size_t)
raise TypeError("size's type must be an int or Uint class.")
def OpaqueFix(size):
# 固定長のOpaque (e.g. opaque string[16])
# ただし、外部の変数によってサイズが決まる場合もある (e.g. opaque string[Hash.length])
class OpaqueFix(OpaqueMeta):
size = 0
def __init__(self, byte):
assert isinstance(byte, (bytes, bytearray))
size = OpaqueFix.size
if callable(size): # ラムダのときは実行時に評価した値がサイズになる
self.byte = byte
else:
assert len(byte) <= size
self.byte = bytes(byte).rjust(size, b'\x00')
def __bytes__(self):
return self.byte
@classmethod
def from_stream(cls, fs, parent=None):
size = cls.size
if callable(size): # ラムダのときは実行時に評価した値がサイズになる
size = int(size(parent))
opaque = OpaqueFix(fs.read(size))
opaque.set_parent(parent)
return opaque
def __repr__(self):
size = OpaqueFix.size
if callable(size):
size = int(size(self.parent))
return 'Opaque[%d](%s)' % (size, repr(self.byte))
def get_size(self):
if callable(self.size):
return int(OpaqueFix.size(self.parent))
return self.size
OpaqueFix.size = size
return OpaqueFix
def OpaqueVar(size_t):
# 可変長のOpaque (e.g. opaque string<0..15>)
class OpaqueVar(OpaqueMeta):
size_t = Uint
def __init__(self, byte):
assert isinstance(byte, (bytes, bytearray))
self.byte = bytes(byte)
self.size_t = OpaqueVar.size_t
def __bytes__(self):
if issubclass(self.size_t, Uint):
UintN = self.size_t
return bytes(UintN(len(self.byte))) + self.byte
elif issubclass(self.size_t, VarLenIntEncoding):
VarLenInt = self.size_t
byte_len = len(self.byte)
UintN = VarLenIntEncoding.len2uint(byte_len)
return bytes(VarLenInt(UintN(byte_len))) + self.byte
else:
raise NotImplementedError
@classmethod
def from_stream(cls, fs, parent=None):
size_t = OpaqueVar.size_t
length = int(size_t.from_stream(fs))
byte = fs.read(length)
return OpaqueVar(byte)
def __repr__(self):
return 'Opaque<%s>(%s)' % \
(OpaqueVar.size_t.__name__, repr(self.byte))
OpaqueVar.size_t = size_t
return OpaqueVar
OpaqueUint8 = Opaque(Uint8)
OpaqueUint16 = Opaque(Uint16)
OpaqueUint24 = Opaque(Uint24)
OpaqueUint32 = Opaque(Uint32)
OpaqueLength = Opaque(lambda self: self.length)
# --- List ---------------------------------------------------------------------
class ListMeta(Type):
pass
# 配列の構造を表すためのクラス
def List(size_t, elem_t):
# List ではスコープが異なる(グローバルとローカル)と、
# 組み込み関数 issubclass が期待通りに動かない場合があるので、
# 子クラスの基底クラス名の一覧の中に親クラス名が存在すれば True を返す関数を使用する。
# この関数は List クラス内で issubclass の代わりに利用する。
def my_issubclass(child, parent):
if not hasattr(child, '__bases__'):
return False
return parent.__name__ in map(lambda x: x.__name__, child.__bases__)
class List(ListMeta):
size_t = None # リストの長さを表す部分の型
elem_t = None # リストの要素の型
def __init__(self, array):
self.array = array
def get_array(self):
return self.array
# 構造体の構築時には、Listは親インスタンスを参照できるようにする。
# そして要素がMetaStructであれば、各要素の.set_parent()に親インスタンスを渡す。
def set_parent(self, instance):
self.parent = instance
from metastruct import MetaStruct
if my_issubclass(List.elem_t, MetaStruct):
for elem in self.get_array():
elem.set_parent(self.parent)
def __bytes__(self):
size_t = List.size_t
content = b''.join(bytes(elem) for elem in self.get_array())
content_len = len(content)
return bytes(size_t(content_len)) + content
@classmethod
def from_stream(cls, fs, parent=None):
from metastruct import MetaStruct
size_t = cls.size_t
elem_t = cls.elem_t
list_size = int(size_t.from_stream(fs)) # リスト全体の長さ
elem_size = elem_t.size # 要素の長さを表す部分の長さ
array = []
# 現在のストリーム位置が全体の長さを超えない間、繰り返し行う
startpos = fs.tell()
while (fs.tell() - startpos) < list_size:
elem = elem_t.from_stream(fs, parent)
array.append(elem)
return List(array)
def __eq__(self, other):
if len(self.get_array()) != len(other.get_array()):
return False
for self_elem, other_elem in zip(self.get_array(), other.get_array()):
if self_elem != other_elem:
return False
return True
def __repr__(self):
from metastruct import MetaStruct
if my_issubclass(List.elem_t, MetaStruct):
# リストの要素がMetaStructのときは、各要素を複数行で表示する
output = ''
for elem in self.get_array():
content = textwrap.indent(repr(elem), prefix=" ").strip()
output += '+ %s\n' % content
return 'List<%s>:\n%s' % (self.__class__.size_t.__name__, output)
else:
# それ以外のときは配列の中身を一行で表示する
return 'List<%s>%s' % \
(self.__class__.size_t.__name__, repr(self.get_array()))
def __iter__(self):
return iter(self.array)
def find(self, arg):
if callable(arg):
return next((x for x in iter(self) if arg(x)), None)
else:
return next((x for x in iter(self) if x == arg), None)
List.size_t = size_t
List.elem_t = elem_t
return List
# --- Enum ---------------------------------------------------------------------
# 列挙型を表すためのクラス
class Enum(Type, BuildinEnum):
# 親クラスにクラス変数を定義すると、子クラスでEnumが定義できなくなるので注意。
# elem_t = UintN # Enumの要素の型
# Enum は .name でラベル名、.value で値を得ることができる
def __bytes__(self):
return bytes(self.value)
@classmethod
def from_stream(cls, fs, parent=None):
elem_t = cls.get_type()
return cls(elem_t.from_stream(fs))
@classmethod
def get_type(cls):
return cls.elem_t.value
# 列挙型にない値が与えらたとき unknown という名前の値を動的に生成して返すためのクラス
class EnumUnknown(Enum):
@classmethod
def _missing_(cls, value):
obj = object.__new__(cls)
obj._name_ = 'unknown'
obj._value_ = value
return obj
import binascii
def hexdump(data) -> str:
return '\n'.join(__dumpgen(data))
def __dumpgen(data):
# Generator that produces strings (addr, hexstr, ascii):
# 00000000: 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 ................
generator = __chunks(data, 16)
for addr, d in enumerate(generator):
# address
line = '%08X: ' % (addr*16)
# hexstr
dumpstr = __dump(d)
line += dumpstr[:8*3]
if len(d) > 8: # insert separator if needed
line += ' ' + dumpstr[8*3:]
# indent
pad = 2
if len(d) < 16:
pad += 3 * (16 - len(d))
if len(d) <= 8:
pad += 1
line += ' ' * pad
# ascii
for byte in d:
# printable ASCII range 0x20 to 0x7E
line += chr(byte) if 0x20 <= byte <= 0x7E else '.'
yield line
# list(chunks([1,2,3,4,5,6,7], 3)) #=> [[1, 2, 3], [4, 5, 6], [7]]
def __chunks(seq, size):
d, m = divmod(len(seq), size)
for i in range(d):
yield seq[i*size:(i+1)*size]
if m:
yield seq[d*size:]
def __dump(binary, size=2, sep=' '):
return sep.join(__chunks(__hexstr(binary).upper(), size))
def __hexstr(binary):
return binascii.hexlify(binary).decode('ascii')
def bytexor(b1, b2):
result = bytearray(b1)
for i, b in enumerate(b2):
result[i] ^= b
return bytes(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment