Skip to content

Instantly share code, notes, and snippets.

@jesseschalken
Last active November 8, 2022 12:36
Show Gist options
  • Save jesseschalken/2db9f68299c5913362c666b948587d18 to your computer and use it in GitHub Desktop.
Save jesseschalken/2db9f68299c5913362c666b948587d18 to your computer and use it in GitHub Desktop.

Interfaces and Sum Types

The way different programming languages implement the mapping of runtime polymorphic values to code both at the source level and in the language implementation is a defining characteristic. There are two opposite approaches that languages facilitate to varying degrees:

  1. Interfaces / Traits / Abstract Classes / Inheritance / Method Overriding (from here on, just "Interfaces")
  2. Sum Types / Algebraic Data Types / Tagged Unions / Pattern Matching (from here on, just "Sum Types")

Example

You have two types: A, containing a string, and B, containing an int, and two operations: x and y.

There is a different implementation of x and y for each of A and B (4 implementations total).

You have a list of values which can each be either A or B, and for each value you wish to apply the operation x and y, finding the correct implementation.

This has to be done in a type safe way. Specifically:

  • The compiler should check that all four implementations exist for each combination of x or y and A or B.
  • The compiler should permit access to the string inside A and the int inside B if and only if the runtime value is of the matching type.
  • Extra data should be able to be passed from the call to the implementation, and/or returned from the implementation to the caller in a type-safe way. (This is not used in the examples but is why the Java visitor example has a type parameter, for the output of the visitor.)

If Foo is the type of elements of the list, there are two possible definitions for Foo:

  1. Define Foo as anything implementing x and y (an interface), and in A and B define their implementations of x and y.
  2. Define Foo as a value that is either A or B (a sum type), and in x and y define their implementations for A and B.

An interface fixes the set of operations (x and y) in the definition of Foo, but allows extension in the data types (A, B, ...).

A sum type fixes the set of types (A and B) in the definition of Foo, but allows extension in the operations (x, y, ...).

The need for extensibility in both at the same time is the expression problem.

Pseudocode

Sum Type

type A = { a: String }
type B = { b: Int }
type Foo = A | B
function x(foo: Foo) = when foo {
  is A -> ...
  is B -> ...
}
function y(foo: Foo) = when foo {
  is A -> ...
  is B -> ...
}

Interface

interface Foo {
  function x()
  function y()
}
type A implements Foo = {
  var a: String
  function x() = ...
  function y() = ...
}
type B implements Foo = {
  var b: Int
  function x() = ...
  function y() = ...
}

Compilation

Both approaches can be compiled to efficient machine code and memory layout. Languages designed for ahead-of-time compilation like Rust, C++, Swift and Haskell will typically guarantee this, while just-in-time compilers will typically make a best-effort attempt.

Sum type

If Foo is defined as the sum of A and B, the compiler can assign to each an index A=0 and B=1 (the type tag) and implement Foo as a tagged union. The tag is then used as the index into an array of implementations in x and y (a branch table).

In other words, the operation contains an array of implementations for each type, and the value is paired with an index into that array.

 ┌─────────┐                       ┌───────────┐
 │ A       │                       │ x(Foo)    │
 ├─────────┤                       ├───┬───────┤
 │ 0       ├────────────┬────────► │ 0 │ x(A)  │
 ├─────────┤            │          ├───┼───────┤
 │ String  │            │          │ 1 │ x(B)  │
 └─────────┘            │          └───┴───────┘
                        │
 ┌─────────┐            │          ┌───────────┐
 │ B       │            │          │ y(Foo)    │
 ├─────────┤            │          ├───┬───────┤
 │ 1       │            └────────► │ 0 │ y(A)  │
 ├─────────┤                       ├───┼───────┤
 │ Int     │                       │ 1 │ y(B)  │
 └─────────┘                       └───┴───────┘

Interface

If Foo is defined as an interface with methods x and y, the compiler can assign to each operation an index x=0, y=1 and implement Foo as a value paired with an an array of implementations for each operation (the virtual method table). A call to x or y is then an index into the array of operations, and a call to the implementation passing in the value as a parameter (which becomes this inside the implementation).

In other words, the value is paired with an array of implementations of each operation, and the operation is an index into that array.

 ┌───────────┐
 │ A         │
 ├───┬───────┤                     ┌─────────┐
 │ 0 │ A.x() │ ◄──────────┐        │ Foo.x() │
 ├───┼───────┤            │        ├─────────┤
 │ 1 │ A.y() │            ├────────┤ 0       │
 ├───┴───────┤            │        └─────────┘
 │ String    │            │
 └───────────┘            │
                          │
 ┌───────────┐            │
 │ B         │            │
 ├───┬───────┤            │        ┌─────────┐
 │ 0 │ B.x() │ ◄──────────┘        │ Foo.y() │
 ├───┼───────┤                     ├─────────┤
 │ 1 │ B.y() │                     │ 1       │
 ├───┴───────┤                     └─────────┘
 │ Int       │
 └───────────┘

(To save memory, typically the value contains a pointer to a global array of implementations for that type, instead of the array itself. For the same reason, it is good practice to define a JavaScript object's methods on a shared prototype (class) instead of in a field on each object.)

So the only difference is, between the data (A and B) and the code (x and y), which has the array of implementations and which has the index into that array.

A branch table is an array of implementations in the code, indexed by the data, and a virtual method table is a table of implementations in the data, indexed by the code.

Implementations

This is an example of the same code using either approach implemented in Kotlin, Scala, Java, C++, TypeScript, Haskell and Rust.

Each implementation should produce the output:

A(hello).x
A(hello).y
B(100).x
B(100).y

Notes

  • Java and C++ both require recent versions for sum type functionality.

    • Java doesn't support sealed classes and pattern matching without version 17 and --enable-preview. Alternatively the visitor pattern can be used (also included).
    • C++ doesn't support pattern matching but can get a similar effect with std::visit and std::variant combined with a template to merge multiple lambdas into a single object as the visitor. This requires C++17.
  • Kotlin, Scala and Rust support both approaches very well.

  • TypeScript has two alternative sum type representations:

    1. A union of classes distinguished by instanceof checks.
    2. A union of object types each with an explicit tag field set to a different literal type, distinguished by a switch on the tag field.

    I have used the former.

  • Haskell's type class instances and Rust's trait impls are very similar in that they are written separately from the data type definition. This enables a new trait/type class to come with impls/instances for existing types, and enables the implementation itself to be generic.

    • They are also similar in that the this or self parameter is explicit in each trait/type class method.
  • All languages supported inheritance very well except Haskell, which requires the ExistentialQuantification extension to build a list of heterogenous values implementing an interface (type class) and an extra wrapper type.

    • The \(FooImpl foo) -> destructuring lambda is interesting in that brings both the value foo into scope and the implementations of the Foo methods for that runtime value. The FooImpl wrapper type is presumably needed as a place to implicitly store those method implementations.
  • Systems languages Rust and C++ both require a heap allocation for the heterogenous list, namely Vec<Box<dyn Foo>> and std::vector<std::unique_ptr<Foo>>, because a Foo is of unknown size, but when Foo is a sum type the size is known in advance and heap allocation isn't required, so Vec<Foo> and std::vector<Foo> both work.

    • C++ requires an explicit definition of a virtual destructor for Foo to deallocate the std::string inside A through the std::unique_ptr<Foo>, but in Rust when a Box is dropped it will call the Drop::drop implementation of the contained value automatically.

Kotlin

Interface

interface Foo {
  fun x()
  fun y()
}

class A(val a: String) : Foo {
  override fun x() = println("A($a).x")
  override fun y() = println("A($a).y")
}

class B(val b: Int) : Foo {
  override fun x() = println("B($b).x")
  override fun y() = println("B($b).y")
}

val foos: List<Foo> = listOf(
  A("hello"),
  B(100),
)

fun main() {
  for (foo in foos) {
    foo.x()
    foo.y()
  }
}

Sum Type

sealed class Foo
class A(val a: String) : Foo()
class B(val b: Int) : Foo()

fun x(foo: Foo) = when (foo) {
  is A -> println("A(${foo.a}).x")
  is B -> println("B(${foo.b}).x")
}

fun y(foo: Foo) = when (foo) {
  is A -> println("A(${foo.a}).y")
  is B -> println("B(${foo.b}).y")
}

val foos: List<Foo> = listOf(
  A("hello"),
  B(100),
)

fun main() {
  for (foo in foos) {
    x(foo)
    y(foo)
  }
}

Scala

Interface

object Main {
  trait Foo {
    def x(): Unit
    def y(): Unit
  }
  
  class A(val a: String) extends Foo {
    override def x() = println(s"A($a).x")
    override def y() = println(s"A($a).y")
  }
  
  class B(val b: Int) extends Foo {
    override def x() = println(s"B($b).x")
    override def y() = println(s"B($b).y")
  }

  val foos: List[Foo] = List(
    new A("hello"),
    new B(100),
  )

  def main(args: Array[String]): Unit = {
    for (foo <- foos) {
      foo.x()
      foo.y()
    }
  }
}

Sum Type

object Main {
  sealed trait Foo
  case class A(a: String) extends Foo
  case class B(b: Int) extends Foo

  def x(foo: Foo) = foo match {
    case A(a) => println(s"A($a).x")
    case B(b) => println(s"B($b).x")
  }
  
  def y(foo: Foo) = foo match {
    case A(a) => println(s"A($a).y")
    case B(b) => println(s"B($b).y")
  }
  
  val foos: List[Foo] = List(
    A("hello"),
    B(100),
  )

  def main(args: Array[String]): Unit = {
    for (foo <- foos) {
      x(foo)
      y(foo)
    }
  }
}

Java

Interface

import java.util.List;
import java.util.Arrays;

class Main {
  interface Foo {
    void x();
    void y();
  }

  static class A implements Foo {
    String a;
    A(String a) { this.a = a; }
    @Override public void x() { System.out.println("A(" + a + ").x"); }
    @Override public void y() { System.out.println("A(" + a + ").y"); }
  }
  
  static class B implements Foo {
    int b;
    B(int b) { this.b = b; }
    @Override public void x() { System.out.println("B(" + b + ").x"); }
    @Override public void y() { System.out.println("B(" + b + ").y"); }
  }

  static List<Foo> foos = Arrays.asList(
    new A("hello"),
    new B(100)
  );

  public static void main(String args[]) { 
    for (var foo : foos) {
      foo.x();
      foo.y();
    }
  } 
}

Sum Type

import java.util.List;
import java.util.Arrays;

class Main {
  sealed interface Foo permits A, B {}
  static final record A(String a) implements Foo {}
  static final record B(int b) implements Foo {}

  static void x(Foo foo) {
    switch (foo) {
      case A a -> System.out.println("A(" + a.a + ").x");
      case B b -> System.out.println("B(" + b.b + ").x");
    }
  }

  static void y(Foo foo) {
    switch (foo) {
      case A a -> System.out.println("A(" + a.a + ").y");
      case B b -> System.out.println("B(" + b.b + ").y");
    }
  }

  static List<Foo> foos = Arrays.asList(
    new A("hello"),
    new B(100)
  );

  public static void main(String args[]) { 
    for (var foo : foos) {
      x(foo);
      y(foo);
    }
  } 
}

Visitor

import java.util.List;
import java.util.Arrays;

class Main {
  interface Foo {
    interface Visitor<T> {
      T visitA(A a);
      T visitB(B b);
    }
    <T> T accept(Visitor<? extends T> visitor);
  }

  static class A implements Foo {
    String a;
    A(String a) { this.a = a; }
    @Override public <T> T accept(Visitor<? extends T> visitor) { return visitor.visitA(this); }
  }
  
  static class B implements Foo {
    int b;
    B(int b) { this.b = b; }
    @Override public <T> T accept(Visitor<? extends T> visitor) { return visitor.visitB(this); }
  }

  static void x(Foo foo) {
    foo.accept(new Foo.Visitor<Void>() {
      @Override public Void visitA(A a) { System.out.println("A(" + a.a + ").x"); return null; }
      @Override public Void visitB(B b) { System.out.println("B(" + b.b + ").x"); return null; }
    });
  }

  static void y(Foo foo) {
    foo.accept(new Foo.Visitor<Void>() {
      @Override public Void visitA(A a) { System.out.println("A(" + a.a + ").y"); return null; }
      @Override public Void visitB(B b) { System.out.println("B(" + b.b + ").y"); return null; }
    });
  }

  static List<Foo> foos = Arrays.asList(
    new A("hello"),
    new B(100)
  );

  public static void main(String args[]) { 
    for (var foo : foos) {
      x(foo);
      y(foo);
    }
  } 
}

C++

Interface

#include <iostream>
#include <vector>
#include <memory>

class Foo {
public:
  virtual void x() const = 0;
  virtual void y() const = 0;
  virtual ~Foo() = default;
};

class A : public Foo {
  std::string a;
public:
  A(std::string a): a(a) {}
  virtual void x() const override { std::cout << "A(" << a << ").x\n"; }
  virtual void y() const override { std::cout << "A(" << a << ").y\n"; }
};

class B : public Foo {
  int b;
public:
  B(int b): b(b) {}
  virtual void x() const override { std::cout << "B(" << b << ").x\n"; }
  virtual void y() const override { std::cout << "B(" << b << ").y\n"; }
};

auto createFoos() {
  std::vector<std::unique_ptr<Foo>> foos;
  foos.emplace_back(std::make_unique<A>("hello"));
  foos.emplace_back(std::make_unique<B>(100));
  return foos;
}

int main() {
  for (auto& foo : createFoos()) {
    foo->x();
    foo->y();
  }
}

Sum Type

#include <iostream>
#include <vector>
#include <variant>

class A {
public:
  std::string a;
};

class B {
public:
  int b;
};

using Foo = std::variant<A, B>;

template<typename...T> struct overload : T... { using T::operator()...; };
template<typename...T> overload(T...) -> overload<T...>;

auto x(Foo const& foo) {
  std::visit(overload {
    [](A const& a) { std::cout << "A(" << a.a << ").x\n"; },
    [](B const& b) { std::cout << "B(" << b.b << ").x\n"; },
  }, foo);
}

auto y(Foo const& foo) {
  std::visit(overload {
    [](A const& a) { std::cout << "A(" << a.a << ").y\n"; },
    [](B const& b) { std::cout << "B(" << b.b << ").y\n"; },
  }, foo);
}

auto const foos = std::vector<Foo> {
  A{"hello"},
  B{100},
};

int main() {
  for (auto& foo : foos) {
    x(foo);
    y(foo);
  }
}

TypeScript

Interface

interface Foo {
  x(): void
  y(): void
}

class A implements Foo {
  constructor(public a: string) {}
  x() { console.log("A(" + this.a + ").x"); }
  y() { console.log("A(" + this.a + ").y"); }
}

class B implements Foo {
  constructor(public b: number) {}
  x() { console.log("B(" + this.b + ").x"); }
  y() { console.log("B(" + this.b + ").y"); }
}

const foos: Foo[] = [
  new A("hello"),
  new B(100),
];

function main() {
  for (const foo of foos) {
    foo.x();
    foo.y();
  }
}

main();

Sum Type

class A {
  constructor(public a: string) {}
}

class B {
  constructor(public b: number) {}
}

type Foo = A | B

const assertNever = (x: never) => x

function x(foo: Foo): void {
  if (foo instanceof A) return console.log("A(" + foo.a + ").x");
  if (foo instanceof B) return console.log("B(" + foo.b + ").x");
  assertNever(foo);
}

function y(foo: Foo): void {
  if (foo instanceof A) return console.log("A(" + foo.a + ").y");
  if (foo instanceof B) return console.log("B(" + foo.b + ").y");
  assertNever(foo);
}

const foos: Foo[] = [
  new A("hello"),
  new B(100),
];

function main() {
  for (const foo of foos) {
    x(foo);
    y(foo);
  }
}

main();

Haskell

Interface

{-# LANGUAGE ExistentialQuantification #-}

import Control.Monad

class Foo f where
  x :: f -> IO ()
  y :: f -> IO ()

data A = A String
data B = B Int

instance Foo A where
  x (A a) = putStrLn $ "A(" ++ a ++ ").x"
  y (A a) = putStrLn $ "A(" ++ a ++ ").y"

instance Foo B where
  x (B b) = putStrLn $ "B(" ++ (show b) ++ ").x"
  y (B b) = putStrLn $ "B(" ++ (show b) ++ ").y"

data FooImpl = forall f. Foo f => FooImpl f

foos :: [FooImpl]
foos = [
  FooImpl $ A "hello",
  FooImpl $ B 100 ]

main = do
  forM foos $ \(FooImpl foo) -> do
    x foo
    y foo

Sum Type

import Control.Monad

data Foo = A String | B Int

x :: Foo -> IO ()
x (A a) = putStrLn $ "A(" ++ a ++ ").x"
x (B b) = putStrLn $ "B(" ++ (show b) ++ ").x"

y :: Foo -> IO ()
y (A a) = putStrLn $ "A(" ++ a ++ ").y"
y (B b) = putStrLn $ "B(" ++ (show b) ++ ").y"

foos :: [Foo]
foos = [
  A "hello",
  B 100 ]

main = do
  forM foos $ \foo -> do
    x foo
    y foo

Rust

Interface

trait Foo {
  fn x(&self);
  fn y(&self);
}

struct A { a: String }
struct B { b: i32 }

impl Foo for A {
  fn x(&self) { println!("A({}).x", self.a) }
  fn y(&self) { println!("A({}).y", self.a) }
}

impl Foo for B {
  fn x(&self) { println!("B({}).x", self.b) }
  fn y(&self) { println!("B({}).y", self.b) }
}

fn create_foos() -> Vec<Box<dyn Foo>> {
  vec!(
    Box::new(A {a: "hello".to_string()}),
    Box::new(B {b: 100}),
  )
}

fn main() {
  for ref foo in create_foos() {
    foo.x();
    foo.y();
  }
}

Sum Type

enum Foo {
  A(String),
  B(i32),
}

fn x(foo: &Foo) {
  match foo {
    Foo::A(a) => println!("A({}).x", a),
    Foo::B(b) => println!("B({}).x", b),
  }
}

fn y(foo: &Foo) {
  match foo {
    Foo::A(a) => println!("A({}).y", a),
    Foo::B(b) => println!("B({}).y", b),
  }
}

fn create_foos() -> Vec<Foo> {
  vec!(
    Foo::A("hello".to_string()),
    Foo::B(100),
  )
}

fn main() {
  for ref foo in create_foos() {
    x(foo);
    y(foo);
  }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment