Created
January 5, 2018 05:59
-
-
Save bkase/34df2a8fb48df61a0069cb5a8ab02d84 to your computer and use it in GitHub Desktop.
Type Inference Algorithm M (ported from Swift Sandbox)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Write some awesome Swift code, or import libraries like "Foundation", | |
// "Dispatch", or "Glibc" | |
enum Result<E, T> { | |
case fail(E) | |
case success(T) | |
var succeeded: Bool { | |
switch self { | |
case .fail(_): return false | |
case .success(_): return true | |
} | |
} | |
func map<U>(_ f: (T) -> U) -> Result<E, U> { | |
switch self { | |
case .success(let v): return .success(f(v)) | |
case .fail(let e): return .fail(e) | |
} | |
} | |
func flatMap<U>(_ f: (T) -> Result<E, U>) -> Result<E, U> { | |
switch self { | |
case .success(let v): return f(v) | |
case .fail(let e): return .fail(e) | |
} | |
} | |
func mapError<E2>(_ f: (E) -> E2) -> Result<E2, T> { | |
switch self { | |
case .success(let v): return .success(v) | |
case .fail(let e): return .fail(f(e)) | |
} | |
} | |
func flatMapError<E2>(_ f: (E) -> Result<E2, T>) -> Result<E2, T> { | |
switch self { | |
case .success(let v): return .success(v) | |
case .fail(let e): return f(e) | |
} | |
} | |
func recover(_ f: (E) -> T) -> T { | |
switch self { | |
case .success(let v): return v | |
case .fail(let e): return f(e) | |
} | |
} | |
} | |
struct Identifier { let v: String } | |
extension Identifier: Equatable { | |
static func ==(lhs: Identifier, rhs: Identifier) -> Bool { | |
return lhs.v == rhs.v | |
} | |
} | |
extension Identifier: Hashable { | |
var hashValue: Int { | |
return self.v.hashValue | |
} | |
} | |
extension Identifier: ExpressibleByStringLiteral { | |
init(stringLiteral: String) { | |
self.v = stringLiteral | |
} | |
} | |
extension Identifier: CustomStringConvertible { | |
var description: String { | |
return self.v | |
} | |
} | |
struct TypeIdentifier { | |
let v: String | |
var type: `Type` { | |
return .typeVariable(self) | |
} | |
} | |
extension TypeIdentifier: Equatable { | |
static func ==(lhs: TypeIdentifier, rhs: TypeIdentifier) -> Bool { | |
return lhs.v == rhs.v | |
} | |
} | |
extension TypeIdentifier: Hashable { | |
var hashValue: Int { | |
return self.v.hashValue | |
} | |
} | |
extension TypeIdentifier: ExpressibleByStringLiteral { | |
init(stringLiteral: String) { | |
self.v = stringLiteral | |
} | |
} | |
extension TypeIdentifier: CustomStringConvertible { | |
var description: String { | |
return self.v | |
} | |
} | |
enum Literal { | |
case int(Int) | |
case string(String) | |
} | |
extension Literal: CustomStringConvertible { | |
var description: String { | |
switch self { | |
case let .int(i): | |
return i.description | |
case let .string(s): | |
return "\"\(s)\"" | |
} | |
} | |
} | |
indirect enum Expr { | |
case literal(Literal) | |
case variable(Identifier) | |
case application(Expr, Expr) | |
case abstraction(Identifier, Expr) | |
case `let`(id: Identifier, be: Expr, inside: Expr) | |
case fix(f: Identifier, inside: (Identifier, Expr)) | |
} | |
extension Expr: ExpressibleByIntegerLiteral { | |
init(integerLiteral: Int) { | |
self = .literal(.int(integerLiteral)) | |
} | |
} | |
extension Expr: CustomStringConvertible { | |
var description: String { | |
switch self { | |
case let .literal(l): | |
return l.description | |
case let .variable(ident): | |
return ident.v | |
case let .application(e1, e2): | |
return "\(e1)·\(e2)" | |
case let .abstraction(x, e): | |
return "λ(\(x).\(e))" | |
case let .`let`(id: x, be: e1, inside: e2): | |
return "let \(x) = \(e1) in \(e2) end" | |
case let .fix(f: f, inside: (x, e)): | |
return "fix \(f) in λ(\(x).\(e)) end" | |
} | |
} | |
} | |
protocol HasTypeVariables { | |
var typeVariables: Set<TypeIdentifier> { get } | |
} | |
enum LiteralType: Equatable { | |
case intType | |
case stringType | |
} | |
extension LiteralType: CustomStringConvertible { | |
var description: String { | |
switch self { | |
case .intType: | |
return "Int" | |
case .stringType: | |
return "String" | |
} | |
} | |
} | |
indirect enum `Type` { | |
case constant(LiteralType) | |
case typeVariable(TypeIdentifier) | |
case function(`Type`, `Type`) | |
mutating func sub(_ typeContext: TypeContext) { | |
switch self { | |
case .constant(_): | |
return | |
case let .typeVariable(ident): | |
if let type = typeContext[ident], type != self { | |
var mutType = type | |
mutType.sub(typeContext) | |
self = mutType | |
} | |
case let .function(tIn, tOut): | |
var mutTIn = tIn | |
var mutTOut = tOut | |
mutTIn.sub(typeContext) | |
mutTOut.sub(typeContext) | |
self = .function(mutTIn, mutTOut) | |
} | |
} | |
func closure(_ context: Context) -> TypeScheme { | |
return TypeScheme(forall: typeVariables.subtracting(context.typeVariables), type: self) | |
} | |
} | |
extension `Type`: Equatable { | |
static func ==(lhs: `Type`, rhs: `Type`) -> Bool { | |
switch (lhs, rhs) { | |
case let (.constant(l1), .constant(l2)): return l1 == l2 | |
case let (.typeVariable(id1), .typeVariable(id2)): return id1 == id2 | |
case let (.function(t1in, t1out), .function(t2in, t2out)): | |
return t1in == t2in && t1out == t2out | |
default: | |
return false | |
} | |
} | |
} | |
extension `Type`: HasTypeVariables { | |
var typeVariables: Set<TypeIdentifier> { | |
switch self { | |
case .constant(_): | |
return Set<TypeIdentifier>() | |
case let .typeVariable(ident): return [ident] | |
case let .function(tIn, tOut): | |
return tIn.typeVariables.union(tOut.typeVariables) | |
} | |
} | |
} | |
extension `Type`: CustomStringConvertible { | |
var description: String { | |
switch self { | |
case let .constant(x): | |
return x.description | |
case let .typeVariable(a): | |
return a.description | |
case let .function( | |
.function(t1In, t1Out), | |
t2 | |
): | |
return "(" + t1In.description + " → " + t1Out.description + ")" + " → " + t2.description | |
case let .function(t1In, t1Out): | |
return t1In.description + " → " + t1Out.description | |
} | |
} | |
} | |
struct TypeScheme { | |
var typeVariables: Set<TypeIdentifier> | |
var type: `Type` | |
init(type: `Type`) { | |
self.typeVariables = [] | |
self.type = type | |
} | |
init(forall typeVariables: Set<TypeIdentifier>, type: `Type`) { | |
self.typeVariables = typeVariables | |
self.type = type | |
} | |
func instantiate(_ freshTypes: FreshTypeFactory) -> `Type` { | |
var mutType = type | |
mutType.sub(typeVariables.reduce(TypeContext()) { ctx, a in | |
var mutCtx = ctx | |
mutCtx[a] = freshTypes.next()!.type | |
return mutCtx | |
}) | |
return mutType | |
} | |
mutating func sub(_ typeContext: TypeContext) { | |
type.sub(typeContext) | |
self.typeVariables = Set(Array(typeVariables).filter{ typeContext[$0] == nil }) | |
} | |
} | |
extension TypeScheme: HasTypeVariables { } | |
extension TypeScheme: CustomStringConvertible { | |
var description: String { | |
return "∀" + typeVariables.map{ $0.description }.joined(separator: ",") + "." + type.description | |
} | |
} | |
typealias TypeContext = [TypeIdentifier: `Type`] | |
typealias Context = [Identifier: TypeScheme] | |
extension Dictionary where Key == Identifier, Value == TypeScheme { | |
mutating func sub(_ typeContext: TypeContext) { | |
for (k, v) in self { | |
var mutV = v | |
mutV.sub(typeContext) | |
self[k] = mutV | |
} | |
} | |
} | |
extension Dictionary where Value: HasTypeVariables { | |
var typeVariables: Set<TypeIdentifier> { | |
return self.values.reduce(Set<TypeIdentifier>()) { (tvs, hasTvs) in | |
tvs.union(hasTvs.typeVariables) | |
} | |
} | |
} | |
enum UnificationError { | |
case primitiveFail(LiteralType, LiteralType) | |
case generalError(`Type`, `Type`) | |
} | |
extension UnificationError: Equatable { | |
static func ==(lhs: UnificationError, rhs: UnificationError) -> Bool { | |
switch (lhs, rhs) { | |
case let (.primitiveFail(l1, l2), .primitiveFail(r1, r2)): | |
return l1 == r1 && l2 == r2 | |
case let (.generalError(t1), .generalError(t2)): | |
return t1 == t2 | |
default: return false | |
} | |
} | |
} | |
func unify(_ t1: `Type`, _ t2: `Type`) -> Result<UnificationError, TypeContext> { | |
switch (t1, t2) { | |
case let (.constant(l1), .constant(l2)): | |
if l1 == l2 { | |
return .success([:]) | |
} else { | |
return .fail(.primitiveFail(l1, l2)) | |
} | |
case let (_, .typeVariable(v2)): | |
return .success([v2: t1]) | |
case let (.typeVariable(v1), _): | |
return .success([v1: t2]) | |
case let (.function(t1in, t1out), .function(t2in, t2out)): | |
return unify(t1in, t2in).flatMap { tinContext in | |
unify(t1out, t2out).map { toutContext in | |
var tinContextCopy = tinContext | |
tinContextCopy.merge(toutContext, uniquingKeysWith: {x, y in y}) | |
return tinContextCopy | |
} | |
} | |
default: | |
return .fail(.generalError(t1, t2)) | |
} | |
} | |
class FreshTypeFactory: IteratorProtocol { | |
var count = 0 | |
func next() -> TypeIdentifier? { | |
count += 1 | |
return TypeIdentifier(v: "$\(self.count)") | |
} | |
} | |
enum InferError { | |
case unificationError(UnificationError) | |
case undefinedVariable | |
case other | |
} | |
extension InferError: Equatable { | |
static func ==(lhs: InferError, rhs: InferError) -> Bool { | |
switch (lhs, rhs) { | |
case let (.unificationError(l), .unificationError(r)): | |
return l == r | |
case (.undefinedVariable, .undefinedVariable), | |
(.other, .other): | |
return true | |
default: return false | |
} | |
} | |
} | |
func syn(freshTypes: FreshTypeFactory, context: Context, e: Expr) -> Result<InferError, `Type`> { | |
var rho = `Type`.typeVariable(freshTypes.next()!) | |
return synM(freshTypes: freshTypes, context: context, e: e, rho: rho).map { tyCtx in | |
rho.sub(tyCtx) | |
return rho | |
} | |
} | |
func synM(freshTypes: FreshTypeFactory, context: Context, e: Expr, rho: `Type`) -> Result<InferError, TypeContext> { | |
switch e { | |
case let .literal(lit): | |
let litType: LiteralType = { | |
switch lit { | |
case .int(_): return .intType | |
case .string(_): return .stringType | |
} | |
}() | |
return unify(rho, .constant(litType)).mapError{ e in .unificationError(e) } | |
case let .variable(x): | |
guard let typeScheme = context[x] else { | |
return .fail(.undefinedVariable) | |
} | |
return unify(rho, typeScheme.instantiate(freshTypes)) | |
.mapError{ e in .unificationError(e) } | |
case let .abstraction(parameter, functionBody): | |
var (b1, b2) = (freshTypes.next()!.type, freshTypes.next()!.type) | |
return unify(rho, .function(b1, b2)) | |
.mapError{ e in .unificationError(e) } | |
.flatMap { (s1: TypeContext) -> Result<InferError, TypeContext> in | |
var mutContext = context | |
mutContext.merge([parameter: TypeScheme(type: b1)], uniquingKeysWith: {x, y in y}) | |
mutContext.sub(s1) | |
b2.sub(s1) | |
return synM( | |
freshTypes: freshTypes, | |
context: mutContext, | |
e: functionBody, | |
rho: b2 | |
).map { s2 in | |
var mutS2 = s2 | |
mutS2.merge(s1, uniquingKeysWith: {x, y in y}) | |
return mutS2 | |
} | |
} | |
case let .application(e1, e2): | |
var b = freshTypes.next()!.type | |
return synM( | |
freshTypes: freshTypes, | |
context: context, | |
e: e1, | |
rho: .function(b, rho) | |
).flatMap{ (s1: TypeContext) -> Result<InferError, TypeContext> in | |
var mutContext = context | |
mutContext.sub(s1) | |
b.sub(s1) | |
return synM( | |
freshTypes: freshTypes, | |
context: mutContext, | |
e: e2, | |
rho: b | |
).map { s2 in | |
var mutS2 = s2 | |
mutS2.merge(s1, uniquingKeysWith: {x, y in y}) | |
return mutS2 | |
} | |
} | |
case let .`let`(id: x, be: e1, inside: e2): | |
var b = freshTypes.next()!.type | |
return synM( | |
freshTypes: freshTypes, | |
context: context, | |
e: e1, | |
rho: b | |
).flatMap{ (s1: TypeContext) -> Result<InferError, TypeContext> in | |
// S1.Ctx | |
var mutContext = context | |
mutContext.sub(s1) | |
// S1.b | |
b.sub(s1) | |
// Clos_S1.b(S1.Ctx) | |
let scheme = b.closure(mutContext) | |
// S1.Ctx + x: Clos_S1.b(S1.ctx) | |
mutContext.merge([x: scheme], uniquingKeysWith: {x, y in y}) | |
var mutRho = rho | |
mutRho.sub(s1) | |
return synM( | |
freshTypes: freshTypes, | |
context: mutContext, | |
e: e2, | |
rho: mutRho | |
).map { s2 in | |
// print("** \(s1) ; \(s2)") | |
var mutS2 = s2 | |
mutS2.merge(s1, uniquingKeysWith: {x, y in y}) | |
return mutS2 | |
} | |
} | |
case let .fix(f: f, inside: (x, e)): | |
var mutContext = context | |
mutContext.merge([f: TypeScheme(type: rho)], uniquingKeysWith: {x,y in y}) | |
return synM( | |
freshTypes: freshTypes, | |
context: mutContext, | |
e: .abstraction(x, e), | |
rho: rho | |
) | |
default: | |
return .success([:]) | |
} | |
} | |
// TDD | |
infix operator =?=: TernaryPrecedence | |
func =?=<T: Equatable>(x: T, y: T) { | |
// Parametricity... I'm sorry :( | |
let equivalent: Bool | |
if let x = x as? `Type`, | |
let y = y as? `Type` { | |
equivalent = unify(x, y).succeeded | |
} else { | |
equivalent = x == y | |
} | |
if equivalent { | |
print("✓ \(x)") | |
} else { | |
print("✖ \(x) != \(y)") | |
} | |
} | |
extension Result where T: Equatable { | |
func expect(_ t: T) { | |
switch self { | |
case .fail(let e): | |
print("✖ Failed with \(e), didn't get \(t)") | |
case .success(let v): | |
v =?= t | |
} | |
} | |
} | |
extension Result where E: Equatable { | |
func expectError(_ e: E) { | |
switch self { | |
case .fail(let e): | |
print("✓ \(e)") | |
case .success(let v): | |
print("✖ Succeeded with \(v), didn't get \(e)") | |
} | |
} | |
} | |
func check(context: Context, e: Expr) -> Result<InferError, `Type`> { | |
print("Checking: \(e)") | |
let freshTypes = FreshTypeFactory() | |
return syn( | |
freshTypes: freshTypes, | |
context: context, | |
e: e | |
) | |
} | |
check( | |
context: ["a": TypeScheme(type: .constant(.intType))], | |
e: .variable("a") | |
).expect(.constant(.intType)) | |
check( | |
context: [:], | |
e: .variable("a") | |
).expectError(.undefinedVariable) | |
check( | |
context: [:], | |
e: .abstraction("a", 4) | |
).expect( | |
.function( | |
.constant(.intType), | |
.constant(.intType) | |
) | |
) | |
check( | |
context: [:], | |
e: .application(.abstraction("a", .variable("a")), 1) | |
).expect(.constant(.intType)) | |
check( | |
context: [:], | |
e: .let(id: "x", be: 5, inside: .variable("x")) | |
).expect(.constant(.intType)) | |
check( | |
context: [:], | |
e: .let( | |
id: "identity", | |
be: .abstraction("x", .variable("x")), | |
inside: .application(.variable("identity"), 3) | |
) | |
).expect(.constant(.intType)) | |
check( | |
context: [:], | |
e: .let( | |
id: "const", | |
be: .abstraction("c", .abstraction("x", .variable("c"))), | |
inside: .application( | |
.application(.variable("const"), 3), | |
4 | |
) | |
) | |
).expect(.constant(.intType)) | |
check( | |
context: [:], | |
e: .let( | |
id: "const", | |
be: .abstraction("c", .abstraction("x", .variable("c"))), | |
inside: .variable("const") | |
) | |
).expect( | |
.function( | |
.constant(.intType), | |
.function(.constant(.intType), .constant(.intType)) | |
) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment