|
#!/usr/bin/env python3 |
|
# -*- coding: utf-8 -*- |
|
|
|
import os |
|
import sys |
|
import struct |
|
from enum import IntEnum |
|
from collections import namedtuple |
|
|
|
from decodeenums import I32, I32WithImm, RType, Type, VarTypes, Stmt, StmtWithImm, ExportFormat |
|
|
|
Signature = namedtuple("Signature", ["ret", "args"]) |
|
FuncImportSignature = namedtuple("FuncImportSignature", |
|
["sig_index", "func_imp_index"]) |
|
FuncPtrTable = namedtuple("FuncPtrTable", ["sig_index", "elems"]) |
|
|
|
class WasmBinary: |
|
def __init__(self): |
|
self.sigs = [] |
|
self.i32s = [] |
|
self.f32s = [] |
|
self.f64s = [] |
|
self.func_names = [] |
|
self.func_imp_sigs = [] |
|
self.global_types = [] |
|
self.global_vals = [] |
|
#uint32_t func_name_base_; |
|
self.func_sigs = [] |
|
#uint32_t func_ptr_table_name_base_; |
|
self.func_ptr_tables = [] |
|
# |
|
#uint32_t num_labels_; |
|
#RType cur_ret_; |
|
#vector<Type> cur_local_types_; |
|
|
|
self.unpacked_size = -1 |
|
|
|
def decode_wasm(self, f): |
|
# read magick |
|
temp = f.read(4) |
|
if temp != b"wasm": |
|
raise Exception("Not a wasm binary") |
|
# upacked length |
|
temp = f.read(4) |
|
self.unpacked_size = struct.unpack("<I", temp)[0] |
|
#print("self.unpacked_size", self.unpacked_size) |
|
|
|
self.read_constant_pool_section(f) |
|
print("Constants", self.i32s, self.f32s, self.f64s) |
|
self.read_signature_section(f) |
|
print(self.sigs) |
|
self.read_function_import_section(f) |
|
print(self.func_names) |
|
print(self.func_imp_sigs) |
|
self.read_global_section(f) |
|
print("Global", list(zip(self.global_types, self.global_vals))) |
|
self.read_function_declaration_section(f) |
|
print("Func sigs", self.func_sigs) |
|
self.read_function_pointer_tables(f) |
|
print("Ptr tables", self.func_ptr_tables) |
|
self.read_function_definition_section(f) |
|
|
|
self.read_export_section(f) |
|
|
|
def read_vlq32(self, f): |
|
# https://en.wikipedia.org/wiki/Variable-length_quantity |
|
# unsigned, max 32bit |
|
value = f.read(1)[0] |
|
if value < 0x80: |
|
return value |
|
value &= 0x7f |
|
shift = 7 |
|
while True: |
|
b = f.read(1)[0] |
|
if b < 0x80: |
|
return value | (b << shift) |
|
value |= (b & 0x7f) << shift |
|
shift += 7 |
|
#if shift > 21: |
|
# break |
|
return value |
|
|
|
def read_vlq32i(self, f): |
|
val = self.read_vlq32(f) |
|
if val < 0x80000000: |
|
return val |
|
return val - 0x100000000 |
|
|
|
def read_str(self, f): |
|
v = b'' |
|
while True: |
|
c = f.read(1) |
|
if c[0] == 0: |
|
break |
|
v += c |
|
return v.decode("UTF-8") |
|
|
|
def read_code(self, f): |
|
code = f.read(1)[0] |
|
if not (code & 0x80): |
|
return (True, code, None, None) # raw, value |
|
else: |
|
op = (code >> 5) & 3 # high 2 bits |
|
imm = code & 31 # lower 5 bits |
|
return (False, code, op, imm) |
|
|
|
def read_constant_pool_section(self, f): |
|
num_i32s = self.read_vlq32(f) |
|
num_f32s = self.read_vlq32(f) |
|
num_f64s = self.read_vlq32(f) |
|
for i in range(num_i32s): |
|
self.i32s.append(self.read_vlq32(f)) |
|
for i in range(num_f32s): |
|
temp = f.read(4) |
|
self.f32s.append(struct.unpack("<f", temp)[0]) |
|
for i in range(num_f64s): |
|
temp = f.read(8) |
|
self.f32s.append(struct.unpack("<d", temp)[0]) |
|
|
|
def read_signature_section(self, f): |
|
self.sigs = [] |
|
num_sigs = self.read_vlq32(f) |
|
for i in range(num_sigs): |
|
ret = RType(f.read(1)[0]) |
|
num_args = self.read_vlq32(f) |
|
args = [] |
|
for i in range(num_args): |
|
args.append(Type(f.read(1)[0])) |
|
self.sigs.append(Signature(ret, args)) |
|
|
|
def read_function_import_section(self, f): |
|
num_func_imps = self.read_vlq32(f) |
|
num_func_imp_sigs = self.read_vlq32(f) |
|
for i in range(num_func_imps): |
|
self.func_names.append(self.read_str(f)) |
|
num_sigs = self.read_vlq32(f) |
|
for j in range(num_sigs): |
|
sig = self.read_vlq32(f) |
|
self.func_imp_sigs.append(FuncImportSignature(sig, i)) |
|
|
|
def read_global_section(self, f): |
|
num_i32_zero = self.read_vlq32(f) |
|
num_f32_zero = self.read_vlq32(f) |
|
num_f64_zero = self.read_vlq32(f) |
|
num_i32_import = self.read_vlq32(f) |
|
num_f32_import = self.read_vlq32(f) |
|
num_f64_import = self.read_vlq32(f) |
|
for i in range(num_i32_zero): |
|
self.global_types.append(Type.I32) |
|
self.global_vals.append(0) |
|
for i in range(num_f32_zero): |
|
self.global_types.append(Type.F32) |
|
self.global_vals.append(0.0) |
|
for i in range(num_f64_zero): |
|
self.global_types.append(Type.F64) |
|
self.global_vals.append(0.0) |
|
for i in range(num_i32_import): |
|
self.global_types.append(Type.I32) |
|
self.global_vals.append(self.read_str(f)) |
|
for i in range(num_f32_import): |
|
self.global_types.append(Type.F32) |
|
self.global_vals.append(self.read_str(f)) |
|
for i in range(num_f64_import): |
|
self.global_types.append(Type.F64) |
|
self.global_vals.append(self.read_str(f)) |
|
|
|
def read_function_declaration_section(self, f): |
|
num_funcs = self.read_vlq32(f) |
|
for i in range(num_funcs): |
|
self.func_sigs.append(self.read_vlq32(f)) |
|
|
|
def read_function_pointer_tables(self, f): |
|
num_func_ptr_tables = self.read_vlq32(f) |
|
for i in range(num_func_ptr_tables): |
|
sig_index = self.read_vlq32(f) |
|
num_elems = self.read_vlq32(f) |
|
elems = [] |
|
for j in range(num_elems): |
|
elems.append(self.read_vlq32(f)) |
|
self.func_ptr_tables.append(FuncPtrTable(sig_index, elems)) |
|
|
|
def read_add_sub(self, f, op, tp): |
|
v1 = self.read_expr(f, tp) |
|
v2 = self.read_expr(f, tp) |
|
return (op, v1, v2) |
|
|
|
def read_comma(self, f, tp): |
|
tp1 = RType(f.read(1)[0]) |
|
v1 = self.read_expr(f, tp1) |
|
v2 = self.read_expr(f, tp) |
|
return (op, v1, v2) |
|
|
|
def read_stmt(self, f): |
|
raw, code, stmt, imm = self.read_code(f) |
|
ret = [] |
|
print("STMT", raw, code, stmt, imm, "stmt=", code & 0x7f) |
|
if raw: |
|
stcode = Stmt(code) |
|
if code == Stmt.SetLoc: |
|
ret = self.read_set_local(f) |
|
elif code == Stmt.SetGlo: |
|
ret = self.read_set_global(f) |
|
else: |
|
raise Exception("Unknown statement " + str(code)) |
|
else: |
|
if stmt == StmtWithImm.SetLoc: |
|
stcode = Stmt.SetLoc |
|
ret = self.read_set_local(f, imm) |
|
elif stmt == StmtWithImm.SetGlo: |
|
stcode = Stmt.SetGlo |
|
ret = self.read_set_global(f, imm) |
|
else: |
|
raise Exception("Unknown statement with imm " + str(stmt)) |
|
return (stcode,) + ret |
|
|
|
def read_stmt_list(self, f): |
|
num_stmts = self.read_vlq32(f) |
|
lst = [] |
|
if not num_stmts: |
|
print("***No statements") |
|
for i in range(num_stmts): |
|
lst.append(self.read_stmt(f)) |
|
return lst |
|
|
|
def read_set_local(self, f, loc = None): |
|
if loc is None: |
|
loc = self.read_vlq32(f) |
|
# read type |
|
#print("SET_LOCAL", loc, self.cur_local_type) |
|
tp = self.cur_local_type[loc] |
|
expr = self.read_expr(f, tp) |
|
return (loc, expr) |
|
|
|
def read_set_global(self, f, loc = None): |
|
if loc is None: |
|
loc = self.read_vlq32(f) |
|
# read type |
|
#print("SET_GLOBAL", loc, self.global_types) |
|
tp = self.global_types[loc] |
|
expr = self.read_expr(f, tp) |
|
return (loc, expr) |
|
|
|
def read_expr(self, f, tp): |
|
if tp == Type.I32: |
|
ret = self.read_expr_i32(f) |
|
else: |
|
raise Exception("Unknown type for expression " + str(tp)) |
|
return (tp,) + ret |
|
|
|
def read_expr_i32(self, f): |
|
raw, code, expr, imm = self.read_code(f) |
|
ret = [] |
|
print("EXPR", raw, code, expr, imm, "expr=", code & 0x7f) |
|
if raw: |
|
excode = I32(code) |
|
if code == I32.LitImm: |
|
ret = (self.read_vlq32i(f),) |
|
elif code == I32.Comma: |
|
ret = self.read_comma(f, tp = Type.I32) |
|
elif code == I32.Add: |
|
ret = self.read_add_sub(f, tp = Type.I32, op = "+") |
|
elif code == I32.Sub: |
|
ret = self.read_add_sub(f, tp = Type.I32, op = "-") |
|
else: |
|
raise Exception("Unknown expression I32 " + str(code)) |
|
else: |
|
if expr == I32WithImm.LitImm: |
|
excode = I32WithImm.LitImm |
|
ret = (imm,) |
|
else: |
|
raise Exception("Unknown expression I32 with imm " + str(expr)) |
|
return (excode,) + ret |
|
|
|
def read_function_definition_section(self, f): |
|
for i in range(len(self.func_sigs)): |
|
# construct local arg types from sig |
|
sig = self.sigs[self.func_sigs[i]] |
|
self.cur_local_type = list(sig.args) |
|
# read vars types |
|
num_i32_vars = 0 |
|
num_f32_vars = 0 |
|
num_f64_vars = 0 |
|
raw, code, op, imm = self.read_code(f) |
|
if raw: |
|
if code & VarTypes.I32: |
|
num_i32_vars = self.read_vlq32(f) |
|
if code & VarTypes.F32: |
|
num_f32_vars = self.read_vlq32(f) |
|
if code & VarTypes.F64: |
|
num_f64_vars = self.read_vlq32(f) |
|
else: |
|
num_i32_vars = imm |
|
# construct local var types from code (ie num_*_vars) |
|
for j in range(num_i32_vars): |
|
self.cur_local_type.append(Type.I32) |
|
for j in range(num_f32_vars): |
|
self.cur_local_type.append(Type.F32) |
|
for j in range(num_f64_vars): |
|
self.cur_local_type.append(Type.F64) |
|
print("Function f%d: %d, %d, %d" % (i, num_i32_vars, |
|
num_f32_vars, num_f64_vars)) |
|
print(" ", sig) |
|
print(" local", self.cur_local_type) |
|
# TODO: save stmt |
|
stmts = self.read_stmt_list(f) |
|
print(" stmt", len(stmts)) |
|
for stmt in stmts: |
|
print(stmt) |
|
|
|
|
|
def read_export_section(self, f): |
|
fmt = ExportFormat(f.read(1)[0]) |
|
if fmt == ExportFormat.Default: |
|
funcnum = self.read_vlq32(f) |
|
print("Export default function #%d" % funcnum) |
|
elif fmt == ExportFormat.Record: |
|
elen = self.read_vlq32(f) |
|
for i in range(elen): |
|
func = self.read_str(f) |
|
fidx = self.read_vlq32(f) |
|
print("Export record #%d: \"%s\": #%d" % (i, func, fidx)) |
|
|
|
def main(): |
|
wb = WasmBinary() |
|
with open(sys.argv[1], "rb") as f: |
|
wb.decode_wasm(f) |
|
data = f.read() |
|
if len(data): |
|
print("Unparsed %d bytes" % len(data)) |
|
|
|
if __name__ == "__main__": |
|
main() |