Skip to content

Instantly share code, notes, and snippets.

@rrika
Last active May 8, 2019 11:38
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rrika/fcfa43aee2ec95b0bd482e1c508f55b5 to your computer and use it in GitHub Desktop.
Save rrika/fcfa43aee2ec95b0bd482e1c508f55b5 to your computer and use it in GitHub Desktop.
Trying to write a borrow checker
class Expr: pass
class Type: pass
class Stmt: pass
class RefMut(Type):
def __init__(self, lt, ty):
self.lt = lt
self.ty = ty
def __repr__(self):
return "&'{} {!r}".format(self.lt, self.ty)
class NewTypeExpr(Expr):
def __init__(self, value): self.value = value
class Deref(NewTypeExpr):
def __repr__(self):
return "*{!r}".format(self.value)
class TakeRefMut(NewTypeExpr):
def __repr__(self):
return "&mut {!r}".format(self.value)
class Var(NewTypeExpr):
def __repr__(self):
return self.value
class Const(Expr):
def __init__(self, ty, value):
self.ty = ty
self.value = value
def __repr__(self):
return "({!r}: {!r})".format(self.value, self.ty)
class Assign(Stmt):
def __init__(self, lhs, rhs):
self.lhs = lhs
self.rhs = rhs
def __repr__(self):
return "*{!r} = {!r}".format(self.lhs, self.rhs)
class Call(Stmt):
def __init__(self, fun, *args):
self.fun = fun
self.args = args
def exprs(self):
yield from self.args
class Return(Stmt):
def __init__(self, arg):
self.arg = arg
def exprs(self):
yield self.arg
class Function:
def __init__(self, name, lts, args):
self.name = name
self.lts = lts
self.args = args
self.vars = dict(args)
self.entry = BasicBlock()
class BasicBlock:
def __init__(self):
self.stmts = []
self.succ = []
class Builder:
def __init__(self, fun):
self.fun = fun
self.bb = fun.entry
def stmt(self, stmt):
self.bb.stmts.append(stmt)
def declare_mut(self, name, ty):
self.fun.vars[name] = ty
def visit_type(cs, ty):
if isinstance(ty, str):
return "static"
if isinstance(ty, RefMut):
vty = visit_type(cs, ty.ty)
cs.append((vty, ty.lt, "ref only lives as long as pointed-to type"))
return ty.lt
def generate_lt_constraints(fun):
cs = []
counter = 0
def tmp(kind="tmp"):
nonlocal counter
counter += 1
return "{}{}".format(kind, counter)
borrows = {}
def record_borrow(var, now, mutable=True):
nonlocal borrows
prior_mut, lts = borrows.get(var, (False, []))
if prior_mut:
for plt in lts:
cs.append((now, plt, "all prior borrows of {} must end".format(var)))
lts = []
lt = tmp("borrow_"+var+"_")
if mutable:
borrows[var] = True, [lt]
else:
borrows[var] = False, lts+[lt]
return lt
def visit_expr(expr, borrow=False):
lt, ty = visit_expr_inner(expr, borrow)
visit_type(cs, ty)
return lt, ty
def visit_expr_inner(expr, borrow=False):
nonlocal time
if isinstance(expr, Deref):
lt = tmp("tmpderef_")
vlt, vty = visit_expr(expr.value)
return lt, vty.ty
elif isinstance(expr, TakeRefMut):
# todo: borrowing
vlt, vty = visit_expr(expr.value)
if False:
lt = tmp("tmpref_")
cs.append((vlt, lt, "ref taken at most as long as value exists"))
return lt, RefMut(lt, vty)
else:
return vlt, RefMut(vlt, vty)
elif isinstance(expr, Var):
lt = record_borrow(expr.value, time, True)
return lt, fun.vars[expr.value]
elif isinstance(expr, Const):
return "static", expr.ty
def assign_cs(l, r, lremap=None):
lr = isinstance(l, RefMut)
rr = isinstance(r, RefMut)
assert lr == rr
if lr and rr:
cs.append((r.lt, lremap[l.lt] if lremap else l.lt, "assignment may shorten lifetime"))
if assign_cs(l.ty, r.ty, lremap): # and mutable
cs.append((lremap[l.ty.lt] if lremap else l.ty.lt, r.ty.lt, "TODO explain this one"))
return True
return False
time = "begin"
def gen_for_stmt(stmt):
nonlocal time
after_stmt = tmp("stmt_")
cs.append((after_stmt, time, "stmt ordering"))
before_stmt = time
time = after_stmt
if isinstance(stmt, Assign):
l = llt, lty = visit_expr(stmt.lhs)
r = rlt, rty = visit_expr(stmt.rhs)
cs.append((lty.lt, after_stmt, "assignment target must be valid at use"))
cs.append((llt, after_stmt, "arg must be available at use (arg = ...)"))
cs.append((rlt, after_stmt, "arg must be available at use (... = arg)"))
assign_cs(lty.ty, rty)
elif isinstance(stmt, Call):
plts = {lt for ltab in stmt.fun.lts or [] for lt in ltab}
for argname, ty in stmt.fun.args:
while isinstance(ty, RefMut):
plts.add(ty.lt)
ty = ty.ty
del argname, ty
lt_insta = {
lt:
before_stmt if lt == "begin" else
after_stmt if lt == "return" else
tmp("call_{}_{}_".format(stmt.fun.name, lt))
for lt in plts
}
for lta, ltb in stmt.fun.lts or []:
cs.append((lt_insta[lta], lt_insta[ltb], "interface lifetime bound for {} ({}: {})".format(stmt.fun.name, lta, ltb)))
for arg, (pname, pty) in zip(stmt.args, stmt.fun.args):
a = alt, aty = visit_expr(arg)
cs.append((alt, after_stmt, "arg must be available at use ({}({}=arg, ...))".format(stmt.fun.name, pname)))
assign_cs(pty, aty, lt_insta)
elif isinstance(stmt, Return):
a = alt, aty = visit_expr(stmt.arg)
cs.append((alt, after_stmt, "arg must be available at use (return)"))
for stmt in fun.entry.stmts:
gen_for_stmt(stmt)
for var, (mut, lts) in borrows.items():
for lt in lts:
cs.append(("return", lt, "borrow of local var {} ends before end of function".format(var)))
cs.append(("return", time, "stmt ordering"))
# remove 'static : 'a
# remove 'a : 'a
cs = [c for c in cs if c[0] != "static" and c[0] != c[1]]
#cs = list(set(cs))
return cs
def demo(fun):
print(fun.name)
cs = generate_lt_constraints(fun)
for c in cs:
print(" {}: {} // {}".format(*c))
print()
public_lts = {"begin", "return"}
public_lts.update(lt for lt, bounds in fun.lts or [])
for argname, ty in fun.args:
while isinstance(ty, RefMut):
public_lts.add(ty.lt)
ty = ty.ty
import networkx as nx
g = nx.DiGraph()
g.add_edges_from((c[0], c[1]) for c in cs)
rg = g.reverse()
rgscc = list(nx.strongly_connected_components(rg))
rgc = nx.condensation(rg, rgscc)
for n in rgc:
rgc.nodes[n]["label"] = "/".join(rgscc[n])
nx.nx_pydot.write_dot(rgc, "{}_generated.dot".format(fun.name))
tg = nx.transitive_closure(g)
tgp = nx.subgraph(tg, public_lts)
tgpr = nx.transitive_reduction(tgp)
ig = nx.DiGraph()
ig.add_edge("return", "begin")
ig.add_edges_from(fun.lts or [])
for c in ig.edges():
print(" {}: {} // provided by signature".format(c[0], c[1]))
print()
tyreqs = []
for argname, ty in fun.args:
visit_type(tyreqs, ty)
for c in tgpr.edges():
provided = False if c[0] not in ig else c[1] in nx.descendants(ig, c[0])
status = "OK" if provided else "MISSING"
print(" {}: {} // required from signature ({})".format(c[0], c[1], status))
for c in tyreqs:
if c[0] == "static": continue
provided = False if c[0] not in ig else c[1] in nx.descendants(ig, c[0])
status = "OK" if provided else "MISSING"
print(" {}: {} // required for argument type validity ({})".format(c[0], c[1], status))
print()
return cs
"""
fun_set_ptr<'a: 'return, 'b: 'a>(x: &'a mut &'b mut i32, y: &'b mut i32)
*x = y
fun_set_val<'a: 'return>(x: &'c mut i32)
*x = 99
fun_main()
let mut y: i32 = 0
let mut x: &'tmpx mut i32 = undef
fun_set_ptr(&mut x, &mut y)
fun_set_val(x)
return y
"""
fun_set_ptr = Function("fun_set_ptr",
[["a", "return"],
["b", "a"]],
[["x", RefMut("a", RefMut("b", "i32"))],
["y", RefMut("b", "i32")]])
Builder(fun_set_ptr).stmt(Assign(Var("x"), Var("y")))
fun_set_val = Function("fun_set_val",
[["c", "return"]],
[["z", RefMut("c", "i32")]])
Builder(fun_set_val).stmt(Assign(Var("z"), Const("i32", 99)))
fun_main = Function("fun_main",
None,
[])
b = Builder(fun_main)
b.declare_mut("y", "i32")
b.declare_mut("x", RefMut("tmpx", "i32"))
b.stmt(Call(fun_set_ptr, TakeRefMut(Var("x")), TakeRefMut(Var("y"))))
b.stmt(Call(fun_set_val, Var("x")))
b.stmt(Return(Var("y")))
cs_fun_set_ptr = demo(fun_set_ptr)
cs_fun_set_val = demo(fun_set_val)
cs_fun_main = demo(fun_main)
"""
fun_set_ptr
stmt_1: begin // stmt ordering
b: a // ref only lives as long as pointed-to type
a: stmt_1 // assignment target must be valid at use
borrow_x_2: stmt_1 // arg must be available at use (arg = ...)
borrow_y_3: stmt_1 // arg must be available at use (... = arg)
return: borrow_x_2 // borrow of local var x ends before end of function
return: borrow_y_3 // borrow of local var y ends before end of function
return: stmt_1 // stmt ordering
return: begin // provided by signature
a: return // provided by signature
b: a // provided by signature
b: a // required from signature (OK)
a: begin // required from signature (OK)
return: begin // required from signature (OK)
b: a // required for argument type validity (OK)
fun_set_val
stmt_1: begin // stmt ordering
c: stmt_1 // assignment target must be valid at use
borrow_z_2: stmt_1 // arg must be available at use (arg = ...)
return: borrow_z_2 // borrow of local var z ends before end of function
return: stmt_1 // stmt ordering
return: begin // provided by signature
c: return // provided by signature
c: begin // required from signature (OK)
return: begin // required from signature (OK)
fun_main
stmt_1: begin // stmt ordering
call_fun_set_ptr_a_3: stmt_1 // interface lifetime bound for fun_set_ptr (a: return)
call_fun_set_ptr_b_2: call_fun_set_ptr_a_3 // interface lifetime bound for fun_set_ptr (b: a)
tmpx: borrow_x_4 // ref only lives as long as pointed-to type
borrow_x_4: stmt_1 // arg must be available at use (fun_set_ptr(x=arg, ...))
borrow_x_4: call_fun_set_ptr_a_3 // assignment may shorten lifetime
tmpx: call_fun_set_ptr_b_2 // assignment may shorten lifetime
call_fun_set_ptr_b_2: tmpx // TODO explain this one
borrow_y_5: stmt_1 // arg must be available at use (fun_set_ptr(y=arg, ...))
borrow_y_5: call_fun_set_ptr_b_2 // assignment may shorten lifetime
stmt_6: stmt_1 // stmt ordering
call_fun_set_val_c_7: stmt_6 // interface lifetime bound for fun_set_val (c: return)
stmt_6: borrow_x_4 // all prior borrows of x must end
borrow_x_8: stmt_6 // arg must be available at use (fun_set_val(z=arg, ...))
tmpx: call_fun_set_val_c_7 // assignment may shorten lifetime
stmt_9: stmt_6 // stmt ordering
stmt_9: borrow_y_5 // all prior borrows of y must end
borrow_y_10: stmt_9 // arg must be available at use (return)
return: borrow_x_8 // borrow of local var x ends before end of function
return: borrow_y_10 // borrow of local var y ends before end of function
return: stmt_9 // stmt ordering
return: begin // provided by signature
return: begin // required from signature (OK)
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment