Skip to content

Instantly share code, notes, and snippets.

@bkase
Created January 5, 2018 05:59
Show Gist options
  • Save bkase/34df2a8fb48df61a0069cb5a8ab02d84 to your computer and use it in GitHub Desktop.
Save bkase/34df2a8fb48df61a0069cb5a8ab02d84 to your computer and use it in GitHub Desktop.
Type Inference Algorithm M (ported from Swift Sandbox)
// Write some awesome Swift code, or import libraries like "Foundation",
// "Dispatch", or "Glibc"
enum Result<E, T> {
case fail(E)
case success(T)
var succeeded: Bool {
switch self {
case .fail(_): return false
case .success(_): return true
}
}
func map<U>(_ f: (T) -> U) -> Result<E, U> {
switch self {
case .success(let v): return .success(f(v))
case .fail(let e): return .fail(e)
}
}
func flatMap<U>(_ f: (T) -> Result<E, U>) -> Result<E, U> {
switch self {
case .success(let v): return f(v)
case .fail(let e): return .fail(e)
}
}
func mapError<E2>(_ f: (E) -> E2) -> Result<E2, T> {
switch self {
case .success(let v): return .success(v)
case .fail(let e): return .fail(f(e))
}
}
func flatMapError<E2>(_ f: (E) -> Result<E2, T>) -> Result<E2, T> {
switch self {
case .success(let v): return .success(v)
case .fail(let e): return f(e)
}
}
func recover(_ f: (E) -> T) -> T {
switch self {
case .success(let v): return v
case .fail(let e): return f(e)
}
}
}
struct Identifier { let v: String }
extension Identifier: Equatable {
static func ==(lhs: Identifier, rhs: Identifier) -> Bool {
return lhs.v == rhs.v
}
}
extension Identifier: Hashable {
var hashValue: Int {
return self.v.hashValue
}
}
extension Identifier: ExpressibleByStringLiteral {
init(stringLiteral: String) {
self.v = stringLiteral
}
}
extension Identifier: CustomStringConvertible {
var description: String {
return self.v
}
}
struct TypeIdentifier {
let v: String
var type: `Type` {
return .typeVariable(self)
}
}
extension TypeIdentifier: Equatable {
static func ==(lhs: TypeIdentifier, rhs: TypeIdentifier) -> Bool {
return lhs.v == rhs.v
}
}
extension TypeIdentifier: Hashable {
var hashValue: Int {
return self.v.hashValue
}
}
extension TypeIdentifier: ExpressibleByStringLiteral {
init(stringLiteral: String) {
self.v = stringLiteral
}
}
extension TypeIdentifier: CustomStringConvertible {
var description: String {
return self.v
}
}
enum Literal {
case int(Int)
case string(String)
}
extension Literal: CustomStringConvertible {
var description: String {
switch self {
case let .int(i):
return i.description
case let .string(s):
return "\"\(s)\""
}
}
}
indirect enum Expr {
case literal(Literal)
case variable(Identifier)
case application(Expr, Expr)
case abstraction(Identifier, Expr)
case `let`(id: Identifier, be: Expr, inside: Expr)
case fix(f: Identifier, inside: (Identifier, Expr))
}
extension Expr: ExpressibleByIntegerLiteral {
init(integerLiteral: Int) {
self = .literal(.int(integerLiteral))
}
}
extension Expr: CustomStringConvertible {
var description: String {
switch self {
case let .literal(l):
return l.description
case let .variable(ident):
return ident.v
case let .application(e1, e2):
return "\(e1)·\(e2)"
case let .abstraction(x, e):
return "λ(\(x).\(e))"
case let .`let`(id: x, be: e1, inside: e2):
return "let \(x) = \(e1) in \(e2) end"
case let .fix(f: f, inside: (x, e)):
return "fix \(f) in λ(\(x).\(e)) end"
}
}
}
protocol HasTypeVariables {
var typeVariables: Set<TypeIdentifier> { get }
}
enum LiteralType: Equatable {
case intType
case stringType
}
extension LiteralType: CustomStringConvertible {
var description: String {
switch self {
case .intType:
return "Int"
case .stringType:
return "String"
}
}
}
indirect enum `Type` {
case constant(LiteralType)
case typeVariable(TypeIdentifier)
case function(`Type`, `Type`)
mutating func sub(_ typeContext: TypeContext) {
switch self {
case .constant(_):
return
case let .typeVariable(ident):
if let type = typeContext[ident], type != self {
var mutType = type
mutType.sub(typeContext)
self = mutType
}
case let .function(tIn, tOut):
var mutTIn = tIn
var mutTOut = tOut
mutTIn.sub(typeContext)
mutTOut.sub(typeContext)
self = .function(mutTIn, mutTOut)
}
}
func closure(_ context: Context) -> TypeScheme {
return TypeScheme(forall: typeVariables.subtracting(context.typeVariables), type: self)
}
}
extension `Type`: Equatable {
static func ==(lhs: `Type`, rhs: `Type`) -> Bool {
switch (lhs, rhs) {
case let (.constant(l1), .constant(l2)): return l1 == l2
case let (.typeVariable(id1), .typeVariable(id2)): return id1 == id2
case let (.function(t1in, t1out), .function(t2in, t2out)):
return t1in == t2in && t1out == t2out
default:
return false
}
}
}
extension `Type`: HasTypeVariables {
var typeVariables: Set<TypeIdentifier> {
switch self {
case .constant(_):
return Set<TypeIdentifier>()
case let .typeVariable(ident): return [ident]
case let .function(tIn, tOut):
return tIn.typeVariables.union(tOut.typeVariables)
}
}
}
extension `Type`: CustomStringConvertible {
var description: String {
switch self {
case let .constant(x):
return x.description
case let .typeVariable(a):
return a.description
case let .function(
.function(t1In, t1Out),
t2
):
return "(" + t1In.description + " → " + t1Out.description + ")" + " → " + t2.description
case let .function(t1In, t1Out):
return t1In.description + " → " + t1Out.description
}
}
}
struct TypeScheme {
var typeVariables: Set<TypeIdentifier>
var type: `Type`
init(type: `Type`) {
self.typeVariables = []
self.type = type
}
init(forall typeVariables: Set<TypeIdentifier>, type: `Type`) {
self.typeVariables = typeVariables
self.type = type
}
func instantiate(_ freshTypes: FreshTypeFactory) -> `Type` {
var mutType = type
mutType.sub(typeVariables.reduce(TypeContext()) { ctx, a in
var mutCtx = ctx
mutCtx[a] = freshTypes.next()!.type
return mutCtx
})
return mutType
}
mutating func sub(_ typeContext: TypeContext) {
type.sub(typeContext)
self.typeVariables = Set(Array(typeVariables).filter{ typeContext[$0] == nil })
}
}
extension TypeScheme: HasTypeVariables { }
extension TypeScheme: CustomStringConvertible {
var description: String {
return "∀" + typeVariables.map{ $0.description }.joined(separator: ",") + "." + type.description
}
}
typealias TypeContext = [TypeIdentifier: `Type`]
typealias Context = [Identifier: TypeScheme]
extension Dictionary where Key == Identifier, Value == TypeScheme {
mutating func sub(_ typeContext: TypeContext) {
for (k, v) in self {
var mutV = v
mutV.sub(typeContext)
self[k] = mutV
}
}
}
extension Dictionary where Value: HasTypeVariables {
var typeVariables: Set<TypeIdentifier> {
return self.values.reduce(Set<TypeIdentifier>()) { (tvs, hasTvs) in
tvs.union(hasTvs.typeVariables)
}
}
}
enum UnificationError {
case primitiveFail(LiteralType, LiteralType)
case generalError(`Type`, `Type`)
}
extension UnificationError: Equatable {
static func ==(lhs: UnificationError, rhs: UnificationError) -> Bool {
switch (lhs, rhs) {
case let (.primitiveFail(l1, l2), .primitiveFail(r1, r2)):
return l1 == r1 && l2 == r2
case let (.generalError(t1), .generalError(t2)):
return t1 == t2
default: return false
}
}
}
func unify(_ t1: `Type`, _ t2: `Type`) -> Result<UnificationError, TypeContext> {
switch (t1, t2) {
case let (.constant(l1), .constant(l2)):
if l1 == l2 {
return .success([:])
} else {
return .fail(.primitiveFail(l1, l2))
}
case let (_, .typeVariable(v2)):
return .success([v2: t1])
case let (.typeVariable(v1), _):
return .success([v1: t2])
case let (.function(t1in, t1out), .function(t2in, t2out)):
return unify(t1in, t2in).flatMap { tinContext in
unify(t1out, t2out).map { toutContext in
var tinContextCopy = tinContext
tinContextCopy.merge(toutContext, uniquingKeysWith: {x, y in y})
return tinContextCopy
}
}
default:
return .fail(.generalError(t1, t2))
}
}
class FreshTypeFactory: IteratorProtocol {
var count = 0
func next() -> TypeIdentifier? {
count += 1
return TypeIdentifier(v: "$\(self.count)")
}
}
enum InferError {
case unificationError(UnificationError)
case undefinedVariable
case other
}
extension InferError: Equatable {
static func ==(lhs: InferError, rhs: InferError) -> Bool {
switch (lhs, rhs) {
case let (.unificationError(l), .unificationError(r)):
return l == r
case (.undefinedVariable, .undefinedVariable),
(.other, .other):
return true
default: return false
}
}
}
func syn(freshTypes: FreshTypeFactory, context: Context, e: Expr) -> Result<InferError, `Type`> {
var rho = `Type`.typeVariable(freshTypes.next()!)
return synM(freshTypes: freshTypes, context: context, e: e, rho: rho).map { tyCtx in
rho.sub(tyCtx)
return rho
}
}
func synM(freshTypes: FreshTypeFactory, context: Context, e: Expr, rho: `Type`) -> Result<InferError, TypeContext> {
switch e {
case let .literal(lit):
let litType: LiteralType = {
switch lit {
case .int(_): return .intType
case .string(_): return .stringType
}
}()
return unify(rho, .constant(litType)).mapError{ e in .unificationError(e) }
case let .variable(x):
guard let typeScheme = context[x] else {
return .fail(.undefinedVariable)
}
return unify(rho, typeScheme.instantiate(freshTypes))
.mapError{ e in .unificationError(e) }
case let .abstraction(parameter, functionBody):
var (b1, b2) = (freshTypes.next()!.type, freshTypes.next()!.type)
return unify(rho, .function(b1, b2))
.mapError{ e in .unificationError(e) }
.flatMap { (s1: TypeContext) -> Result<InferError, TypeContext> in
var mutContext = context
mutContext.merge([parameter: TypeScheme(type: b1)], uniquingKeysWith: {x, y in y})
mutContext.sub(s1)
b2.sub(s1)
return synM(
freshTypes: freshTypes,
context: mutContext,
e: functionBody,
rho: b2
).map { s2 in
var mutS2 = s2
mutS2.merge(s1, uniquingKeysWith: {x, y in y})
return mutS2
}
}
case let .application(e1, e2):
var b = freshTypes.next()!.type
return synM(
freshTypes: freshTypes,
context: context,
e: e1,
rho: .function(b, rho)
).flatMap{ (s1: TypeContext) -> Result<InferError, TypeContext> in
var mutContext = context
mutContext.sub(s1)
b.sub(s1)
return synM(
freshTypes: freshTypes,
context: mutContext,
e: e2,
rho: b
).map { s2 in
var mutS2 = s2
mutS2.merge(s1, uniquingKeysWith: {x, y in y})
return mutS2
}
}
case let .`let`(id: x, be: e1, inside: e2):
var b = freshTypes.next()!.type
return synM(
freshTypes: freshTypes,
context: context,
e: e1,
rho: b
).flatMap{ (s1: TypeContext) -> Result<InferError, TypeContext> in
// S1.Ctx
var mutContext = context
mutContext.sub(s1)
// S1.b
b.sub(s1)
// Clos_S1.b(S1.Ctx)
let scheme = b.closure(mutContext)
// S1.Ctx + x: Clos_S1.b(S1.ctx)
mutContext.merge([x: scheme], uniquingKeysWith: {x, y in y})
var mutRho = rho
mutRho.sub(s1)
return synM(
freshTypes: freshTypes,
context: mutContext,
e: e2,
rho: mutRho
).map { s2 in
// print("** \(s1) ; \(s2)")
var mutS2 = s2
mutS2.merge(s1, uniquingKeysWith: {x, y in y})
return mutS2
}
}
case let .fix(f: f, inside: (x, e)):
var mutContext = context
mutContext.merge([f: TypeScheme(type: rho)], uniquingKeysWith: {x,y in y})
return synM(
freshTypes: freshTypes,
context: mutContext,
e: .abstraction(x, e),
rho: rho
)
default:
return .success([:])
}
}
// TDD
infix operator =?=: TernaryPrecedence
func =?=<T: Equatable>(x: T, y: T) {
// Parametricity... I'm sorry :(
let equivalent: Bool
if let x = x as? `Type`,
let y = y as? `Type` {
equivalent = unify(x, y).succeeded
} else {
equivalent = x == y
}
if equivalent {
print("✓ \(x)")
} else {
print("✖ \(x) != \(y)")
}
}
extension Result where T: Equatable {
func expect(_ t: T) {
switch self {
case .fail(let e):
print("✖ Failed with \(e), didn't get \(t)")
case .success(let v):
v =?= t
}
}
}
extension Result where E: Equatable {
func expectError(_ e: E) {
switch self {
case .fail(let e):
print("✓ \(e)")
case .success(let v):
print("✖ Succeeded with \(v), didn't get \(e)")
}
}
}
func check(context: Context, e: Expr) -> Result<InferError, `Type`> {
print("Checking: \(e)")
let freshTypes = FreshTypeFactory()
return syn(
freshTypes: freshTypes,
context: context,
e: e
)
}
check(
context: ["a": TypeScheme(type: .constant(.intType))],
e: .variable("a")
).expect(.constant(.intType))
check(
context: [:],
e: .variable("a")
).expectError(.undefinedVariable)
check(
context: [:],
e: .abstraction("a", 4)
).expect(
.function(
.constant(.intType),
.constant(.intType)
)
)
check(
context: [:],
e: .application(.abstraction("a", .variable("a")), 1)
).expect(.constant(.intType))
check(
context: [:],
e: .let(id: "x", be: 5, inside: .variable("x"))
).expect(.constant(.intType))
check(
context: [:],
e: .let(
id: "identity",
be: .abstraction("x", .variable("x")),
inside: .application(.variable("identity"), 3)
)
).expect(.constant(.intType))
check(
context: [:],
e: .let(
id: "const",
be: .abstraction("c", .abstraction("x", .variable("c"))),
inside: .application(
.application(.variable("const"), 3),
4
)
)
).expect(.constant(.intType))
check(
context: [:],
e: .let(
id: "const",
be: .abstraction("c", .abstraction("x", .variable("c"))),
inside: .variable("const")
)
).expect(
.function(
.constant(.intType),
.function(.constant(.intType), .constant(.intType))
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment