Skip to content

Instantly share code, notes, and snippets.

@timercrack
Created January 20, 2016 07:39
Show Gist options
  • Save timercrack/d8df0972eadcb343b895 to your computer and use it in GitHub Desktop.
Save timercrack/d8df0972eadcb343b895 to your computer and use it in GitHub Desktop.
python construct checksum fields made easy for common tasks
#!/usr/bin/env python
# from https://groups.google.com/forum/#!topic/construct3/FzYFmdv4qTg
from construct import *
# copied from core.py
def _read_stream(stream, length):
if length < 0:
raise ValueError("length must be >= 0", length)
data = stream.read(length)
if len(data) != length:
raise FieldError("expected %d, found %d" % (length, len(data)))
return data
# copied from core.py
def _write_stream(stream, length, data):
if length < 0:
raise ValueError("length must be >= 0", length)
if len(data) != length:
raise FieldError("expected %d, found %d" % (length, len(data)))
stream.write(data)
class ChecksumField(Subconstruct):
"""
A field which can compute the checksum of previously streamed bytes.
The None value means that while writing to a stream the checksum is
automatically computed and written; the actual value is untouched (None).
Parameters 'start' and 'stop' can be either None (stream beginning and
current positions) or functions which compute stream offsets from the
context.
It requires readable and writable streams.
"""
__slots__ = ['start', 'stop', 'compute']
def __init__(self, subcon, compute, start=None, stop=None):
Subconstruct.__init__(self, subcon)
self.start = start
self.stop = stop
self.compute = compute
def _parse(self, stream, context):
origpos = stream.tell()
startpos = self.start(context) if self.start is not None else 0
stoppos = self.stop(context) if self.stop is not None else origpos
length = stoppos - startpos
stream.seek(startpos, 2 if startpos < 0 else 0)
data = _read_stream(stream, length)
mask = (1 << (self.subcon.sizeof() << 3)) - 1
checksum = self.compute(data) & mask
stream.seek(origpos, 0)
obj = self.subcon._parse(stream, context)
if obj != checksum:
fmt = 'wrong checksum %d, expected %d'
raise ValidationError(fmt % (obj, checksum))
return obj
def _build(self, obj, stream, context):
if obj is None:
origpos = stream.tell()
startpos = self.start(context) if self.start is not None else 0
stoppos = self.stop(context) if self.stop is not None else origpos
length = stoppos - startpos
stream.seek(startpos, 2 if startpos < 0 else 0)
data = _read_stream(stream, length)
mask = (1 << (self.subcon.sizeof() << 3)) - 1
checksum = self.compute(data) & mask
stream.seek(origpos, 0)
self.subcon._build(checksum, stream, context)
else:
self.subcon._build(obj, stream, context)
def _sizeof(self, context):
return self.subcon._sizeof(context)
def Checksum8(name, start=None, stop=None):
"""
Creates a checksum-8 field.
Parameters 'start' and 'stop' can be either None (stream beginning and
current positions) or functions which compute stream offsets from the
context.
It requires readable and writable streams.
"""
return ChecksumField(Byte(name), sum, start, stop)
def AutoNoneContainer(subcon, **kw):
"""
Automatically sets children not found in 'kw' to None.
"""
if hasattr(subcon, 'subcons'):
keys = [sc.name for sc in subcon.subcons]
elif hasattr(subcon, 'subcon'):
keys = [subcon.subcon.name]
else:
raise KeyError('subcon %r has no children' % subcon.name)
kw.update({key: None for key in keys if key not in kw})
return Container(**kw)
if __name__ == '__main__':
from binascii import hexlify, unhexlify
from collections import OrderedDict
import six
def hexstr(data):
return '0x' + hexlify(data).upper().decode()
# 'x' holds the checksum-8 of ('b', 'c')
# 'y' holds the checksum-8 of all previous bytes in the stream
Record = Struct(None,
Byte('a'),
Anchor('_at_b'),
UBInt16('b'),
UBInt32('c'),
Anchor('_at_d'),
Byte('d'),
Checksum8('x', lambda ctx: ctx._at_b, lambda ctx: ctx._at_d),
Byte('e'),
Checksum8('y'),
)
# 'x' and 'y' are set to None, so that their value is computed
# on-the-fly while writing to a stream
model = AutoNoneContainer(Record, **{
'a': 0xAA,
'b': 0xEFDC,
'c': 0x12345678,
'd': 0x55,
'e': 0x99,
})
six.print_('# model', dict(model), '', sep='\n')
built = Record.build(model)
six.print_('# built', hexstr(built), '', sep='\n')
expected = unhexlify(six.b('AAEFDC1234567855DF9956'))
six.print_('# expected', hexstr(expected), '', sep='\n')
six.print_('# built == expected', str(built == expected), '', sep='\n')
# 'x' and 'y' hold the correct values, while 'model' still has None
parsed = Record.parse(built)
six.print_('# parsed', OrderedDict(parsed), '', parsed, '', sep='\n')
rebuilt = Record.build(parsed)
six.print_('# rebuilt', hexstr(rebuilt), '', sep='\n')
six.print_('# built == rebuilt', str(built == rebuilt), '', sep='\n')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment