Skip to content

Instantly share code, notes, and snippets.

@slavapestov
Last active March 13, 2021 16:27
Show Gist options
  • Save slavapestov/75dbec34f9eba5fb4a4a00b1ee520d0b to your computer and use it in GitHub Desktop.
Save slavapestov/75dbec34f9eba5fb4a4a00b1ee520d0b to your computer and use it in GitHub Desktop.
//
// RequirementMachine.swift
//
// Created by Slava Pestov on 3/9/21.
//
/// Uniqued identifiers.
struct Identifier: CustomStringConvertible {
let id: Int
init(_ str: String) {
id = Self.getUniqueID(str)
}
var description: String {
return Self.strings[id]
}
static var strings: [String] = []
static var ids: [String: Int] = [:]
static func getUniqueID(_ str: String) -> Int {
let nextID = strings.count
let id = ids[str, default: nextID]
if id == nextID {
strings.append(str)
ids[str] = nextID
}
return id
}
}
extension Identifier : Hashable {}
extension Identifier : ExpressibleByStringLiteral {
init(stringLiteral value: String) {
self.init(value)
}
}
extension Identifier : Comparable {
static func <(lhs: Self, rhs: Self) -> Bool {
return lhs.description < rhs.description
}
}
/// The order of enum cases is significant for the Comparable conformance on Token.
enum Token: CustomStringConvertible {
/// All terms are grounded in a protocol token. Also represents conformance
/// of a subject type, given by the preceding tokens.
case proto(Identifier)
/// A resolved associated type in a protocol.
case type(Identifier, Identifier)
/// An unresolved nested type of a parent type.
case name(Identifier)
var description: String {
switch self {
case .type(let proto, let name): return "[.\(proto):\(name)]"
case .proto(let name): return "[.\(name)]"
case .name(let name): return ".\(name)"
}
}
}
extension Token : Comparable, Hashable {}
extension Token : ExpressibleByStringLiteral {
init(stringLiteral value: String) {
self = .name(Identifier(value))
}
}
/// A term is a non-empty sequence of tokens.
struct Term: CustomStringConvertible {
var tokens: [Token]
init(tokens: [Token]) {
assert(!tokens.isEmpty)
self.tokens = tokens
}
var description: String {
return tokens.map(\.description).joined()
}
/// Look for the subterm in this term.
func contains(subterm: Term) -> Bool {
var lhs = tokens[...]
let rhs = subterm.tokens[...]
while rhs.count <= lhs.count {
let range = ..<(lhs.startIndex + rhs.count)
if lhs[range] == rhs {
return true
}
lhs = lhs.dropFirst()
}
return false
}
/// Replace the first occurrence of a subterm with the replacement term. Returns false if
/// the subterm does not appear in this term.
mutating func rewrite(subterm: Term, with replacement: Term) -> Bool {
if subterm.tokens.count > tokens.count {
return false
}
var lhs = tokens[...]
let rhs = subterm.tokens[...]
while rhs.count <= lhs.count {
let range = lhs.startIndex..<(lhs.startIndex + rhs.count)
if lhs[range] == rhs {
tokens.replaceSubrange(range, with: replacement.tokens)
return true
}
lhs = lhs.dropFirst()
}
return false
}
/// Attempt to compute XYZ such that either:
/// - XYZ = self, Y = other
/// - XY = self, YZ = other
/// Returns the term XYZ, or nil if self does not overlap with other.
func overlaps(with other: Term) -> Term? {
if other.tokens.count > tokens.count {
return nil
}
var lhs = tokens[...]
var rhs = other.tokens[...]
while rhs.count <= lhs.count {
if lhs[..<(lhs.startIndex + rhs.count)] == rhs {
return self
}
lhs = lhs.dropFirst()
}
while !lhs.isEmpty {
rhs = rhs[..<(rhs.endIndex - 1)]
if lhs == rhs {
var overlap = tokens
overlap.replaceSubrange(lhs.startIndex..., with: other.tokens)
return Term(tokens: overlap)
}
lhs = lhs.dropFirst()
}
return nil
}
}
extension Term : ExpressibleByArrayLiteral {
init(arrayLiteral tokens: Token...) {
self.init(tokens: tokens)
}
}
extension Term : Comparable {
/// Lexshort order.
static func <(lhs: Self, rhs: Self) -> Bool {
if lhs.tokens.count < rhs.tokens.count {
return true
}
if (lhs.tokens.count > rhs.tokens.count) {
return false
}
for (x, y) in zip(lhs.tokens, rhs.tokens) {
if x < y {
return true
}
if x > y {
return false
}
}
return false
}
}
/// Rewrite rules.
struct Rule : CustomStringConvertible {
var from: Term
var to: Term
/// Orients the two sides. Returns nil if the rule is redundant.
init?(_ from: Term, _ to: Term) {
if (from > to) {
self.from = from
self.to = to
} else if (from < to) {
self.from = to
self.to = from
} else {
return nil
}
}
/// Simplifies and orients both sides. Returns nil if the rule is redundant.
init?(_ from: Term, _ to: Term, _ rewriteSystem: RewriteSystem) {
var x = from
var y = to
rewriteSystem.simplify(term: &x)
rewriteSystem.simplify(term: &y)
self.init(x, y)
}
/// Returns true if the term changed.
@discardableResult
func apply(to term: inout Term) -> Bool {
return term.rewrite(subterm: from, with: to)
}
var description: String {
return "\(from) --> \(to)"
}
}
extension Rule : Comparable {
/// Rules are ordered by their left hand sides, just for sorting in dump().
static func <(lhs: Self, rhs: Self) -> Bool {
return lhs.from < rhs.from
}
}
/// A rewrite system stores a set of rewrite rules and attempts to compute their completion.
struct RewriteSystem {
var rules: [Rule] = []
/// Orient and add a rule unless it is redundant.
mutating func add(_ lhs: Term, _ rhs: Term) {
if let new = Rule(lhs, rhs, self) {
rules.append(new)
}
}
/// Check if two terms are canonically equivalent.
func areEquivalent(_ lhs: Term, _ rhs: Term) -> Bool {
return Rule(lhs, rhs, self) == nil
}
/// Check if a type term conforms to a protocol.
func conforms(type: Term, to proto: Identifier) -> Bool {
var other = type
other.addProtocol(proto)
return areEquivalent(type, other)
}
/// Simplify a term until fixed point, returning true if it changed.
@discardableResult
func simplify(term: inout Term) -> Bool {
var changed = false
while (rules.reduce(false) { (result: Bool, rule: Rule) in
if rule.apply(to: &term) {
changed = true
return true
}
return result
}) {}
return changed
}
/// Attempt to compute confluent completion using the Knuth-Bendix algorithm.
/// Returns false if max iterations reached.
mutating func complete(iterations: Int = 100, debug: Bool = false) -> Bool {
func log(_ str: @autoclosure () -> String) {
if (debug) {
print(str())
}
}
var pairs: [(Rule, Rule)] = []
for (i, lhs) in rules.enumerated() {
for (j, rhs) in rules.enumerated() {
if i != j {
pairs.append((lhs, rhs))
}
}
}
log("Starting completion procedure")
var rulesAdded = 0
while let (lhs, rhs) = pairs.popLast() {
guard let overlap = lhs.from.overlaps(with: rhs.from) else { continue }
log("lhs = \(lhs)")
log("rhs = \(rhs)")
log("overlap = \(overlap)")
var x = overlap
var y = overlap
lhs.apply(to: &x)
rhs.apply(to: &y)
log("x = \(x)")
log("y = \(y)")
guard let newRule = Rule(x, y, self) else { continue }
if rulesAdded > iterations {
log("Completion failed")
return false
}
rulesAdded += 1
log("Adding \(newRule)")
rules.removeAll(where: { rule in
if rule.from.contains(subterm: newRule.from) {
log("Removing \(rule)")
}
return rule.from.contains(subterm: newRule.from)
})
for rule in rules {
pairs.append((rule, newRule))
pairs.append((newRule, rule))
}
rules.append(newRule)
}
return true
}
func dump() {
print("Rewrite system rules:")
for rule in rules.sorted() {
print(rule)
}
}
}
/// Utilities for building terms from requirements.
extension Term {
init(proto: Identifier, path: [Identifier]) {
self.init(tokens: [.proto(proto)] + path.map(Token.name))
}
mutating func addProtocol(_ proto: Identifier) {
tokens.append(.proto(proto))
}
}
/// A specification of a protocol requirement in a Swift-like form.
enum Requirement {
/// Protocol refinement -- aka, conformance on `Self`.
case refines(Identifier)
/// Associated type introduction. It may conform to zero or more protocols.
case type(Identifier, conformsTo: [Identifier] = [])
/// More general conformance relation between a given type and protocol.
///
/// The type is understood to be a non-empty path from `Self`. This
/// cannot be used to impose conformance constraints on `Self`.
case conforms([Identifier], Identifier)
/// Same-type equivalence between two types.
///
/// The two types are understood to be paths from `Self`. One of
/// the two may be empty, which means `Self`, but this cannot be
/// used to impose conformance constraints on `Self`.
case sameType([Identifier], [Identifier])
/// Desugar .type() into zero or more .conforms() requirements.
func desugar() -> [Requirement] {
guard case .type(let name, conformsTo: let conformsTo) = self else { return [] }
return conformsTo.map { proto in .conforms([name], proto) }
}
/// Resolve the `Self`-relative subject type as a term grounded at the given protocol.
func resolveSubjectType(_ proto: Identifier) -> Term {
switch self {
case .refines(_):
fatalError()
case .type(let name, _):
return Term(proto: proto, path: [name])
case .conforms(let subject, _):
return Term(proto: proto, path: subject)
case .sameType(let subject, _):
return Term(proto: proto, path: subject)
}
}
/// Resolve the `Self`-relative a subject type a term grounded at the given protocol.
func resolveConstraintType(_ proto: Identifier) -> Term {
switch self {
case .refines(_):
fatalError()
case .type(let name, _):
return [.type(proto, name)]
case .conforms(let subject, let conformsTo):
var result = Term(proto: proto, path: subject)
result.addProtocol(conformsTo)
return result
case .sameType(_, let constraint):
return Term(proto: proto, path: constraint)
}
}
}
/// A protocol stores its requirements.
struct Proto {
var requirements: [Requirement]
/// Compute refined protocols.
var refines: [Identifier] {
return requirements.compactMap { requirement in
switch requirement {
case .refines(let proto): return proto
default: return nil
}
}
}
init(requirements: [Requirement]) {
self.requirements = requirements
}
}
extension Proto : ExpressibleByArrayLiteral {
init(arrayLiteral values: Requirement...) {
self.init(requirements: values)
}
}
/// A category is a (possibly mutually-recursive) set of protocols.
struct Category {
var protocols: [Identifier : Proto] = [:]
/// Expand protocol refinement when introducing a conformance requirement.
func expand(requirement: Requirement) -> [Requirement] {
guard case .conforms(let path, let name) = requirement else { return [] }
guard let proto = protocols[name] else { return [] }
return proto.refines.map { otherProto in .conforms(path, otherProto) }
}
}
/// Converting a Category into a series of rewrite rules.
extension RewriteSystem {
mutating func add(category: Category) {
for (name, proto) in category.protocols {
for requirement in proto.requirements {
/// Recursively desugar `.type(_, conformsTo: _)` and `.conforms()` to
/// handle protocol refinement.
func add(_ requirement: Requirement) {
for desugared in requirement.desugar() {
add(desugared)
}
for expanded in category.expand(requirement: requirement) {
add(expanded)
}
self.add(requirement.resolveSubjectType(name),
requirement.resolveConstraintType(name))
}
if case .refines(let other) = requirement {
/// Associated types from refined protocols are re-stated.
func inheritAssociatedTypes(_ other: Identifier) {
guard let otherProto = category.protocols[other] else { return }
for case .type(let name, _) in otherProto.requirements {
add(.type(name))
}
for refined in otherProto.refines {
inheritAssociatedTypes(refined)
}
}
inheritAssociatedTypes(other)
continue
}
add(requirement)
}
}
}
}
var s = RewriteSystem()
let c = Category(protocols: [
"IteratorProtocol" : [
.type("Element")
],
"Sequence" : [
.type("Element"),
.type("Iterator", conformsTo: ["IteratorProtocol"]),
.sameType(["Iterator", "Element"], ["Element"])
],
"Collection" : [
.refines("Sequence"),
.type("SubSequence", conformsTo: ["Collection"]),
.sameType(["SubSequence", "SubSequence"], ["SubSequence"]),
.sameType(["SubSequence", "Element"], ["Element"]),
.sameType(["SubSequence", "Index"], ["Index"]),
.type("Index", conformsTo: ["Comparable"]),
.type("Indices", conformsTo: ["Collection"]),
.sameType(["Indices", "Element"], ["Index"]),
.sameType(["Indices", "Index"], ["Index"]),
.sameType(["Indices", "SubSequence"], ["Indices"]),
]])
s.add(category: c)
s.complete()
s.dump()
/*
Rewrite system rules:
[.Collection].Element --> [.Collection:Element]
[.Collection].Index --> [.Collection:Index]
[.Collection].Indices --> [.Collection:Indices]
[.Collection].Iterator --> [.Collection:Iterator]
[.Collection].SubSequence --> [.Collection:SubSequence]
[.IteratorProtocol].Element --> [.IteratorProtocol:Element]
[.Sequence].Element --> [.Sequence:Element]
[.Sequence].Iterator --> [.Sequence:Iterator]
[.Collection:Index][.Comparable] --> [.Collection:Index]
[.Collection:Indices][.Collection] --> [.Collection:Indices]
[.Collection:Indices][.Sequence] --> [.Collection:Indices]
[.Collection:Indices][.Collection:Element] --> [.Collection:Index]
[.Collection:Indices][.Collection:Index] --> [.Collection:Index]
[.Collection:Indices][.Collection:SubSequence] --> [.Collection:Indices]
[.Collection:Indices][.Sequence:Element] --> [.Collection:Index]
[.Collection:Indices][.Sequence:Iterator] --> [.Collection:Indices][.Collection:Iterator]
[.Collection:Indices].Element --> [.Collection:Index]
[.Collection:Indices].Index --> [.Collection:Index]
[.Collection:Indices].Indices --> [.Collection:Indices][.Collection:Indices]
[.Collection:Indices].Iterator --> [.Collection:Indices][.Collection:Iterator]
[.Collection:Indices].SubSequence --> [.Collection:Indices]
[.Collection:SubSequence][.Collection] --> [.Collection:SubSequence]
[.Collection:SubSequence][.Sequence] --> [.Collection:SubSequence]
[.Collection:SubSequence][.Collection:Element] --> [.Collection:Element]
[.Collection:SubSequence][.Collection:Index] --> [.Collection:Index]
[.Collection:SubSequence][.Collection:SubSequence] --> [.Collection:SubSequence]
[.Collection:SubSequence][.Sequence:Element] --> [.Collection:Element]
[.Collection:SubSequence][.Sequence:Iterator] --> [.Collection:SubSequence][.Collection:Iterator]
[.Collection:SubSequence].Element --> [.Collection:Element]
[.Collection:SubSequence].Index --> [.Collection].Index
[.Collection:SubSequence].Indices --> [.Collection:SubSequence][.Collection:Indices]
[.Collection:SubSequence].Iterator --> [.Collection:SubSequence][.Collection:Iterator]
[.Collection:SubSequence].SubSequence --> [.Collection:SubSequence]
[.Sequence:Iterator][.IteratorProtocol] --> [.Sequence:Iterator]
[.Sequence:Iterator][.IteratorProtocol:Element] --> [.Sequence:Element]
[.Sequence:Iterator].Element --> [.Sequence:Element]
[.Collection].Index[.Comparable] --> [.Collection].Index
[.Collection].Indices[.Collection] --> [.Collection].Indices
[.Collection].Indices[.Sequence] --> [.Collection].Indices
[.Collection].SubSequence[.Collection] --> [.Collection].SubSequence
[.Collection].SubSequence[.Sequence] --> [.Collection].SubSequence
[.Sequence].Iterator[.IteratorProtocol] --> [.Sequence].Iterator
[.Collection:Indices][.Collection:Iterator][.IteratorProtocol] --> [.Collection:Indices][.Collection:Iterator]
[.Collection:Indices][.Collection:Iterator][.IteratorProtocol:Element] --> [.Collection:Index]
[.Collection:Indices][.Collection:Iterator].Element --> [.Collection:Indices][.Sequence:Element]
[.Collection:SubSequence][.Collection:Iterator][.IteratorProtocol] --> [.Collection:SubSequence][.Collection:Iterator]
[.Collection:SubSequence][.Collection:Iterator][.IteratorProtocol:Element] --> [.Collection:Element]
[.Collection:SubSequence][.Collection:Iterator].Element --> [.Collection:SubSequence][.Sequence:Element]
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment