Skip to content

Instantly share code, notes, and snippets.

@jelmervdl
Created March 14, 2022 15:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jelmervdl/674f7d9bb9eda1c5f0206f275d8b8d75 to your computer and use it in GitHub Desktop.
Save jelmervdl/674f7d9bb9eda1c5f0206f275d8b8d75 to your computer and use it in GitHub Desktop.
Parse marian binary model
#!/usr/bin/env python3
import sys
import struct
from pprint import pprint
import argparse
import mmap
from typing import NamedTuple
class Reader:
def __init__(self, buffer):
self.buffer = buffer
self.offset = 0
def unpack(self, format):
size = struct.calcsize(format)
assert(self.offset + size <= len(self.buffer))
data = struct.unpack_from(format, self.buffer, self.offset)
self.offset += size
return data
class Header(NamedTuple):
name_length: int
type: int
shape_length: int
data_length: int
class Item:
name: str
type: int
shape: list[int]
bytes: bytes
def __init__(self, type, name):
self.type = type
self.name = name
def __repr__(self):
return f'{self.name.decode()}: {self.type} ({"x".join(map(str, self.shape))})'
def read(filename, *, read_data=True):
with open(filename, 'rb') as fh, mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ) as buffer:
r = Reader(buffer)
version, size = r.unpack('@QQ')
# read the sizes of our data items
headers = [Header(*r.unpack('@QQQQ')) for i in range(size)]
# read the names
items = [
Item(
type=header.type,
name=r.unpack(f'{header.name_length}s')[0].rstrip(b'\0')
)
for header in headers
]
# read the shape data
for header, item in zip(headers, items):
item.shape = r.unpack(f'@{header.shape_length}I')
# move by offset bytes, aligned to 256-bytes boundary
if read_data:
offset = r.unpack('@Q')[0]
r.offset += offset
for header, item in zip(headers, items):
item.bytes = r.unpack(f'@{header.data_length}s')[0]
return items
def run_list(options):
items = read(options.model, read_data=False)
for item in items:
if options.simple:
print(item.name.decode())
else:
print(repr(item))
def run_extract(options):
items = read(options.model)
for item in items:
if item.name.decode() == options.name:
sys.stdout.buffer.write(item.bytes)
break
else:
print(f"Entry not found: {options.name}", file=sys.stderr)
sys.exit(1)
def main(args):
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
parser_list = subparsers.add_parser('list', aliases=['l', 'ls'])
parser_list.add_argument('--simple', '-s', action='store_true')
parser_list.add_argument('model')
parser_list.set_defaults(func=run_list)
parser_extract = subparsers.add_parser('extract', aliases=['e'])
parser_extract.add_argument('model')
parser_extract.set_defaults(func=run_extract)
parser_extract.add_argument('name')
options = parser.parse_args(args)
options.func(options)
if __name__ == '__main__':
main(sys.argv[1:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment