Last active
April 11, 2024 12:40
-
-
Save sma/dd7bd6aee192749637ec88405ec6ab22 to your computer and use it in GitHub Desktop.
A tiny, incomplete, proof of concept wasm interpreter that can add two numbers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// A simple incomplete WebAssembly interpreter. | |
library; | |
import 'dart:io'; | |
void main() { | |
// final bytes = File('sum.wasm').readAsBytesSync(); | |
final module = WasmModule(bytes); | |
print(module.exports); | |
print(module.invoke('add', [3, 4])); | |
} | |
/// Runs functions from a WebAssembly module. | |
/// | |
/// Use [invoke] to call an exported function. | |
/// | |
/// Note that this is a minimal implementation that only supports the sections | |
/// required for this example. It supports only single occurences of each | |
/// section and only a subset of types and instructions. It doesn't support | |
/// memory, globals, tables, or imported functions. | |
/// | |
/// Requires a Dart implementation with `sizeof(int) > 32 bits`. | |
class WasmModule { | |
WasmModule(this.bytes) { | |
if (_readUint32() != 0x6d736100) throw Exception('Invalid header'); | |
if (_readUint32() != 1) throw Exception('Invalid version'); | |
} | |
final List<int> bytes; | |
/// Where to read the next element. | |
var _index = 0; | |
/// Reads an unsigned byte. | |
int _readByte() => bytes[_index++]; | |
/// Reads an unsigned 32-bit little endian integer. | |
int _readUint32() => _readByte() | (_readByte() << 8) | (_readByte() << 16) | (_readByte() << 24); | |
/// Reads a LEB128 encoded unsiged integer (of up to 32 bits). | |
int _readUint() { | |
var result = 0; | |
var shift = 0; | |
var byte = 0x80; | |
while (byte & 0x80 != 0) { | |
byte = _readByte(); | |
result |= (byte & 0x7f) << shift; | |
shift += 7; | |
} | |
if (shift > 28) throw Exception('Integer too large'); | |
return result; | |
} | |
/// Reads an UTF-8 encoded byte string. | |
String _readString() { | |
final length = _readUint(); | |
final value = String.fromCharCodes(bytes, _index, _index + length); | |
_index += length; | |
return value; | |
} | |
/// Iterates over all sections of a given type. | |
void _each(WasmSection section, void Function() callback) { | |
_index = 8; | |
while (_index < bytes.length) { | |
final id = _readByte(); | |
final length = _readUint(); | |
final next = _index + length; | |
if (section.index == id) { | |
callback(); | |
if (_index > next) throw Exception('Section overflow'); | |
} | |
_index = next; | |
} | |
} | |
/// Reads a type definition (incomplete). | |
WasmType _readType() { | |
final id = _readByte(); | |
return switch (id) { | |
0x7f => WasmTypeI32(), | |
0x7e => WasmTypeI64(), | |
0x60 => | |
WasmTypeFn(List.generate(_readUint(), (_) => _readType()), List.generate(_readUint(), (_) => _readType())), | |
_ => throw Exception('Unsupported type ${id.toRadixString(16)}'), | |
}; | |
} | |
/// Returns all types declared in this module. | |
List<WasmType> get types { | |
final types = <WasmType>[]; | |
_each(WasmSection.type, () { | |
types.addAll(Iterable.generate(_readUint(), (_) => _readType())); | |
}); | |
return types; | |
} | |
/// Returns all functions defined in this module. | |
List<WasmFn> get functions { | |
final types = this.types; | |
final functions = <WasmFn>[]; | |
_each(WasmSection.function, () { | |
functions.addAll(Iterable.generate(_readUint(), (index) { | |
return WasmFn(types[_readUint()] as WasmTypeFn, index); | |
})); | |
}); | |
return functions; | |
} | |
/// Returns all exports defined in this module (incomplete). | |
List<WasmExport> get exports { | |
final functions = this.functions; | |
final exports = <WasmExport>[]; | |
_each(WasmSection.export, () { | |
exports.addAll(Iterable.generate(_readUint(), (_) { | |
final name = _readString(); | |
final kind = _readByte(); | |
switch (kind) { | |
case 0x00: | |
return WasmExportFn(name, functions[_readUint()]); | |
case 0x02: | |
return WasmExportMem(name, _readUint()); | |
default: | |
throw Exception('Unsupported export kind ${kind.toRadixString(16)}'); | |
} | |
})); | |
}); | |
return exports; | |
} | |
/// Invokes an exported function by [name], passing [args] which must match | |
/// the function's signature. Only a single return value is supported. Very, | |
/// very incomplete as it can only add two integers. | |
dynamic invoke(String name, List<dynamic> args) { | |
final fn = exports.whereType<WasmExportFn>().singleWhere((element) => element.name == name).fn; | |
if (fn.params.length != args.length) { | |
throw Exception('Invalid number of arguments (expected ${fn.params.length}, got ${args.length})'); | |
} | |
var i = 0; | |
for (final param in fn.params) { | |
if (!param.includes(args[i++])) { | |
throw Exception('Argument $i is not of type $param'); | |
} | |
} | |
_locateCode(fn.index); | |
final frame = <int>[...args]; | |
// Read locals and initialize them to zero. | |
final count = _readUint(); | |
for (var i = 0; i < count; i++) { | |
final n = _readUint(); | |
final t = _readType(); | |
if (t is! WasmTypeI32) throw Exception('Unsupported local type $t'); | |
frame.addAll(Iterable.generate(n, (_) => 0)); | |
} | |
return _executeCode(frame); | |
} | |
/// Seeks to the [index]th code block. | |
void _locateCode(int index) { | |
var found = 0; | |
_each(WasmSection.code, () { | |
int count = _readUint(); | |
for (int i = 0; i < count; i++) { | |
int size = _readUint(); | |
if (i == index) { | |
found = _index; | |
return; | |
} | |
_index += size; | |
} | |
}); | |
if (found == 0) throw Exception('Code $index not found'); | |
_index = found; | |
} | |
/// Executes the current code block. | |
int _executeCode(List<int> frame) { | |
final stack = <int>[]; | |
final blocks = <int>[]; | |
while (true) { | |
final opcode = _readByte(); | |
switch (opcode) { | |
case 0x00: // unreachable | |
throw Exception('Unreachable'); | |
case 0x01: // nop | |
break; | |
case 0x03: // loop | |
if (_readByte() != 0x40) throw Exception('Unsupported block type'); | |
blocks.add(_index); | |
case 0x04: // if | |
if (_readByte() != 0x40) throw Exception('Unsupported if type'); | |
if (stack.removeLast() == 0) { | |
_skipIf(); | |
} | |
case 0x0b: // end | |
if (blocks.isEmpty) return stack.single; | |
blocks.removeLast(); | |
case 0x0c: // br | |
final depth = _readUint(); | |
_index = blocks[blocks.length - depth]; | |
case 0x20: // local.get | |
stack.add(frame[_readUint()]); | |
case 0x21: // local.set | |
frame[_readUint()] = stack.removeLast(); | |
case 0x41: // i32.const | |
stack.add(_readUint()); | |
case 0x4e: // i32.ge_s | |
final b = stack.removeLast(); | |
final a = stack.removeLast(); | |
stack.add(a >= b ? 1 : 0); | |
case 0x6a: // i32.add | |
final b = stack.removeLast(); | |
final a = stack.removeLast(); | |
stack.add(a + b); | |
case 0x6b: // i32.sub | |
final b = stack.removeLast(); | |
final a = stack.removeLast(); | |
stack.add(a - b); | |
default: | |
throw Exception('Unsupported opcode ${opcode.toRadixString(16)} at ${(_index - 1).toRadixString(16)}'); | |
} | |
} | |
} | |
void _skipIf() { | |
// doesn't support else | |
var depth = 1; | |
while (depth > 0) { | |
final op = _readByte(); | |
switch (op) { | |
case 0x02: // block | |
case 0x03: // loop | |
case 0x04: // if | |
depth++; | |
case 0x0b: // end | |
depth--; | |
case 0x20 || 0x21 || 0x41: | |
_readUint(); | |
default: | |
// we need to skip all codes individually as they might have arguments | |
break; | |
} | |
} | |
} | |
} | |
enum WasmSection { | |
custom, // 0x00 | |
type, // 0x01 - a vector of function signatures | |
import, // 0x02 - a vector of imports (either functions, tables, memory, or globals) | |
function, // 0x03 - a vector of indices into the type section | |
table, // 0x04 - a vector of table types | |
memory, // 0x05 - a vector of memory types | |
global, // 0x06 - a vector of type/mutability/initializer expression triples | |
export, // 0x07 - a vector of exports (either functions, tables, memory, or globals) | |
start, // 0x08 - index of the start function | |
element, // 0x09 - a vector of element segments (which are complicated) | |
code, // 0x0a - a vector of function bodies (a vector of locals and instructions) | |
data, // 0x0b - a vector of data segments (which are also complicated) | |
dataCount, // 0x0c - the count of the data segments of this module | |
} | |
/// Abstract superclass for all types. | |
/// | |
/// See [WasmTypeI32], [WasmTypeI64], and [WasmTypeFn]. | |
sealed class WasmType { | |
const WasmType(); | |
bool includes(Object? value); | |
} | |
class WasmTypeI32 extends WasmType { | |
const WasmTypeI32(); | |
@override | |
String toString() => 'i32'; | |
@override | |
bool includes(Object? value) => value is int; | |
} | |
class WasmTypeI64 extends WasmType { | |
const WasmTypeI64(); | |
@override | |
String toString() => 'i64'; | |
@override | |
bool includes(Object? value) => value is int; | |
} | |
class WasmTypeFn extends WasmType { | |
const WasmTypeFn(this.params, this.results); | |
final List<WasmType> params; | |
final List<WasmType> results; | |
@override | |
String toString() => '(${params.join(', ')}) -> (${results.join(', ')})'; | |
@override | |
bool includes(Object? value) => value is Function; | |
} | |
/// Represents a function definition. | |
class WasmFn { | |
const WasmFn(this.type, this.index); | |
final WasmTypeFn type; | |
final int index; | |
List<WasmType> get params => type.params; | |
List<WasmType> get results => type.results; | |
WasmType get result => results.single; | |
@override | |
String toString() => '$type [$index]'; | |
} | |
/// Abstract superclass for all exports. | |
/// | |
/// See [WasmExportFn], [WasmExportMem]. | |
sealed class WasmExport { | |
const WasmExport(); | |
} | |
class WasmExportFn extends WasmExport { | |
const WasmExportFn(this.name, this.fn); | |
final String name; | |
final WasmFn fn; | |
@override | |
String toString() => 'fn $name(${fn.params.join(', ')}) -> ${fn.result}'; | |
} | |
class WasmExportMem extends WasmExport { | |
const WasmExportMem(this.name, this.index); | |
final String name; | |
final int index; | |
@override | |
String toString() => 'mem $name[$index]'; | |
} | |
// $ hexdump -C add.wasm | |
// 00000000 00 61 73 6d 01 00 00 00 01 07 01 60 02 7f 7f 01 |.asm.......`....| | |
// 00000010 7f 03 02 01 00 07 07 01 03 61 64 64 00 00 0a 09 |.........add....| | |
// 00000020 01 07 00 20 00 20 01 6a 0b |... . .j.| | |
// 00000029 | |
const bytes = <int>[ | |
0x00, | |
0x61, | |
0x73, | |
0x6d, | |
0x01, | |
0x00, | |
0x00, | |
0x00, | |
0x01, | |
0x07, | |
0x01, | |
0x60, | |
0x02, | |
0x7f, | |
0x7f, | |
0x01, | |
0x7f, | |
0x03, | |
0x02, | |
0x01, | |
0x00, | |
0x07, | |
0x07, | |
0x01, | |
0x03, | |
0x61, | |
0x64, | |
0x64, | |
0x00, | |
0x00, | |
0x0a, | |
0x09, | |
0x01, | |
0x07, | |
0x00, | |
0x20, | |
0x00, | |
0x20, | |
0x01, | |
0x6a, | |
0x0b, | |
]; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment