Created
October 31, 2024 12:08
-
-
Save wjlewis/3dfc1b02e2d16d2079a98042b0e18c95 to your computer and use it in GitHub Desktop.
A Wadler-style pretty-printer in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from dataclasses import dataclass | |
# + --- + | |
# | ASM | | |
# + --- + | |
def interpret(insts: list[AsmInst]) -> str: | |
"""Interpret the ASM instructions as a string.""" | |
result = "" | |
for inst in insts: | |
match inst: | |
case text if isinstance(text, str): | |
result += inst | |
case indent if isinstance(indent, int): | |
result += f"\n{' ' * indent}" | |
return result | |
AsmInst = str | int | |
# + --- + | |
# | Doc | | |
# + --- + | |
def with_add(cls): | |
"""Add suitable __add__ and __radd__ methods to cls so that instances of it | |
can be concatenated using "+". | |
""" | |
def add(self: DocExpr, other: DocExpr) -> DocExpr: | |
return Concat(self, other) | |
def radd(self: DocExpr, other: DocExpr) -> DocExpr: | |
return Concat(other, self) | |
setattr(cls, "__add__", add) | |
setattr(cls, "__radd__", radd) | |
return cls | |
@with_add | |
@dataclass | |
class Nil: | |
pass | |
@with_add | |
@dataclass | |
class Concat: | |
car: DocExpr | |
cdr: DocExpr | |
@with_add | |
@dataclass | |
class Br: | |
text: str = "" | |
@with_add | |
@dataclass | |
class Nest: | |
indent: int | |
doc: DocExpr | |
@with_add | |
@dataclass | |
class Group: | |
doc: DocExpr | |
DocExpr = Nil | Concat | str | Br | Nest | Group | |
# + -- + | |
# | IR | | |
# + -- + | |
@dataclass | |
class BR: | |
text: str | |
@dataclass | |
class NEST: | |
indent: int | |
insts: list[IrExpr] | |
@dataclass | |
class GROUP: | |
insts: list[IrExpr] | |
IrExpr = str | BR | NEST | GROUP | |
def lower(expr: DocExpr) -> list[IrExpr]: | |
"""Lower a Doc expression to an equivalent sequence of IR expressions.""" | |
match expr: | |
case Nil(): | |
return [] | |
case Concat(car, cdr): | |
return lower(car) + lower(cdr) | |
case text if isinstance(text, str): | |
return [text] | |
case Br(text): | |
return [BR(text)] | |
case Nest(indent, expr): | |
return [NEST(indent, lower(expr))] | |
case Group(expr): | |
return [GROUP(lower(expr))] | |
case _: | |
raise ValueError("expected a doc") | |
def compile(exprs: list[IrExpr], max_len: int) -> list[AsmInst]: | |
"""Compile a sequence of IR expressions into a list of ASM instructions.""" | |
pos = 0 | |
asm = [] | |
def process(expr, indent=0, flat=True): | |
nonlocal pos, asm | |
match expr: | |
case text if isinstance(text, str): | |
asm.append(text) | |
pos += len(text) | |
case BR(text): | |
if flat: | |
asm.append(text) | |
pos += len(text) | |
else: | |
asm.append(indent) | |
pos = indent | |
case NEST(nest_indent, nest_exprs): | |
for nest_expr in nest_exprs: | |
process(nest_expr, indent=indent + nest_indent, flat=flat) | |
case GROUP(group_exprs): | |
flat = fits_flat(group_exprs, max_len - pos) | |
for group_expr in group_exprs: | |
process(group_expr, indent=indent, flat=flat) | |
for expr in exprs: | |
process(expr) | |
return asm | |
def fits_flat(exprs: list[IrExpr], width: int) -> bool: | |
"""Can the list of IR expressions can be laid out flat without exceeding | |
the provided width?""" | |
if width < 0: | |
return False | |
elif not exprs: | |
return True | |
first, *rest = exprs | |
match first: | |
case text if isinstance(text, str): | |
return fits_flat(rest, width - len(text)) | |
case BR(text): | |
return fits_flat(rest, width - len(text)) | |
case NEST(_, nest_exprs): | |
return fits_flat(nest_exprs + rest, width) | |
case GROUP(group_exprs): | |
return fits_flat(group_exprs + rest, width) | |
# + ---------------- + | |
# | Public Interface | | |
# + ---------------- + | |
def layout(doc: DocExpr, max_len: int) -> str: | |
"""Transform a Doc expression into a string it represents, trying not to | |
exceed a line width of max_len. | |
""" | |
ir = lower(doc) | |
asm = compile(ir, max_len) | |
return interpret(asm) | |
nil = Nil() | |
br = Br | |
nest = Nest | |
group = Group |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment