Skip to content

Instantly share code, notes, and snippets.

@sma
Last active April 11, 2024 12:40
Show Gist options
  • Save sma/dd7bd6aee192749637ec88405ec6ab22 to your computer and use it in GitHub Desktop.
Save sma/dd7bd6aee192749637ec88405ec6ab22 to your computer and use it in GitHub Desktop.
A tiny, incomplete, proof of concept wasm interpreter that can add two numbers
/// A simple incomplete WebAssembly interpreter.
library;
import 'dart:io';
void main() {
// final bytes = File('sum.wasm').readAsBytesSync();
final module = WasmModule(bytes);
print(module.exports);
print(module.invoke('add', [3, 4]));
}
/// Runs functions from a WebAssembly module.
///
/// Use [invoke] to call an exported function.
///
/// Note that this is a minimal implementation that only supports the sections
/// required for this example. It supports only single occurences of each
/// section and only a subset of types and instructions. It doesn't support
/// memory, globals, tables, or imported functions.
///
/// Requires a Dart implementation with `sizeof(int) > 32 bits`.
class WasmModule {
WasmModule(this.bytes) {
if (_readUint32() != 0x6d736100) throw Exception('Invalid header');
if (_readUint32() != 1) throw Exception('Invalid version');
}
final List<int> bytes;
/// Where to read the next element.
var _index = 0;
/// Reads an unsigned byte.
int _readByte() => bytes[_index++];
/// Reads an unsigned 32-bit little endian integer.
int _readUint32() => _readByte() | (_readByte() << 8) | (_readByte() << 16) | (_readByte() << 24);
/// Reads a LEB128 encoded unsiged integer (of up to 32 bits).
int _readUint() {
var result = 0;
var shift = 0;
var byte = 0x80;
while (byte & 0x80 != 0) {
byte = _readByte();
result |= (byte & 0x7f) << shift;
shift += 7;
}
if (shift > 28) throw Exception('Integer too large');
return result;
}
/// Reads an UTF-8 encoded byte string.
String _readString() {
final length = _readUint();
final value = String.fromCharCodes(bytes, _index, _index + length);
_index += length;
return value;
}
/// Iterates over all sections of a given type.
void _each(WasmSection section, void Function() callback) {
_index = 8;
while (_index < bytes.length) {
final id = _readByte();
final length = _readUint();
final next = _index + length;
if (section.index == id) {
callback();
if (_index > next) throw Exception('Section overflow');
}
_index = next;
}
}
/// Reads a type definition (incomplete).
WasmType _readType() {
final id = _readByte();
return switch (id) {
0x7f => WasmTypeI32(),
0x7e => WasmTypeI64(),
0x60 =>
WasmTypeFn(List.generate(_readUint(), (_) => _readType()), List.generate(_readUint(), (_) => _readType())),
_ => throw Exception('Unsupported type ${id.toRadixString(16)}'),
};
}
/// Returns all types declared in this module.
List<WasmType> get types {
final types = <WasmType>[];
_each(WasmSection.type, () {
types.addAll(Iterable.generate(_readUint(), (_) => _readType()));
});
return types;
}
/// Returns all functions defined in this module.
List<WasmFn> get functions {
final types = this.types;
final functions = <WasmFn>[];
_each(WasmSection.function, () {
functions.addAll(Iterable.generate(_readUint(), (index) {
return WasmFn(types[_readUint()] as WasmTypeFn, index);
}));
});
return functions;
}
/// Returns all exports defined in this module (incomplete).
List<WasmExport> get exports {
final functions = this.functions;
final exports = <WasmExport>[];
_each(WasmSection.export, () {
exports.addAll(Iterable.generate(_readUint(), (_) {
final name = _readString();
final kind = _readByte();
switch (kind) {
case 0x00:
return WasmExportFn(name, functions[_readUint()]);
case 0x02:
return WasmExportMem(name, _readUint());
default:
throw Exception('Unsupported export kind ${kind.toRadixString(16)}');
}
}));
});
return exports;
}
/// Invokes an exported function by [name], passing [args] which must match
/// the function's signature. Only a single return value is supported. Very,
/// very incomplete as it can only add two integers.
dynamic invoke(String name, List<dynamic> args) {
final fn = exports.whereType<WasmExportFn>().singleWhere((element) => element.name == name).fn;
if (fn.params.length != args.length) {
throw Exception('Invalid number of arguments (expected ${fn.params.length}, got ${args.length})');
}
var i = 0;
for (final param in fn.params) {
if (!param.includes(args[i++])) {
throw Exception('Argument $i is not of type $param');
}
}
_locateCode(fn.index);
final frame = <int>[...args];
// Read locals and initialize them to zero.
final count = _readUint();
for (var i = 0; i < count; i++) {
final n = _readUint();
final t = _readType();
if (t is! WasmTypeI32) throw Exception('Unsupported local type $t');
frame.addAll(Iterable.generate(n, (_) => 0));
}
return _executeCode(frame);
}
/// Seeks to the [index]th code block.
void _locateCode(int index) {
var found = 0;
_each(WasmSection.code, () {
int count = _readUint();
for (int i = 0; i < count; i++) {
int size = _readUint();
if (i == index) {
found = _index;
return;
}
_index += size;
}
});
if (found == 0) throw Exception('Code $index not found');
_index = found;
}
/// Executes the current code block.
int _executeCode(List<int> frame) {
final stack = <int>[];
final blocks = <int>[];
while (true) {
final opcode = _readByte();
switch (opcode) {
case 0x00: // unreachable
throw Exception('Unreachable');
case 0x01: // nop
break;
case 0x03: // loop
if (_readByte() != 0x40) throw Exception('Unsupported block type');
blocks.add(_index);
case 0x04: // if
if (_readByte() != 0x40) throw Exception('Unsupported if type');
if (stack.removeLast() == 0) {
_skipIf();
}
case 0x0b: // end
if (blocks.isEmpty) return stack.single;
blocks.removeLast();
case 0x0c: // br
final depth = _readUint();
_index = blocks[blocks.length - depth];
case 0x20: // local.get
stack.add(frame[_readUint()]);
case 0x21: // local.set
frame[_readUint()] = stack.removeLast();
case 0x41: // i32.const
stack.add(_readUint());
case 0x4e: // i32.ge_s
final b = stack.removeLast();
final a = stack.removeLast();
stack.add(a >= b ? 1 : 0);
case 0x6a: // i32.add
final b = stack.removeLast();
final a = stack.removeLast();
stack.add(a + b);
case 0x6b: // i32.sub
final b = stack.removeLast();
final a = stack.removeLast();
stack.add(a - b);
default:
throw Exception('Unsupported opcode ${opcode.toRadixString(16)} at ${(_index - 1).toRadixString(16)}');
}
}
}
void _skipIf() {
// doesn't support else
var depth = 1;
while (depth > 0) {
final op = _readByte();
switch (op) {
case 0x02: // block
case 0x03: // loop
case 0x04: // if
depth++;
case 0x0b: // end
depth--;
case 0x20 || 0x21 || 0x41:
_readUint();
default:
// we need to skip all codes individually as they might have arguments
break;
}
}
}
}
enum WasmSection {
custom, // 0x00
type, // 0x01 - a vector of function signatures
import, // 0x02 - a vector of imports (either functions, tables, memory, or globals)
function, // 0x03 - a vector of indices into the type section
table, // 0x04 - a vector of table types
memory, // 0x05 - a vector of memory types
global, // 0x06 - a vector of type/mutability/initializer expression triples
export, // 0x07 - a vector of exports (either functions, tables, memory, or globals)
start, // 0x08 - index of the start function
element, // 0x09 - a vector of element segments (which are complicated)
code, // 0x0a - a vector of function bodies (a vector of locals and instructions)
data, // 0x0b - a vector of data segments (which are also complicated)
dataCount, // 0x0c - the count of the data segments of this module
}
/// Abstract superclass for all types.
///
/// See [WasmTypeI32], [WasmTypeI64], and [WasmTypeFn].
sealed class WasmType {
const WasmType();
bool includes(Object? value);
}
class WasmTypeI32 extends WasmType {
const WasmTypeI32();
@override
String toString() => 'i32';
@override
bool includes(Object? value) => value is int;
}
class WasmTypeI64 extends WasmType {
const WasmTypeI64();
@override
String toString() => 'i64';
@override
bool includes(Object? value) => value is int;
}
class WasmTypeFn extends WasmType {
const WasmTypeFn(this.params, this.results);
final List<WasmType> params;
final List<WasmType> results;
@override
String toString() => '(${params.join(', ')}) -> (${results.join(', ')})';
@override
bool includes(Object? value) => value is Function;
}
/// Represents a function definition.
class WasmFn {
const WasmFn(this.type, this.index);
final WasmTypeFn type;
final int index;
List<WasmType> get params => type.params;
List<WasmType> get results => type.results;
WasmType get result => results.single;
@override
String toString() => '$type [$index]';
}
/// Abstract superclass for all exports.
///
/// See [WasmExportFn], [WasmExportMem].
sealed class WasmExport {
const WasmExport();
}
class WasmExportFn extends WasmExport {
const WasmExportFn(this.name, this.fn);
final String name;
final WasmFn fn;
@override
String toString() => 'fn $name(${fn.params.join(', ')}) -> ${fn.result}';
}
class WasmExportMem extends WasmExport {
const WasmExportMem(this.name, this.index);
final String name;
final int index;
@override
String toString() => 'mem $name[$index]';
}
// $ hexdump -C add.wasm
// 00000000 00 61 73 6d 01 00 00 00 01 07 01 60 02 7f 7f 01 |.asm.......`....|
// 00000010 7f 03 02 01 00 07 07 01 03 61 64 64 00 00 0a 09 |.........add....|
// 00000020 01 07 00 20 00 20 01 6a 0b |... . .j.|
// 00000029
const bytes = <int>[
0x00,
0x61,
0x73,
0x6d,
0x01,
0x00,
0x00,
0x00,
0x01,
0x07,
0x01,
0x60,
0x02,
0x7f,
0x7f,
0x01,
0x7f,
0x03,
0x02,
0x01,
0x00,
0x07,
0x07,
0x01,
0x03,
0x61,
0x64,
0x64,
0x00,
0x00,
0x0a,
0x09,
0x01,
0x07,
0x00,
0x20,
0x00,
0x20,
0x01,
0x6a,
0x0b,
];
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment