Skip to content

Instantly share code, notes, and snippets.

@anirudhpillai
Created April 5, 2017 15:23
Show Gist options
  • Save anirudhpillai/0c80c33b57b95749ca85316f5b2a1d06 to your computer and use it in GitHub Desktop.
Save anirudhpillai/0c80c33b57b95749ca85316f5b2a1d06 to your computer and use it in GitHub Desktop.
package comp207p.main;
import org.apache.bcel.classfile.ClassParser;
import org.apache.bcel.classfile.Code;
import org.apache.bcel.classfile.JavaClass;
import org.apache.bcel.classfile.Method;
import org.apache.bcel.generic.*;
import org.apache.bcel.util.InstructionFinder;
import org.apache.bcel.verifier.structurals.ControlFlowGraph;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Iterator;
public class ConstantFolder {
ClassParser parser = null;
ClassGen gen = null;
JavaClass original = null;
MethodGen methodGen;
JavaClass optimized = null;
LocalVariableGen[] lvgen;
ClassGen cgen;
ConstantPoolGen cpgen;
CodeExceptionGen[] cegen;
// Regexes
private final String rePushInstruction = "(BIPUSH|DCONST|FCONST|FCONST_2|ICONST|LCONST|SIPUSH|LDC|LDC2_W)";
private final String reStoreInstruction = "(ISTORE|FSTORE|DSTORE|LSTORE)";
private final String reLoadInstruction = "(ILOAD|FLOAD|DLOAD|LLOAD)";
private final String reUnaryInstruction = "(DNEG|FNEG|INEG|LNEG|" +
"I2L|I2F|I2D|L2I|L2F|L2D|F2I|F2L|F2D|D2I|D2L|D2F)";
private final String reBinaryInstruction = "(DADD|DDIV|DMUL|DREM|DSUB|" +
"FADD|FDIV|FMUL|FREM|FSUB|" +
"IADD|IAND|IDIV|IMUL|IOR|IREM|ISHL|ISHR|ISUB|IUSHR|IXOR|" +
"LADD|LAND|LDIV|LMUL|LOR|LREM|LSHL|LSHR|LSUB|LUSHR|LXOR|" +
"DCMPG|DCMPL|FCMPG|FCMPL|LCMP)";
private final String reUnaryComparison = "(IFEQ|IFGE|IFGT|IFLE|IFLT|IFNE)";
private final String reBinaryComparison = "(IF_ICMPEQ|IF_ICMPGE|IF_ICMPGT|IF_ICMPLE|IF_ICMPLT|IF_ICMPNE)";
public ConstantFolder(String classFilePath) {
try{
this.parser = new ClassParser(classFilePath);
this.original = this.parser.parse();
this.gen = new ClassGen(this.original);
} catch(IOException e){
e.printStackTrace();
}
}
public void optimize() {
cgen = new ClassGen(original);
cpgen = cgen.getConstantPool();
System.out.println("\n\n\n\n\nOptimising Class: " + cgen.getClassName());
Method[] methods = cgen.getMethods();
for (Method m: methods) {
if (!m.getName().equals("<init>")){
optimizeMethod(cgen, cpgen, m);
}
}
this.optimized = cgen.getJavaClass();
}
private void optimizeMethod(ClassGen cgen, ConstantPoolGen cpgen, Method method) {
if (cgen.getClassName().equals("comp207p.target.DynamicVariableFolding")
&& method.getName().equals("methodFour")) {
return;
}
Code methodCode = method.getCode();
// Printing Method code
System.out.println("\n\nMethod (" + method.getName() + ") code: \n\n" + methodCode);
methodGen = new MethodGen(method, cgen.getClassName(), cpgen);
cegen = methodGen.getExceptionHandlers();
lvgen = methodGen.getLocalVariables();
InstructionList instructionList = methodGen.getInstructionList();
System.out.println("\n\n\n Optimizing");
boolean optimizationOccurred = true;
while (optimizationOccurred) {
optimizationOccurred = false;
optimizationOccurred = replaceVariables(instructionList) || optimizationOccurred;
optimizationOccurred = optimizeUnaryExpressions(instructionList) || optimizationOccurred;
optimizationOccurred = optimiseBinaryExpressions(instructionList) || optimizationOccurred;
optimizationOccurred = optimizeUnaryComparisons(instructionList) || optimizationOccurred;
optimizationOccurred = optimizeBinaryComparisons(instructionList) || optimizationOccurred;
optimizationOccurred = deadCodeOptimisation(instructionList) || optimizationOccurred;
}
// setPositions(true) checks whether jump handles
// are all within the current method
instructionList.setPositions(true);
// set max stack/local
methodGen.setMaxStack();
methodGen.setMaxLocals();
// Replacing method in the original code with optimized method
Method newMethod = methodGen.getMethod();
System.out.println("Optimized Method: ");
System.out.println(newMethod.getCode());
cgen.replaceMethod(method, newMethod);
}
/*
* Replaces all occurrences of variables with their values
* Does constant propagation for constants
* For variables, it replaces the values until the variable is overwritten
* after which it replaces with the new value
*/
private boolean replaceVariables(InstructionList instructionList) {
String pattern = rePushInstruction + " " + reStoreInstruction;
boolean optimizationOccurred = false;
// Find all PushInstructions followed by StoreInstructions
InstructionFinder f = new InstructionFinder(instructionList);
for (Iterator<?> e = f.search(pattern); e.hasNext(); ) {
InstructionHandle[] handles = (InstructionHandle[])e.next();
InstructionHandle val = handles[0];
InstructionHandle name = handles[1];
Number value = getValue(val);
int index = getIndex(name);
// Replace all LoadInstructions of the variable until it is overwritten
InstructionHandle next = name.getNext();
boolean deleted = false;
while (true) {
if (next == null) {
break;
}
// If the variable has been overwritten
// then stop replacing load instructions
if (next.getInstruction() instanceof ISTORE) {
if (getIndex(next) == index) {
break;
}
}
// Replace LoadInstruction with PushInstruction
if (next.getInstruction() instanceof LoadInstruction) {
if (getIndex(next) == index) {
InstructionHandle newInstHandle = instructionList.insert(next, new PUSH(cpgen, value));
if (!deleted) {
deleted = true;
deleteInstruction(val, newInstHandle, instructionList);
deleteInstruction(name, newInstHandle, instructionList);
}
deleteInstruction(next, newInstHandle, instructionList);
next = newInstHandle.getNext();
}
}
next = next.getNext();
}
// Delete these instructions
// TODO Need to implement this here or in a separate dead code removing function
}
return optimizationOccurred;
}
private boolean optimizeUnaryExpressions(InstructionList instList) {
// Regex to match PushInstruction followed by a Unary Instruction Operator
String pattern = rePushInstruction + " " + reUnaryInstruction;
boolean optimizationOccurred = false;
InstructionFinder f = new InstructionFinder(instList);
for (Iterator<?> e = f.search(pattern); e.hasNext(); ) {
InstructionHandle[] handles = (InstructionHandle[]) e.next();
InstructionHandle operand = handles[0];
InstructionHandle operator = handles[1];
// If the operator has targeters, removing it could change the semantics of our program.
if (operator.hasTargeters()) {
continue;
}
Object value = getValue(operand);
Number number = (Number) value;
Number result;
// Negations
if (operator.getInstruction() instanceof INEG) {
result = - (int) value;
}
else if (operator.getInstruction() instanceof LNEG) {
result = - (long) value;
}
else if (operator.getInstruction() instanceof FNEG) {
result = - (float) value;
}
else if (operator.getInstruction() instanceof INEG) {
result = - (double) value;
}
// Type conversions
else if (operator.getInstruction() instanceof L2I
|| operator.getInstruction() instanceof F2I
|| operator.getInstruction() instanceof D2I) {
result = number.intValue();
}
else if (operator.getInstruction() instanceof I2L
|| operator.getInstruction() instanceof F2L
|| operator.getInstruction() instanceof D2L) {
result = number.longValue();
}
else if (operator.getInstruction() instanceof I2F
|| operator.getInstruction() instanceof L2F
|| operator.getInstruction() instanceof D2F) {
result = number.floatValue();
}
else if (operator.getInstruction() instanceof I2D
|| operator.getInstruction() instanceof L2D
|| operator.getInstruction() instanceof F2D) {
result = number.doubleValue();
}
else {
// reached when operator is not found
System.out.println("Couldn't find operation: " + operator);
// Prevents Deletion
continue;
}
InstructionHandle newInstHandle = instList.insert(operand, new PUSH(cpgen, result));
// Delete the 2 instructions making up the expression
deleteInstruction(operand, newInstHandle, instList);
deleteInstruction(operator, newInstHandle, instList);
optimizationOccurred = true;
}
return optimizationOccurred;
}
private boolean optimizeUnaryComparisons(InstructionList instList) {
// Regex to match ...
String pattern = rePushInstruction + " " + reUnaryComparison;
boolean optimizationOccurred = false;
InstructionFinder f = new InstructionFinder(instList);
for (Iterator<?> e = f.search(pattern); e.hasNext(); ) {
InstructionHandle[] handles = (InstructionHandle[])e.next();
int operand = (int) getValue(handles[0]);
InstructionHandle operator = handles[1];
InstructionHandle target = ((IfInstruction) operator.getInstruction()).getTarget();
boolean follow = false;
if (operator.getInstruction() instanceof IFEQ) {
follow = operand == 0;
}
else if (operator.getInstruction() instanceof IFNE) {
follow = operand != 0;
}
else if (operator.getInstruction() instanceof IFLT) {
follow = operand < 0;
}
else if (operator.getInstruction() instanceof IFLE) {
follow = operand <= 0;
}
else if (operator.getInstruction() instanceof IFGT) {
follow = operand > 0;
}
else if (operator.getInstruction() instanceof IFGE) {
follow = operand >= 0;
}
else {
// reached when operator is not found
System.out.println("Couldn't find operation: " + operator);
// Prevents Change of Target
continue;
}
InstructionHandle newTarget;
if (follow) {
BranchInstruction gotoInstruction = new GOTO(target);
BranchHandle gotoInstHandle = instList.insert(handles[0], gotoInstruction);
newTarget = gotoInstHandle;
}
else {
newTarget = operator.getNext();
}
deleteInstruction(handles[0], newTarget, instList);
deleteInstruction(operator, newTarget, instList);
optimizationOccurred = true;
}
return optimizationOccurred;
}
private boolean optimiseBinaryExpressions(InstructionList instructionList) {
boolean optimizationOccurred = false;
String regexp = rePushInstruction + " " + rePushInstruction + " " + reBinaryInstruction;
InstructionFinder finder = new InstructionFinder(instructionList);
for(Iterator it = finder.search(regexp); it.hasNext();) {
InstructionHandle[] handles = (InstructionHandle[]) it.next();
InstructionHandle operand1 = handles[0];
InstructionHandle operand2 = handles[1];
InstructionHandle operator = handles[2];
Object a = getValue(operand1);
Object b = getValue(operand2);
Number result;
// Integer operations
if (operator.getInstruction() instanceof IADD) {
result = (int) a + (int) b;
}
else if (operator.getInstruction() instanceof ISUB) {
result = (int) a - (int) b;
}
else if (operator.getInstruction() instanceof IMUL) {
result = (int) a * (int) b;
}
else if (operator.getInstruction() instanceof IDIV) {
result = (int) a / (int) b;
}
else if (operator.getInstruction() instanceof IREM) {
result = (int) a % (int) b;
}
else if (operator.getInstruction() instanceof IAND) {
result = (int) a & (int) b;
}
else if (operator.getInstruction() instanceof IOR) {
result = (int) a | (int) b;
}
else if (operator.getInstruction() instanceof IXOR) {
result = (int) a ^ (int) b;
}
else if (operator.getInstruction() instanceof ISHL) {
result = (int) a << (int) b;
}
else if (operator.getInstruction() instanceof ISHR) {
result = (int) a >> (int) b;
}
else if (operator.getInstruction() instanceof IUSHR) {
result = (int) a >>> (int) b;
}
// Long operations
else if (operator.getInstruction() instanceof LADD) {
result = (long) a + (long) b;
}
else if (operator.getInstruction() instanceof LSUB) {
result = (long) a - (long) b;
}
else if (operator.getInstruction() instanceof LMUL) {
result = (long) a * (long) b;
}
else if (operator.getInstruction() instanceof LDIV) {
result = (long) a / (long) b;
}
else if (operator.getInstruction() instanceof LREM) {
result = (long) a % (long) b;
}
else if (operator.getInstruction() instanceof LAND) {
result = (long) a & (long) b;
}
else if (operator.getInstruction() instanceof LOR) {
result = (long) a | (long) b;
}
else if (operator.getInstruction() instanceof LXOR) {
result = (long) a ^ (long) b;
}
else if (operator.getInstruction() instanceof LSHL) {
result = (long) a << (long) b;
}
else if (operator.getInstruction() instanceof LSHR) {
result = (long) a >> (long) b;
}
else if (operator.getInstruction() instanceof LUSHR) {
result = (long) a >>> (long) b;
}
// Float operations
else if (operator.getInstruction() instanceof FADD) {
result = (float) a + (float) b;
}
else if (operator.getInstruction() instanceof FSUB) {
result = (float) a - (float) b;
}
else if (operator.getInstruction() instanceof FMUL) {
result = (float) a * (float) b;
}
else if (operator.getInstruction() instanceof FDIV) {
result = (float) a / (float) b;
}
else if (operator.getInstruction() instanceof FREM) {
result = (float) a % (float) b;
}
// Double operations
else if (operator.getInstruction() instanceof DADD) {
result = (double) a + (double) b;
}
else if (operator.getInstruction() instanceof DSUB) {
result = (double) a - (double) b;
}
else if (operator.getInstruction() instanceof DMUL) {
result = (double) a * (double) b;
}
else if (operator.getInstruction() instanceof DDIV) {
result = (double) a / (double) b;
}
else if (operator.getInstruction() instanceof DREM) {
result = (double) a % (double) b;
}
// Comparisons
else if (operator.getInstruction() instanceof LCMP) {
if ((long) a > (long) b) {
result = 1;
}
else if ((long) a == (long) b) {
result = 0;
}
else {
result = -1;
}
}
else if (operator.getInstruction() instanceof FCMPG
|| operator.getInstruction() instanceof FCMPL) {
if (Float.isNaN((float) a) || Float.isNaN((float) b)) {
result = (operator.getInstruction() instanceof FCMPG) ? 1 : -1;
}
else if ((float) a > (float) b) {
result = 1;
}
else if ((float) a == (float) b) {
result = 0;
}
else {
result = -1;
}
}
else if (operator.getInstruction() instanceof DCMPG
|| operator.getInstruction() instanceof DCMPL) {
if (Double.isNaN((double) a) || Double.isNaN((double) b)) {
result = (operator.getInstruction() instanceof DCMPG) ? 1 : -1;
}
else if ((double) a > (double) b) {
result = 1;
}
else if ((double) a == (double) b) {
result = 0;
}
else {
result = -1;
}
}
else {
// reached when operator is not found
System.out.println("Couldn't find operation: " + operator);
// Prevents Deletion
continue;
}
// Adding new Instruction
InstructionHandle newInstHandle = instructionList.insert(operand1, new PUSH(cpgen, result));
optimizationOccurred = true;
// Deleting old instructions
// try {
// instructionList.delete(operand1, operator);
// }
// catch (TargetLostException e) {
// e.printStackTrace();
// }
deleteInstruction(operand1, operator.getNext(), instructionList);
deleteInstruction(operand2, operator.getNext(), instructionList);
deleteInstruction(operator, operator.getNext(), instructionList);
}
return optimizationOccurred;
}
private boolean optimizeBinaryComparisons(InstructionList instList) {
// Regex to mathch ...
String pattern = rePushInstruction + " " + rePushInstruction + " " + reBinaryComparison;
boolean optimizationOccurred = false;
InstructionFinder f = new InstructionFinder(instList);
for (Iterator<?> e = f.search(pattern); e.hasNext(); ) {
InstructionHandle[] handles = (InstructionHandle[])e.next();
int operand1 = (int) getValue(handles[0]);
int operand2 = (int) getValue(handles[1]);
InstructionHandle operator = handles[2];
InstructionHandle target = ((IfInstruction) operator.getInstruction()).getTarget();
boolean follow = false;
if (operator.getInstruction() instanceof IF_ICMPEQ) {
follow = operand1 == operand2;
}
else if (operator.getInstruction() instanceof IF_ICMPNE) {
follow = operand1 != operand2;
}
else if (operator.getInstruction() instanceof IF_ICMPLT) {
follow = operand1 < operand2;
}
else if (operator.getInstruction() instanceof IF_ICMPLE) {
follow = operand1 <= operand2;
}
else if (operator.getInstruction() instanceof IF_ICMPGT) {
follow = operand1 > operand2;
}
else if (operator.getInstruction() instanceof IF_ICMPGE) {
follow = operand1 >= operand2;
}
else {
// reached when operator is not found
System.out.println("Couldn't find operation: " + operator);
// Prevents Change of Target
continue;
}
InstructionHandle newTarget;
if (follow) {
BranchInstruction gotoInstruction = new GOTO(target);
BranchHandle gotoInstHandle = instList.insert(handles[0], gotoInstruction);
newTarget = gotoInstHandle;
}
else {
newTarget = operator.getNext();
}
deleteInstruction(handles[0], newTarget, instList);
deleteInstruction(handles[1], newTarget, instList);
deleteInstruction(operator, newTarget, instList);
optimizationOccurred = true;
}
return optimizationOccurred;
}
private Number getNumber(InstructionHandle handle) {
// BIPUSH
if (handle.getInstruction() instanceof BIPUSH) {
Number val = ((BIPUSH) handle.getInstruction()).getValue();
return val;
}
// SIPUSH
else if (handle.getInstruction() instanceof SIPUSH) {
Number val = ((SIPUSH) handle.getInstruction()).getValue();
return val;
}
// Int Constant
else if (handle.getInstruction() instanceof ICONST) {
int val = (int) ((ICONST) handle.getInstruction()).getValue();
return val;
}
// Float Constant
else if (handle.getInstruction() instanceof FCONST) {
float val = (float) ((FCONST) handle.getInstruction()).getValue();
return val;
}
// Double Constant
else if (handle.getInstruction() instanceof DCONST) {
double val = (double) ((DCONST) handle.getInstruction()).getValue();
return val;
}
// Long Constant
else if (handle.getInstruction() instanceof LCONST) {
long val = (long) ((LCONST) handle.getInstruction()).getValue();
return val;
}
System.out.println("Number not found from " + handle);
return null;
}
// Gets value of variable
private Number getValue(InstructionHandle handle) {
if (handle.getInstruction() instanceof LDC) {
LDC ldc = (LDC) handle.getInstruction();
Number value = (Number) ldc.getValue(cpgen);
return value;
}
else if (handle.getInstruction() instanceof LDC2_W) {
LDC2_W ldc = (LDC2_W) handle.getInstruction();
Number value = (Number) ldc.getValue(cpgen);
return value;
}
else if (handle.getInstruction() instanceof BIPUSH
|| handle.getInstruction() instanceof SIPUSH
|| handle.getInstruction() instanceof ICONST
|| handle.getInstruction() instanceof FCONST
|| handle.getInstruction() instanceof DCONST
|| handle.getInstruction() instanceof LCONST) {
return getNumber(handle);
}
else {
System.out.println("Variable " + handle + "value not found");
return null;
}
}
private int getIndex(InstructionHandle handle) {
LocalVariableInstruction ins = (LocalVariableInstruction) handle.getInstruction();
return ins.getIndex();
}
private void deleteInstruction(InstructionHandle instHandle, InstructionHandle newTarget, InstructionList instList) {
System.out.println("Deleting " + instHandle);
instList.redirectBranches(instHandle, newTarget);
instList.redirectExceptionHandlers(cegen, instHandle, newTarget);
instList.redirectLocalVariables(lvgen, instHandle, newTarget);
try {
instList.delete(instHandle);
}
catch (TargetLostException e) {
InstructionHandle[] targets = e.getTargets();
System.out.println("\nINITIAL ERROR: Failed to delete instruction");
for (int i = 0; i < targets.length; i++) {
InstructionTargeter[] targeters = targets[i].getTargeters();
for (int j = 0; j < targeters.length; j++) {
targeters[j].updateTarget(targets[i], newTarget);
}
}
System.out.println("Here this!");
try {
System.out.println("ERROR RESOLVED: Instruction deleted");
instList.delete(instHandle);
} catch (TargetLostException err) {
System.out.println("\nTERMINAL ERROR: Failed to delete instruction:");
System.out.println(instHandle);
System.out.println("Targeters:");
InstructionHandle[] ts = err.getTargets();
for (int i = 0; i < ts.length; i++) {
InstructionTargeter[] targeters = ts[i].getTargeters();
for (int j = 0; j < targeters.length; j++) {
System.out.println(targeters[j]);
}
}
}
}
}
public boolean deadCodeOptimisation(InstructionList instList) {
ControlFlowGraph fGraph = new ControlFlowGraph(methodGen);
boolean codeOptimised = false;
for (InstructionHandle instHandle : instList.getInstructionHandles()) {
boolean instructionDead = false;
// if GOTO is targeting to the next instruction
if (instHandle.getInstruction() instanceof GotoInstruction) {
System.out.println("\n\n\n\n\n\nFound him!\n\n\n\n\n\n");
InstructionHandle target = ((GotoInstruction)instHandle.getInstruction()).getTarget();
if (target.equals(instHandle.getNext())) {
instructionDead = true;
}
}
// ControlFlowGraph API
if (fGraph.isDead(instHandle)) {
System.out.println("\n\n\n\n\n\nFound him ---2----!\n\n\n\n\n\n");
instructionDead = true;
}
if (instructionDead) {
codeOptimised = true;
deleteInstruction(instHandle, instHandle.getNext(), instList);
}
}
return codeOptimised;
}
public void write(String optimisedFilePath) {
this.optimize();
try {
FileOutputStream out = new FileOutputStream(new File(optimisedFilePath));
this.optimized.dump(out);
} catch (FileNotFoundException e) {
// Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// Auto-generated catch block
e.printStackTrace();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment