Skip to content

Instantly share code, notes, and snippets.

@forax
Created January 3, 2018 16:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save forax/5d3c68733c4ff39ae20b8bdcfa3d8b0a to your computer and use it in GitHub Desktop.
Save forax/5d3c68733c4ff39ae20b8bdcfa3d8b0a to your computer and use it in GitHub Desktop.
A compiler that takes a toy AST and generate the corresponding method handle tree
import static java.lang.invoke.MethodHandles.constant;
import static java.lang.invoke.MethodHandles.dropArguments;
import static java.lang.invoke.MethodHandles.empty;
import static java.lang.invoke.MethodHandles.identity;
import static java.lang.invoke.MethodHandles.loop;
import static java.lang.invoke.MethodHandles.publicLookup;
import static java.lang.invoke.MethodType.methodType;
import static java.util.stream.IntStream.range;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.function.Function;
public interface ASTToMHCompiler {
interface Node { /* empty */ }
class VarNode implements Node {
final String name;
VarNode(String name) { this.name = name; }
}
class VarAssignmentNode implements Node {
final Class<?> type;
final String name;
final Node expr;
VarAssignmentNode(Class<?> type, String name, Node expr) {
this.type = type;
this.name = name;
this.expr = expr;
}
}
class CallNode implements Node {
final MethodHandle mh;
final List<Node> args;
CallNode(MethodHandle mh, Node... args) {
if (mh.type().parameterCount() != args.length) { throw new IllegalArgumentException("invalid number of arguments"); }
this.args = List.of(args);
this.mh = mh;
}
}
class Visitor<R> {
final HashMap<Class<?>, Function<Object, ? extends R>> map = new HashMap<>();
<T> Visitor<R> when(Class<T> type, Function<? super T, ? extends R> fun) {
map.put(type, fun.compose(type::cast));
return this;
}
R call(Object o) { return map.get(o.getClass()).apply(o); }
}
static MethodHandle compile(List<String> params, List<Class<?>> paramTypes, List<Node> body) {
ArrayList<Class<?>> types = new ArrayList<>();
HashMap<String, Integer> localSlot = new HashMap<>();
class Local {
void newLocal(Class<?> type, String name) {
int slot = localSlot.size();
localSlot.put(name, slot);
types.add(type);
}
MethodHandle get(String name) {
int slot = localSlot.get(name);
MethodHandle mh = identity(types.get(slot));
mh = dropArguments(mh, 1, types.subList(slot + 1, types.size()));
mh = dropArguments(mh, 0, types.subList(0, slot));
return mh;
}
}
Local local = new Local();
// register all local variables
body.stream()
.filter(VarAssignmentNode.class::isInstance)
.map(VarAssignmentNode.class::cast)
.forEach(assign -> local.newLocal(assign.type, assign.name));
// add all parameters
range(0, params.size()).forEach(i -> local.newLocal(paramTypes.get(i), params.get(i)));
Visitor<MethodHandle> exprVisitor = new Visitor<>();
exprVisitor
.when(VarNode.class, var -> local.get(var.name))
.when(CallNode.class, call -> {
MethodHandle mh = call.mh;
List<Node> args = call.args;
mh = dropArguments(mh, mh.type().parameterCount(), types);
for(int i = 0; i < args.size(); i++) {
MethodHandle arg = exprVisitor.call(args.get(i));
if (i != args.size() - 1) {
arg = dropArguments(arg, 0, mh.type().parameterList().subList(0, args.size() - i - 1));
}
mh = MethodHandles.foldArguments(mh, arg);
}
return mh;
});
Function<Node, MethodHandle[]> exprAction = expr -> {
MethodHandle mh = exprVisitor.call(expr);
return new MethodHandle[] { null, mh.asType(mh.type().changeReturnType(void.class)), null, mh };
};
Visitor<MethodHandle[]> instrVisitor = new Visitor<MethodHandle[]>()
.when(VarNode.class, exprAction)
.when(CallNode.class, exprAction)
.when(VarAssignmentNode.class, assign -> new MethodHandle[] {
empty(methodType(assign.type, paramTypes)), // init
exprVisitor.call(assign.expr), // step
null, // pred
local.get(assign.name) // fini
});
ArrayList<MethodHandle[]> clauses = new ArrayList<>();
for(int i = 0; i < body.size(); i++) {
Node instr = body.get(i);
MethodHandle[] clause = instrVisitor.call(instr);
if (i == body.size() - 1) { // last parameter
clause[2] = constant(boolean.class, false);
if (clause[0] == null) { // expression
clause[1] = null;
}
} else {
clause[3] = null;
}
clauses.add(clause);
}
// DEBUG
//clauses.forEach(clause -> System.err.println(Arrays.toString(clause)));
return loop(clauses.toArray(new MethodHandle[0][]));
}
public static void main(String[] args) throws Throwable {
MethodHandle sum = publicLookup().findStatic(Integer.class, "sum", methodType(int.class, int.class, int.class));
MethodHandle parseInt = publicLookup().findStatic(Integer.class, "parseInt", methodType(int.class, String.class));
// first example
List<Node> body = List.of(
new VarAssignmentNode(int.class, "c", new VarNode("a")),
new VarAssignmentNode(int.class, "d", new CallNode(sum, new VarNode("c"), new VarNode("b"))),
new VarNode("d")
);
MethodHandle mh = compile(List.of("a", "b"), List.of(int.class, int.class), body);
int result = (int)mh.invokeExact(2, 3);
System.out.println(result);
// second example
List<Node> body2 = List.of(
new VarAssignmentNode(int.class, "c", new CallNode(parseInt, new VarNode("a"))),
new VarAssignmentNode(int.class, "d", new CallNode(parseInt, new VarNode("b"))),
new CallNode(sum, new VarNode("c"), new VarNode("d"))
);
MethodHandle mh2 = compile(List.of("a", "b"), List.of(String.class, String.class), body2);
int result2 = (int)mh2.invokeExact("100", "7");
System.out.println(result2);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment