Skip to content

Instantly share code, notes, and snippets.

@stellingsimon
Last active July 19, 2018 13:54
Show Gist options
  • Save stellingsimon/8c7d3e889dea86313f159e92562c3bc1 to your computer and use it in GitHub Desktop.
Save stellingsimon/8c7d3e889dea86313f159e92562c3bc1 to your computer and use it in GitHub Desktop.
Plain Java emulation of Kotlin's `sealed` classes: https://kotlinlang.org/docs/reference/sealed-classes.html

Kotlin has an awesome feature called sealed classes:

sealed class Expr
data class Const(val number: Double) : Expr()
data class Sum(val e1: Expr, val e2: Expr) : Expr()

Their main advantage is that it allows switch statments over the type hierarchy:

fun eval(expr: Expr): Double = when(expr) {
      is Const -> expr.number
      is Sum -> eval(expr.e1) + eval(expr.e2)
      // the `else` clause is not required because we've covered all the cases
}

If a third sub-type is added (e.g. NotANumber), all existing call sites will raise compilation errors until a matching switch clause is added.

Also, note that we don't need to cast down to access a subtype's properties.

Can we emulate Kotlin's sealed classes in an attempt to get the same compile-time safety in plain Java?


TDLR: It's possible to get somewhat close.

Using the switch

An imitation of the switch statement looks like this:

public int evaluate(MathExpression mathExpression) {
    return mathExpression.apply(
            constant -> constant.getNumber(),
            sum -> evaluate(sum.getLeft()) + evaluate(sum.getRight())
    );
}

Defining the switch

Here's the definition for apply() that makes this work:

public <T> T apply(Function<Constant, T> constantCase, Function<Sum, T> sumCase) {
        if (this instanceof Constant) {
            return constantCase.apply((Constant) this);
        }
        if (this instanceof Sum) {
            return sumCase.apply((Sum) this);
        }
        throw new IllegalStateException();
    }

Because this definition takes exactly one argument per case, the compiler can verify that all cases are handled if there are no subclasses of MathExpression that it doesn't know about.

In other words, we get the desired guarantees if MathExpression is closed with regard to inheritance.

Closing the inheritance hierarchy

We need MathExpression to have sub-types (Constant and Sum), so we can't mark it as final.

We can however make its constructor private, such that noone outside of the file MathExpression.java will be able to extend it, and define Constant and Sum inside the same file:

public abstract class MathExpression {

    private MathExpression() {
    }

    // apply() defined here ...

    public static class Constant extends MathExpression {

        private int number;

        public Constant(int number) {
            this.number = number;
        }
        
    }

    public static class Sum extends MathExpression {

        private MathExpression left;
        private MathExpression right;

        public Sum(MathExpression left, MathExpression right) {
            this.left = left;
            this.right = right;
        }

    }
}

Bottom Line

As always, Java requires a certain amount of redundancy to achieve a similar effect to what Kotlin does with a few short lines.

However, much of this redundancy can be verified with runtime tests. Check out the demonstration gist above, it contains an extensive test suite guarding against common pitfalls when employing this pattern.

public class Calculator {
public static void main(String[] args) {
Calculator calculator = new Calculator();
MathExpression onePlusTwo = calculator.parseSimpleToken(MathExpressionType.SUM, "1+2");
MathExpression three = calculator.parseSimpleToken(MathExpressionType.CONSTANT, "3");
System.out.println("evaluating 1+2: " + calculator.evaluate(onePlusTwo));
System.out.println("evaluationg 3: " + calculator.evaluate(three));
System.out.println(onePlusTwo.spellOut() + " is equal to " + three.spellOut());
}
public int evaluate(MathExpression mathExpression) {
return mathExpression.apply(
constant -> constant.getNumber(),
sum -> evaluate(sum.getLeft()) + evaluate(sum.getRight())
);
}
public MathExpression parseSimpleToken(MathExpressionType tokenType, String token) {
return MathExpression.create(
tokenType,
// CONSTANT:
() -> new MathExpression.Constant(Integer.parseInt(token)),
// SUM:
() -> {
String[] splitTokens = token.split("\\+");
MathExpression left = parseSimpleToken(MathExpressionType.CONSTANT, splitTokens[0]);
MathExpression right = parseSimpleToken(MathExpressionType.CONSTANT, splitTokens[1]);
return new MathExpression.Sum(left, right);
});
}
}
import com.example.sealedtype.SealedType;
import java.util.function.Function;
import java.util.function.Supplier;
public abstract class MathExpression implements SealedType<MathExpressionType> {
private MathExpression() {
}
@Override
public MathExpressionType getType() {
return apply(
x -> MathExpressionType.CONSTANT,
x -> MathExpressionType.SUM);
}
public <T> T apply(Function<Constant, T> constantCase, Function<Sum, T> sumCase) {
if (this instanceof Constant) {
return constantCase.apply((Constant) this);
}
if (this instanceof Sum) {
return sumCase.apply((Sum) this);
}
throw new IllegalStateException();
}
public static MathExpression create(MathExpressionType type, Supplier<Constant> constantCase, Supplier<Sum> sumCase) {
switch (type) {
case CONSTANT:
return constantCase.get();
case SUM:
return sumCase.get();
}
throw new IllegalStateException();
}
public abstract String spellOut();
public static class Constant extends MathExpression {
private int number;
public Constant(int number) {
this.number = number;
}
public int getNumber() {
return number;
}
@Override
public String spellOut() {
return "the constant " + number;
}
}
public static class Sum extends MathExpression {
private MathExpression left;
private MathExpression right;
public Sum(MathExpression left, MathExpression right) {
this.left = left;
this.right = right;
}
public MathExpression getLeft() {
return left;
}
public MathExpression getRight() {
return right;
}
@Override
public String spellOut() {
return "the sum of " + left.spellOut() + " and " + right.spellOut();
}
}
}
public enum MathExpressionType {
CONSTANT,
SUM;
}
package com.example.sealedtype;
/**
* Mark a class as a sealed type (as known in the Kotlin language).
* <p>
* Implementors should provide an {@code apply(...)} method to enable compile-safe subtype-switching as follows:
*
* <pre>
* String message = sealedInstance.apply(
* subA -> subA.methodOnSubA(),
* subB -> subB.methodOnSubB());
* </pre>
*
* Implementors should also provide a static {@code create(T, ...)} method to enable compile-safe subtype-creation:
*
* <pre>
* Sealed sealedInstance = Sealed.create(
* SealedTypeEnum.SUB_A,
* () -> new Sealed.SubA("SubA-Specific constructor argument"),
* () -> new Sealed.SubB("SubA-Specific constructor argument", "etc"));
* </pre>
* </p>
* <p>
* The above can be checked by the compiler, and introducing new subtypes with lead to compile-time errors in all use sites. Compare this to the following unsafe alternative:
*
* <pre>
* switch (sealedInstance.getType()) {
* case SealedTypeEnum.SUB_A:
* return ((Sealed.SubA) sealedInstance).methodOnSubA();
* case SealedTypeEnum.SUB_B:
* return ((Sealed.SubB) sealedInstance).methodOnSubB();
* default:
* throw new UnknownEnumException(sealedInstance.getType());
* }
* </pre>
* </p>
* <p>
* <b>Implementation note:</b> Sealed classes must have the following properties:
* <ul>
* <li>The sealed class itself is abstract.</li>
* <li>The sealed class provides an {@code apply()} method with one argument for each enum value</li>
* <li>The sealed class' constructors are private.</li>
* <li>The direct subtypes of a sealed class are <b>static</b> inner classes of the sealed class itself.</li>
* <li>For every direct sub-type of the sealed class', {@code getType()} returns a different value.</li>
* <li>For every value of the enum returned by {@code getType()}, there exists one direct sub-type such that {@code getType()} returns said value.</li>
* </ul>
* </p>
* All guarantees provided by implementors are verified by {@code SealedTypeChecker}.
*/
public interface SealedType<T extends Enum<T>> {
T getType();
}
package com.exmpale.sealedtype;
import com.exmpale.sealedtype.SealedType;
import com.google.common.base.MoreObjects;
import com.google.common.collect.Sets;
import org.jooq.lambda.Seq;
import org.junit.BeforeClass;
import org.junit.Test;
import org.reflections.Reflections;
import org.reflections.scanners.SubTypesScanner;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.ParameterizedType;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.empty;
import static org.jooq.lambda.Seq.seq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class SealedTypeChecker {
private static Reflections reflections;
@BeforeClass
public static void setUp() {
// do this once per test class since it takes considerable amount of time and can be reused anyway
reflections = new Reflections("com.example", new SubTypesScanner());
}
@Test
public void checkThatSealedTypesAreAbstract() {
Set<Class<? extends SealedType>> nonAbstractClasses = findSealedTypes()
.filter(clazz -> !Modifier.isAbstract(clazz.getModifiers()))
.toSet();
assertThat("Sealed types must be abstract", nonAbstractClasses, empty());
}
@Test
public void checkThatSealedTypesHaveOnlyPrivateConstructors() {
Set<Constructor<?>> nonPrivateConstructors = findSealedTypes()
.flatMap(clazz -> Arrays.stream(clazz.getConstructors()))
.filter(constructor -> !Modifier.isPrivate(constructor.getModifiers()))
.toSet();
assertThat("Sealed types must only have private constructors", nonPrivateConstructors, empty());
}
@Test
public void checkThatSealedTypesHaveASubTypeForEveryEnumValue() {
Set<SealedTypeInfo> sealedTypesWithUnimplementedSubtypes = getSealedTypeInfos()
.filter(sealedTypeInfo -> !Sets.difference(sealedTypeInfo.getDeclaredSubtypes(), sealedTypeInfo.getImplementedSubtypes()).isEmpty())
.toSet();
assertThat("There must be a subtype for every enum value", sealedTypesWithUnimplementedSubtypes, empty());
}
@Test
public void checkThatSealedTypesHaveAEnumValueForEverySubtype() {
Set<SealedTypeInfo> sealedTypesWithUnimplementedSubtypes = getSealedTypeInfos()
.filter(sealedTypeInfo -> !Sets.difference(sealedTypeInfo.getImplementedSubtypes(), sealedTypeInfo.getDeclaredSubtypes()).isEmpty())
.toSet();
assertThat("There must be an enum value for every subtype", sealedTypesWithUnimplementedSubtypes, empty());
}
@Test
public void checkThatSealedTypesProvideAnApplyMethod() {
Set<Class<? extends SealedType>> sealedTypesWithoutApplyMethod = findSealedTypes()
.filter(sealedType -> !seq(Arrays.stream(sealedType.getDeclaredMethods()))
.findFirst(applyMethod -> applyMethod.getName().equals("apply")).isPresent())
.toSet();
assertThat("There must be an 'apply()' method defined on the sealed type", sealedTypesWithoutApplyMethod, empty());
}
@Test
public void checkThatSealedTypesProvideAnApplyMethodWithMatchingNumberOfArguments() {
Set<Class<? extends SealedType>> sealedTypesWithoutAppropriateApplyMethodApplyMethod = findSealedTypes()
.filter(sealedType -> !seq(Arrays.stream(sealedType.getDeclaredMethods()))
.filter(method -> method.getName().equals("apply"))
.findFirst(applyMethod -> {
Set<? extends Class<? extends SealedType>> subtypesDefined = findDirectSubTypes(sealedType);
List<Class<?>> subtypesFoundInApply = Seq.seq(Arrays.stream(applyMethod.getParameterTypes())).toList();
// we can only check the argument count due to erasure
return subtypesFoundInApply.size() == subtypesDefined.size();
})
.isPresent())
.toSet();
assertThat("There must be an 'apply(Function<SubA, T>, Function<SubB, T>, ...)' method taking exactly one lambda function for each possible subtype",
sealedTypesWithoutAppropriateApplyMethodApplyMethod,
empty());
}
@Test
public void checkThatSealedTypesProvideAStaticCreateMethod() {
Set<Class<? extends SealedType>> sealedTypesWithoutApplyMethod = findSealedTypes()
.filter(sealedType -> !seq(Arrays.stream(sealedType.getDeclaredMethods()))
.findFirst(this::isStaticCreateMethod).isPresent())
.toSet();
assertThat("There must be a static 'create()' method defined on the sealed type", sealedTypesWithoutApplyMethod, empty());
}
@Test
public void checkThatSealedTypesProvideAStaticCreateMethodMethodWithMatchingNumberOfArguments() {
Set<Class<? extends SealedType>> sealedTypesWithoutApplyMethod = findSealedTypes()
.filter(sealedType -> !seq(Arrays.stream(sealedType.getDeclaredMethods()))
.filter(this::isStaticCreateMethod)
.findFirst(createMethod -> {
Set<? extends Class<? extends SealedType>> subtypesDefined = findDirectSubTypes(sealedType);
List<Class<?>> subtypesFoundInApply = Seq.seq(Arrays.stream(createMethod.getParameterTypes())).toList();
// we can only check the argument count due to erasure
return subtypesFoundInApply.size() == 1 + subtypesDefined.size();
})
.isPresent())
.toSet();
assertThat("There must be a static 'create(T, Supplier<SubA, T>, Supplier<SubB, T>, ...)' method taking exactly one lambda function for each possible subtype",
sealedTypesWithoutApplyMethod,
empty());
}
private boolean isStaticCreateMethod(Method method) {
return method.getName().equals("create")
&& Modifier.isStatic(method.getModifiers());
}
private Seq<Class<? extends SealedType>> findSealedTypes() {
return seq(reflections.getSubTypesOf(SealedType.class))
.filter(clazz -> Arrays.asList(clazz.getInterfaces()).contains(SealedType.class));
}
private Seq<SealedTypeInfo> getSealedTypeInfos() {
return findSealedTypes()
.map(sealedType -> {
Set<? extends Class<? extends SealedType>> subTypes = findDirectSubTypes(sealedType);
Set<? extends Enum<?>> implementedTypeEnumValues = seq(subTypes)
.map(clazz -> {
SealedType<?> mock = mock(clazz);
when(mock.getType()).thenCallRealMethod();
return mock.getType();
})
.toSet();
Set<? extends Enum<?>> declaredEnumValues = getDeclaredEnumValues(sealedType);
return new SealedTypeInfo(sealedType, declaredEnumValues, implementedTypeEnumValues);
});
}
private Set<? extends Class<? extends SealedType>> findDirectSubTypes(Class<? extends SealedType> sealedType) {
return seq(reflections.getSubTypesOf(sealedType))
.filter(subtype -> subtype.getSuperclass().equals(sealedType))
.toSet();
}
private class SealedTypeInfo {
private Class<? extends SealedType> sealedType;
private Set<? extends Enum<?>> declaredSubtypes;
private Set<? extends Enum<?>> implementedSubtypes;
public SealedTypeInfo(Class<? extends SealedType> sealedType, Set<? extends Enum<?>> declaredSubtypes, Set<? extends Enum<?>> implementedSubtypes) {
this.sealedType = sealedType;
this.declaredSubtypes = declaredSubtypes;
this.implementedSubtypes = implementedSubtypes;
}
public Class<? extends SealedType> getSealedType() {
return sealedType;
}
public Set<? extends Enum<?>> getDeclaredSubtypes() {
return declaredSubtypes;
}
public Set<? extends Enum<?>> getImplementedSubtypes() {
return implementedSubtypes;
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("sealedType", sealedType)
.add("declaredSubtypes", declaredSubtypes)
.add("implementedSubtypes", implementedSubtypes)
.toString();
}
}
private Set<? extends Enum<?>> getDeclaredEnumValues(Class<? extends SealedType> sealedType) {
Class<? extends Enum<?>> enumType = seq(Arrays.stream(sealedType.getGenericInterfaces()))
.findFirst(genericInterface -> genericInterface.getTypeName().startsWith(SealedType.class.getName() + "<"))
.map(parametrizedSealedType -> (Class<? extends Enum<?>>) ((ParameterizedType) parametrizedSealedType).getActualTypeArguments()[0])
.get();
return seq(Arrays.stream(enumType.getEnumConstants())).toSet();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment