Created October 31, 2024 12:08
A Wadler-style pretty-printer in Python
from __future__ import annotations
from dataclasses import dataclass
# + --- +
# | ASM |
# + --- +
def interpret(insts: list[AsmInst]) -> str:
"""Interpret the ASM instructions as a string."""
result = ""
for inst in insts:
match inst:
case text if isinstance(text, str):
result += inst
case indent if isinstance(indent, int):
result += f"\n{' ' * indent}"
return result
AsmInst = str | int
# + --- +
# | Doc |
# + --- +
def with_add(cls):
"""Add suitable __add__ and __radd__ methods to cls so that instances of it
can be concatenated using "+".
def add(self: DocExpr, other: DocExpr) -> DocExpr:
return Concat(self, other)
def radd(self: DocExpr, other: DocExpr) -> DocExpr:
return Concat(other, self)
setattr(cls, "__add__", add)
setattr(cls, "__radd__", radd)
return cls
class Nil:
class Concat:
car: DocExpr
cdr: DocExpr
class Br:
text: str = ""
class Nest:
indent: int
doc: DocExpr
class Group:
doc: DocExpr
DocExpr = Nil | Concat | str | Br | Nest | Group
# + -- +
# | IR |
# + -- +
class BR:
text: str
class NEST:
indent: int
insts: list[IrExpr]
class GROUP:
insts: list[IrExpr]
IrExpr = str | BR | NEST | GROUP
def lower(expr: DocExpr) -> list[IrExpr]:
"""Lower a Doc expression to an equivalent sequence of IR expressions."""
match expr:
case Nil():
return []
case Concat(car, cdr):
return lower(car) + lower(cdr)
case text if isinstance(text, str):
return [text]
case Br(text):
return [BR(text)]
case Nest(indent, expr):
return [NEST(indent, lower(expr))]
case Group(expr):
return [GROUP(lower(expr))]
case _:
raise ValueError("expected a doc")
def compile(exprs: list[IrExpr], max_len: int) -> list[AsmInst]:
"""Compile a sequence of IR expressions into a list of ASM instructions."""
pos = 0
asm = []
def process(expr, indent=0, flat=True):
nonlocal pos, asm
match expr:
case text if isinstance(text, str):
pos += len(text)
case BR(text):
if flat:
pos += len(text)
pos = indent
case NEST(nest_indent, nest_exprs):
for nest_expr in nest_exprs:
process(nest_expr, indent=indent + nest_indent, flat=flat)
case GROUP(group_exprs):
flat = fits_flat(group_exprs, max_len - pos)
for group_expr in group_exprs:
process(group_expr, indent=indent, flat=flat)
for expr in exprs:
return asm
def fits_flat(exprs: list[IrExpr], width: int) -> bool:
"""Can the list of IR expressions can be laid out flat without exceeding
the provided width?"""
if width < 0:
return False
elif not exprs:
return True
first, *rest = exprs
match first:
case text if isinstance(text, str):
return fits_flat(rest, width - len(text))
case BR(text):
return fits_flat(rest, width - len(text))
case NEST(_, nest_exprs):
return fits_flat(nest_exprs + rest, width)
case GROUP(group_exprs):
return fits_flat(group_exprs + rest, width)
# + ---------------- +
# | Public Interface |
# + ---------------- +
def layout(doc: DocExpr, max_len: int) -> str:
"""Transform a Doc expression into a string it represents, trying not to
exceed a line width of max_len.
ir = lower(doc)
asm = compile(ir, max_len)
return interpret(asm)
nil = Nil()
br = Br
nest = Nest
group = Group
