Created
January 3, 2018 16:05
A compiler that takes a toy AST and generate the corresponding method handle tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import static java.lang.invoke.MethodHandles.constant; | |
import static java.lang.invoke.MethodHandles.dropArguments; | |
import static java.lang.invoke.MethodHandles.empty; | |
import static java.lang.invoke.MethodHandles.identity; | |
import static java.lang.invoke.MethodHandles.loop; | |
import static java.lang.invoke.MethodHandles.publicLookup; | |
import static java.lang.invoke.MethodType.methodType; | |
import static java.util.stream.IntStream.range; | |
import java.lang.invoke.MethodHandle; | |
import java.lang.invoke.MethodHandles; | |
import java.util.ArrayList; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.function.Function; | |
public interface ASTToMHCompiler { | |
interface Node { /* empty */ } | |
class VarNode implements Node { | |
final String name; | |
VarNode(String name) { this.name = name; } | |
} | |
class VarAssignmentNode implements Node { | |
final Class<?> type; | |
final String name; | |
final Node expr; | |
VarAssignmentNode(Class<?> type, String name, Node expr) { | |
this.type = type; | |
this.name = name; | |
this.expr = expr; | |
} | |
} | |
class CallNode implements Node { | |
final MethodHandle mh; | |
final List<Node> args; | |
CallNode(MethodHandle mh, Node... args) { | |
if (mh.type().parameterCount() != args.length) { throw new IllegalArgumentException("invalid number of arguments"); } | |
this.args = List.of(args); | |
this.mh = mh; | |
} | |
} | |
class Visitor<R> { | |
final HashMap<Class<?>, Function<Object, ? extends R>> map = new HashMap<>(); | |
<T> Visitor<R> when(Class<T> type, Function<? super T, ? extends R> fun) { | |
map.put(type, fun.compose(type::cast)); | |
return this; | |
} | |
R call(Object o) { return map.get(o.getClass()).apply(o); } | |
} | |
static MethodHandle compile(List<String> params, List<Class<?>> paramTypes, List<Node> body) { | |
ArrayList<Class<?>> types = new ArrayList<>(); | |
HashMap<String, Integer> localSlot = new HashMap<>(); | |
class Local { | |
void newLocal(Class<?> type, String name) { | |
int slot = localSlot.size(); | |
localSlot.put(name, slot); | |
types.add(type); | |
} | |
MethodHandle get(String name) { | |
int slot = localSlot.get(name); | |
MethodHandle mh = identity(types.get(slot)); | |
mh = dropArguments(mh, 1, types.subList(slot + 1, types.size())); | |
mh = dropArguments(mh, 0, types.subList(0, slot)); | |
return mh; | |
} | |
} | |
Local local = new Local(); | |
// register all local variables | |
body.stream() | |
.filter(VarAssignmentNode.class::isInstance) | |
.map(VarAssignmentNode.class::cast) | |
.forEach(assign -> local.newLocal(assign.type, assign.name)); | |
// add all parameters | |
range(0, params.size()).forEach(i -> local.newLocal(paramTypes.get(i), params.get(i))); | |
Visitor<MethodHandle> exprVisitor = new Visitor<>(); | |
exprVisitor | |
.when(VarNode.class, var -> local.get(var.name)) | |
.when(CallNode.class, call -> { | |
MethodHandle mh = call.mh; | |
List<Node> args = call.args; | |
mh = dropArguments(mh, mh.type().parameterCount(), types); | |
for(int i = 0; i < args.size(); i++) { | |
MethodHandle arg = exprVisitor.call(args.get(i)); | |
if (i != args.size() - 1) { | |
arg = dropArguments(arg, 0, mh.type().parameterList().subList(0, args.size() - i - 1)); | |
} | |
mh = MethodHandles.foldArguments(mh, arg); | |
} | |
return mh; | |
}); | |
Function<Node, MethodHandle[]> exprAction = expr -> { | |
MethodHandle mh = exprVisitor.call(expr); | |
return new MethodHandle[] { null, mh.asType(mh.type().changeReturnType(void.class)), null, mh }; | |
}; | |
Visitor<MethodHandle[]> instrVisitor = new Visitor<MethodHandle[]>() | |
.when(VarNode.class, exprAction) | |
.when(CallNode.class, exprAction) | |
.when(VarAssignmentNode.class, assign -> new MethodHandle[] { | |
empty(methodType(assign.type, paramTypes)), // init | |
exprVisitor.call(assign.expr), // step | |
null, // pred | |
local.get(assign.name) // fini | |
}); | |
ArrayList<MethodHandle[]> clauses = new ArrayList<>(); | |
for(int i = 0; i < body.size(); i++) { | |
Node instr = body.get(i); | |
MethodHandle[] clause = instrVisitor.call(instr); | |
if (i == body.size() - 1) { // last parameter | |
clause[2] = constant(boolean.class, false); | |
if (clause[0] == null) { // expression | |
clause[1] = null; | |
} | |
} else { | |
clause[3] = null; | |
} | |
clauses.add(clause); | |
} | |
// DEBUG | |
//clauses.forEach(clause -> System.err.println(Arrays.toString(clause))); | |
return loop(clauses.toArray(new MethodHandle[0][])); | |
} | |
public static void main(String[] args) throws Throwable { | |
MethodHandle sum = publicLookup().findStatic(Integer.class, "sum", methodType(int.class, int.class, int.class)); | |
MethodHandle parseInt = publicLookup().findStatic(Integer.class, "parseInt", methodType(int.class, String.class)); | |
// first example | |
List<Node> body = List.of( | |
new VarAssignmentNode(int.class, "c", new VarNode("a")), | |
new VarAssignmentNode(int.class, "d", new CallNode(sum, new VarNode("c"), new VarNode("b"))), | |
new VarNode("d") | |
); | |
MethodHandle mh = compile(List.of("a", "b"), List.of(int.class, int.class), body); | |
int result = (int)mh.invokeExact(2, 3); | |
System.out.println(result); | |
// second example | |
List<Node> body2 = List.of( | |
new VarAssignmentNode(int.class, "c", new CallNode(parseInt, new VarNode("a"))), | |
new VarAssignmentNode(int.class, "d", new CallNode(parseInt, new VarNode("b"))), | |
new CallNode(sum, new VarNode("c"), new VarNode("d")) | |
); | |
MethodHandle mh2 = compile(List.of("a", "b"), List.of(String.class, String.class), body2); | |
int result2 = (int)mh2.invokeExact("100", "7"); | |
System.out.println(result2); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment