|
"""A register machine for compiling set-based esoteric languages into.""" |
|
|
|
from __future__ import annotations |
|
from typing import (FrozenSet, Optional, Tuple, MutableMapping, Mapping, |
|
Dict, Iterable, Union, Sequence, AbstractSet, Set, |
|
List) |
|
from dataclasses import dataclass |
|
from abc import ABC, abstractmethod |
|
from enum import Enum |
|
|
|
@dataclass(frozen=True) |
|
class HFSet: |
|
"""A hereditarily finite set.""" |
|
elements: FrozenSet[HFSet] |
|
|
|
def __str__(self) -> str: |
|
return f"[{', '.join(str(e) for e in self.elements)}]" |
|
|
|
def union(self, other: HFSet) -> HFSet: |
|
return HFSet(self.elements | other.elements) |
|
|
|
def intersection(self, other: HFSet) -> HFSet: |
|
return HFSet(self.elements & other.elements) |
|
|
|
def difference(self, other: HFSet) -> HFSet: |
|
return HFSet(self.elements - other.elements) |
|
|
|
def symmetric_difference(self, other: HFSet) -> HFSet: |
|
return HFSet(self.elements ^ other.elements) |
|
|
|
def deconstruct(self) -> Optional[Tuple[HFSet, HFSet]]: |
|
if self.elements: |
|
# so far, very inefficient |
|
# may be alleviated if using ordered lists to implement HFSet |
|
it = iter(self.elements) |
|
elem = next(it) |
|
return elem, HFSet(frozenset(it)) |
|
return None |
|
|
|
@property |
|
def is_inhabited(self) -> bool: |
|
return bool(self.elements) |
|
|
|
def add(self, element: HFSet) -> HFSet: |
|
return self.union(HFSet(frozenset([element]))) |
|
|
|
EMPTY_SET: HFSet = HFSet(frozenset()) |
|
|
|
class OpError(ABC): |
|
"""A runtime error due to executing an Op with incorrect arguments.""" |
|
|
|
NonHaltState = int |
|
State = Optional[NonHaltState] |
|
Register = int |
|
Registers = MutableMapping[Register, HFSet] |
|
Values = Tuple[HFSet, ...] |
|
|
|
@dataclass |
|
class IOSpec: |
|
"""A specification how to treat a machine as a function.""" |
|
inputs: Sequence[Register] |
|
outputs: Sequence[Register] |
|
|
|
def __post_init__(self) -> None: |
|
assert len(self.inputs) == len(set(self.inputs)), "Input register shouldn’t clash." |
|
|
|
def is_allowed_for(self, regs: AbstractSet[Register]) -> bool: |
|
return set(self.inputs).union(self.outputs).issubset(regs) |
|
|
|
class Machine: |
|
"""A register machine which can be run.""" |
|
|
|
_state: State |
|
_code: Mapping[NonHaltState, Op] |
|
_registers: Dict[Register, HFSet] |
|
_call_stack: List[State] |
|
_call_stack_max_len: int # useful when runaway recursion happens |
|
_actions: Actions |
|
|
|
def __init__( |
|
self, code: Mapping[NonHaltState, Op], |
|
start_state: State, |
|
register_names: Optional[Iterable[Register]] = None, |
|
call_stack_max_len: int = 100) -> None: |
|
# first ensure the code is valid |
|
ops = code.values() |
|
referenced_states: Set[State] |
|
referenced_states = set([start_state]) |
|
referenced_states.update(*(op.referenced_states() for op in ops)) |
|
referenced_states.discard(None) |
|
assert referenced_states <= code.keys() |
|
referenced_registers: Set[Register] |
|
referenced_registers = set(reg for op in ops |
|
for reg in op.referenced_regs()) |
|
if register_names is not None: |
|
assert referenced_registers.issubset(register_names) |
|
else: |
|
register_names = referenced_registers |
|
# okay |
|
self._code = code |
|
self._state = start_state |
|
self._registers = {reg: EMPTY_SET for reg in register_names} |
|
self._call_stack = [] |
|
self._call_stack_max_len = call_stack_max_len |
|
self._actions = Actions(self) |
|
|
|
def run_step(self) -> Union[bool, OpError]: |
|
"""Run the machine a single step. |
|
|
|
Returns `True` if machine has halted normally. |
|
Returns `False` if there’s still work to do. |
|
""" |
|
if self._state is None: |
|
return True |
|
res = self._code[self._state].run(self._registers, self._actions) |
|
if isinstance(res, OpError): |
|
return res |
|
self._state = res |
|
return self._state is None |
|
|
|
def run(self) -> Optional[OpError]: |
|
"""Run the machine to the end.""" |
|
while True: |
|
if res := self.run_step(): |
|
if isinstance(res, OpError): |
|
return res |
|
return None |
|
|
|
def eval(self, spec: IOSpec, inputs: Values) -> Union[Values, OpError]: |
|
"""Apply the machine to inputs and return outputs or an error.""" |
|
|
|
assert len(spec.inputs) == len(inputs), "Provide as many values as it says in the spec." |
|
regs = self._registers |
|
assert spec.is_allowed_for(regs.keys()), "The spec should mention only the registers present in this machine." |
|
for reg, val in zip(spec.inputs, inputs): |
|
regs[reg] = val |
|
if (res := self.run()) is not None: |
|
return res |
|
return tuple(regs[reg] for reg in spec.outputs) |
|
|
|
class Actions: |
|
"""Machine actions which are available from inside `Op.run`.""" |
|
|
|
_machine: Machine |
|
|
|
def __init__(self, _machine: Machine) -> None: |
|
self._machine = _machine |
|
|
|
def call_stack_push(self, return_state: State) -> bool: |
|
stack = self._machine._call_stack |
|
if len(stack) == self._machine._call_stack_max_len: |
|
return False |
|
stack.append(return_state) |
|
return True |
|
|
|
def call_stack_pop(self) -> Optional[State]: |
|
stack = self._machine._call_stack |
|
if not stack: |
|
return None |
|
return stack.pop() |
|
|
|
class Op(ABC): |
|
"""An operation of the register machine.""" |
|
|
|
@abstractmethod |
|
def run(self, regs: Registers, a: Actions) -> Union[State, OpError]: |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def referenced_states(self) -> Iterable[State]: |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def referenced_regs(self) -> Iterable[Register]: |
|
raise NotImplementedError |
|
|
|
@dataclass |
|
class UnconditionalNextState(ABC): |
|
next: State |
|
|
|
@dataclass |
|
class MovOp(UnconditionalNextState, Op): |
|
"""Copy a value from `source` to `target` register.""" |
|
source: Register |
|
target: Register |
|
|
|
def run(self, regs: Registers, a: Actions) -> State: |
|
regs[self.target] = regs[self.source] |
|
return self.next |
|
|
|
def referenced_states(self) -> Iterable[State]: |
|
return (self.next,) |
|
|
|
def referenced_regs(self) -> Iterable[Register]: |
|
return self.source, self.target |
|
|
|
@dataclass |
|
class ClearOp(UnconditionalNextState, Op): |
|
"""Assign an empty set to the register.""" |
|
target: Register |
|
|
|
def run(self, regs: Registers, a: Actions) -> State: |
|
regs[self.target] = EMPTY_SET |
|
return self.next |
|
|
|
def referenced_states(self) -> Iterable[State]: |
|
return (self.next,) |
|
|
|
def referenced_regs(self) -> Iterable[Register]: |
|
return (self.target,) |
|
|
|
@dataclass |
|
class AddOp(UnconditionalNextState, Op): |
|
"""Include an element into a set `target`. |
|
|
|
For implementing `[e1, ...]` starting with `[]` and adding `e1`, ... |
|
""" |
|
element: Register |
|
target: Register |
|
|
|
def run(self, regs: Registers, a: Actions) -> State: |
|
regs[self.target] = regs[self.target].add(regs[self.element]) |
|
return self.next |
|
|
|
def referenced_states(self) -> Iterable[State]: |
|
return (self.next,) |
|
|
|
def referenced_regs(self) -> Iterable[Register]: |
|
return self.target, self.element |
|
|
|
class BinOpType(Enum): |
|
UNION = HFSet.union |
|
INTERSECTION = HFSet.intersection |
|
DIFFERENCE = HFSet.difference |
|
SYM_DIFFERENCE = HFSet.symmetric_difference |
|
|
|
@dataclass |
|
class BinOp(UnconditionalNextState, Op): |
|
"""Binary operation (`+`, `*`, `-`, `~`).""" |
|
|
|
type_: BinOpType |
|
arg1: Register |
|
arg2: Register |
|
res: Register |
|
|
|
def run(self, regs: Registers, a: Actions) -> State: |
|
f = self.type_.value |
|
regs[self.res] = f(regs[self.arg1], regs[self.arg2]) |
|
return self.next |
|
|
|
def referenced_states(self) -> Iterable[State]: |
|
return (self.next,) |
|
|
|
def referenced_regs(self) -> Iterable[Register]: |
|
return self.arg1, self.arg2, self.res |
|
|
|
@dataclass |
|
class JumpOp(Op): |
|
"""A simple conditional jump based on whether a set is empty.""" |
|
|
|
next_when_empty: State |
|
next_when_inhabited: State |
|
test: Register |
|
|
|
def run(self, regs: Registers, a: Actions) -> State: |
|
if regs[self.test].is_inhabited: |
|
return self.next_when_inhabited |
|
return self.next_when_empty |
|
|
|
def referenced_states(self) -> Iterable[State]: |
|
return self.next_when_empty, self.next_when_inhabited |
|
|
|
def referenced_regs(self) -> Iterable[Register]: |
|
return (self.test,) |
|
|
|
@dataclass |
|
class DeconstructOp(Op): |
|
"""A deconstruction if a set is nonempty, and a conditional jump.""" |
|
|
|
next_when_empty: State |
|
next_when_inhabited: State |
|
test: Register |
|
element: Register |
|
rest: Register |
|
|
|
def run(self, regs: Registers, a: Actions) -> State: |
|
deconstruction = regs[self.test].deconstruct() |
|
if deconstruction is not None: |
|
regs[self.element], regs[self.rest] = deconstruction |
|
return self.next_when_inhabited |
|
return self.next_when_empty |
|
|
|
def referenced_states(self) -> Iterable[State]: |
|
return self.next_when_empty, self.next_when_inhabited |
|
|
|
def referenced_regs(self) -> Iterable[Register]: |
|
return self.test, self.element, self.rest |
|
|
|
@dataclass |
|
class AssertOpError(OpError): |
|
"""An error due to a failed assertion.""" |
|
message: str |
|
|
|
@dataclass |
|
class AssertOp(UnconditionalNextState, Op): |
|
"""Assertion a set is nonempty.""" |
|
|
|
test: Register |
|
fail_message: str |
|
|
|
def run(self, regs: Registers, a: Actions) -> Union[State, AssertOpError]: |
|
if regs[self.test].is_inhabited: |
|
return self.next |
|
return AssertOpError(self.fail_message) |
|
|
|
def referenced_states(self) -> Iterable[State]: |
|
return (self.next,) |
|
|
|
def referenced_regs(self) -> Iterable[Register]: |
|
return (self.test,) |
|
|
|
@dataclass |
|
class CallStackFullOpError(OpError): |
|
"""An error when exceeding call stack’s capacity.""" |
|
|
|
class CallOp(ABC): |
|
"""Function call.""" |
|
|
|
call_state: State |
|
return_state: State |
|
|
|
def run(self, regs: Registers, a: Actions) -> Union[State, CallStackFullOpError]: |
|
if a.call_stack_push(self.return_state): |
|
return self.call_state |
|
return CallStackFullOpError() |
|
|
|
def referenced_states(self) -> Iterable[State]: |
|
return (self.call_state, self.return_state) |
|
|
|
def referenced_regs(self) -> Iterable[Register]: |
|
return () |
|
|
|
@dataclass |
|
class ReturnOpError(OpError): |
|
"""An error when returning with empty call stack.""" |
|
|
|
class ReturnOp(ABC): |
|
"""Return from function.""" |
|
|
|
def run(self, regs: Registers, a: Actions) -> Union[State, ReturnOpError]: |
|
return_state = a.call_stack_pop() |
|
if return_state is None: |
|
return ReturnOpError() |
|
return return_state |
|
|
|
def referenced_states(self) -> Iterable[State]: |
|
return () |
|
|
|
def referenced_regs(self) -> Iterable[Register]: |
|
return () |
|
|
|
if __name__ == "__main__": |
|
def nat(arg: int) -> HFSet: |
|
assert arg >= 0 |
|
res = EMPTY_SET |
|
for _ in range(arg): |
|
res = EMPTY_SET.add(res) |
|
return res |
|
|
|
def from_nat(arg: HFSet) -> Optional[int]: |
|
res = 0 |
|
while (deconstruction := arg.deconstruct()) is not None: |
|
res += 1 |
|
arg, rest = deconstruction |
|
if rest.is_inhabited: |
|
return None |
|
return res |
|
|
|
def pprint_many(title: str, *values: HFSet): |
|
print(title) |
|
print(*map(str, values), sep=", ") |
|
print(*map(from_nat, values), sep=", ") |
|
|
|
def main() -> None: |
|
rx, ry, rz, rt = 0, 1, 2, 3 |
|
code = { |
|
0: MovOp(1, ry, rz), |
|
1: DeconstructOp(None, 2, rx, rx, rt), |
|
2: ClearOp(3, rt), |
|
3: AddOp(4, rz, rt), |
|
4: MovOp(1, rt, rz), |
|
} |
|
machine = Machine(code, 0) |
|
spec = IOSpec(inputs=[rx, ry], outputs=[rz]) |
|
inputs = (nat(3), nat(1)) |
|
pprint_many("Inputs", *inputs) |
|
outputs = machine.eval(spec, inputs) |
|
if isinstance(outputs, OpError): |
|
print(f"Error occurred: {outputs}") |
|
else: |
|
pprint_many("Outputs:", *outputs) |
|
|
|
main() |