Skip to content

Instantly share code, notes, and snippets.

@Runemoro
Last active March 1, 2020 01:37
Show Gist options
  • Save Runemoro/6515914287d8994d5b36c05bbd771fc2 to your computer and use it in GitHub Desktop.
Save Runemoro/6515914287d8994d5b36c05bbd771fc2 to your computer and use it in GitHub Desktop.
package nofloats;
import org.objectweb.asm.*;
// TODO:
// - Annotations
public class ClassFloatRemover extends ClassVisitor {
public ClassFloatRemover(ClassVisitor cv) {
super(Opcodes.ASM7, cv);
}
@Override
public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
descriptor = NoFloats.translateDescriptor(descriptor);
signature = NoFloats.translateSignature(signature);
value = NoFloats.translateValue(value);
return super.visitField(access, name, descriptor, signature, value);
}
@Override
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
descriptor = NoFloats.translateDescriptor(descriptor);
signature = NoFloats.translateSignature(descriptor);
return new MethodFloatRemover(super.visitMethod(access, name, descriptor, signature, exceptions));
}
@Override
public void visitEnd() {
super.visitEnd();
}
}
package nofloats;
public interface IMath {
default double cos(double x) {
throw new AssertionError();
}
default long cos(long x) {
throw new AssertionError();
}
}
package nofloats;
public class MathImpl implements IMath {
@Override
public double cos(double x) {
double x2 = x * x;
double c = 1;
double result = 0;
for (int i = 1; i < 20; i+= 2) {
result += c;
c = -(c * x2) / (i * (i + 1));
}
return result;
}
}
package nofloats;
import org.objectweb.asm.Handle;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
public class MethodFloatRemover extends MethodVisitor {
public MethodFloatRemover(MethodVisitor mv) {
super(Opcodes.ASM7, mv);
}
@Override
public void visitInsn(int opcode) {
switch (opcode) {
case Opcodes.FCONST_0:
case Opcodes.DCONST_0: {
super.visitInsn(Opcodes.LCONST_0);
return;
}
case Opcodes.FCONST_1:
case Opcodes.DCONST_1: {
super.visitLdcInsn(1L << 32);
return;
}
case Opcodes.FCONST_2: {
super.visitLdcInsn(2 * (1L << 32));
return;
}
case Opcodes.FALOAD:
case Opcodes.DALOAD: {
super.visitInsn(Opcodes.LALOAD);
return;
}
case Opcodes.FASTORE:
case Opcodes.DASTORE: {
super.visitInsn(Opcodes.LASTORE);
return;
}
case Opcodes.FADD:
case Opcodes.DADD: {
super.visitInsn(Opcodes.LADD);
return;
}
case Opcodes.FSUB:
case Opcodes.DSUB: {
super.visitInsn(Opcodes.LSUB);
return;
}
case Opcodes.FMUL:
case Opcodes.DMUL: {
r16();
swap();
r16();
super.visitInsn(Opcodes.LMUL);
return;
}
case Opcodes.FDIV:
case Opcodes.DDIV: {
r16();
swap();
l16();
swap();
super.visitInsn(Opcodes.LDIV);
return;
}
case Opcodes.FREM:
case Opcodes.DREM: { // TODO: is this right?
r16();
swap();
l16();
swap();
super.visitInsn(Opcodes.LREM);
return;
}
case Opcodes.FNEG:
case Opcodes.DNEG: {
super.visitInsn(Opcodes.LNEG);
return;
}
case Opcodes.F2D:
case Opcodes.D2F: {
return;
}
case Opcodes.F2L:
case Opcodes.D2L: {
r32();
return;
}
case Opcodes.F2I:
case Opcodes.D2I: {
r32();
super.visitInsn(Opcodes.L2I);
return;
}
case Opcodes.L2F:
case Opcodes.L2D: {
l32();
return;
}
case Opcodes.I2F:
case Opcodes.I2D: {
super.visitInsn(Opcodes.I2L);
l32();
return;
}
case Opcodes.FCMPL:
case Opcodes.FCMPG:
case Opcodes.DCMPL:
case Opcodes.DCMPG: {
super.visitInsn(Opcodes.LCMP);
return;
}
case Opcodes.FRETURN:
case Opcodes.DRETURN: {
super.visitInsn(Opcodes.LRETURN);
return;
}
}
super.visitInsn(opcode);
}
private void r32() {
super.visitIntInsn(Opcodes.BIPUSH, 32);
super.visitInsn(Opcodes.LSHR);
}
private void l32() {
super.visitIntInsn(Opcodes.BIPUSH, 32);
super.visitInsn(Opcodes.LSHL);
}
private void l16() {
super.visitIntInsn(Opcodes.BIPUSH, 16);
super.visitInsn(Opcodes.LSHL);
}
private void r16() {
super.visitIntInsn(Opcodes.BIPUSH, 16);
super.visitInsn(Opcodes.LSHR);
}
private void swap() {
super.visitInsn(Opcodes.DUP2_X2);
super.visitInsn(Opcodes.POP2);
}
@Override
public void visitIntInsn(int opcode, int operand) {
if (opcode == Opcodes.NEWARRAY && (operand == Opcodes.T_FLOAT || operand == Opcodes.T_DOUBLE)) {
super.visitIntInsn(Opcodes.NEWARRAY, Opcodes.T_LONG);
return;
}
super.visitIntInsn(opcode, operand);
}
@Override
public void visitVarInsn(int opcode, int var) {
switch (opcode) {
case Opcodes.FLOAD:
case Opcodes.DLOAD: {
super.visitVarInsn(Opcodes.LLOAD, var);
return;
}
case Opcodes.FSTORE:
case Opcodes.DSTORE: {
super.visitVarInsn(Opcodes.LSTORE, var);
return;
}
}
super.visitVarInsn(opcode, var);
}
@Override
public void visitFieldInsn(int opcode, String owner, String name, String descriptor) {
super.visitFieldInsn(opcode, owner, name, NoFloats.translateDescriptor(descriptor));
}
@Override
public void visitMethodInsn(int opcode, String owner, String name, String descriptor, boolean isInterface) {
super.visitMethodInsn(opcode, owner, name, NoFloats.translateDescriptor(descriptor), isInterface);
}
@Override
public void visitInvokeDynamicInsn(String name, String descriptor, Handle bootstrapMethodHandle, Object... bootstrapMethodArguments) {
super.visitInvokeDynamicInsn(name, NoFloats.translateDescriptor(descriptor), bootstrapMethodHandle, bootstrapMethodArguments);
}
@Override
public void visitLdcInsn(Object value) {
super.visitLdcInsn(NoFloats.translateValue(value));
}
@Override
public void visitMultiANewArrayInsn(String descriptor, int numDimensions) {
super.visitMultiANewArrayInsn(NoFloats.translateDescriptor(descriptor), numDimensions);
}
@Override
public void visitLocalVariable(String name, String descriptor, String signature, Label start, Label end, int index) {
super.visitLocalVariable(name, NoFloats.translateDescriptor(descriptor), NoFloats.translateSignature(signature), start, end, index);
}
@Override
public void visitFrame(int type, int numLocal, Object[] local, int numStack, Object[] stack) {}
@Override
public void visitMaxs(int maxStack, int maxLocals) { // TODO?
super.visitMaxs(maxStack, maxLocals);
}
}
package nofloats;
import org.objectweb.asm.signature.SignatureReader;
import org.objectweb.asm.signature.SignatureWriter;
public class NoFloats {
public static String translateDescriptor(String descriptor) {
return translateSignature(descriptor);
}
public static String translateSignature(String signature) {
if (signature == null) {
return null;
}
SignatureWriter writer = new SignatureWriter() {
@Override
public void visitBaseType(char descriptor) {
if (descriptor == 'F' || descriptor == 'D') {
descriptor = 'J';
}
super.visitBaseType(descriptor);
}
};
new SignatureReader(signature).accept(writer);
return writer.toString();
}
public static Object translateValue(Object value) {
if (value == null) {
return null;
}
if (value instanceof Float) {
return (long) ((Float) value * (1L << 32L));
}
if (value instanceof Double) {
return (long) ((Double) value * (1L << 32L));
}
return value;
}
}
package nofloats;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
import java.nio.file.Files;
import java.nio.file.Path;
public class Test {
public static void main(String[] args) throws Exception {
Class<?> c = new ClassLoader() {
@Override
protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
if (excluded(name)) {
return super.loadClass(name, resolve);
}
Class<?> c = createTransformedClass(name);
if (resolve) {
resolveClass(c);
}
return c;
}
private boolean excluded(String name) {
return !name.contains("MathImpl");
}
protected Class<?> createTransformedClass(String name) throws ClassNotFoundException {
try {
ClassReader cr = new ClassReader(name.replace('/', '.'));
ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
cr.accept(new ClassFloatRemover(cw), 0);
byte[] bytes = cw.toByteArray();
Path path = Path.of("debug/" + name.replace('.', '/') + ".class");
Files.createDirectories(path.getParent());
Files.write(path, bytes);
return defineClass(name, bytes, 0, bytes.length);
} catch (Exception e) {
throw new ClassNotFoundException("exception while transforming class", e);
}
}
}.loadClass("nofloats.MathImpl");
IMath fastMath = (IMath) c.getDeclaredConstructor().newInstance();
IMath math = new MathImpl();
System.out.println(math.cos(1.0));
System.out.println((double) fastMath.cos(1L << 32) / (1L << 32));
testSlow(math);
testFast(fastMath);
testSlow(math);
testFast(fastMath);
testSlow(math);
testFast(fastMath);
testSlow(math);
testFast(fastMath);
System.out.println();
}
private static void testSlow(IMath math) {
long t = System.nanoTime();
double x = 0;
for (long i = 0; i < 100_000_000; i++) {
x = math.cos(x);
}
System.out.println("doubles: " + round((double) (System.nanoTime() - t) / 1_000_000_000) + "s");
}
private static void testFast(IMath math) {
long t = System.nanoTime();
long x = 0;
for (long i = 0; i < 100_000_000; i++) {
x = math.cos(x);
}
System.out.println("longs: " + round((double) (System.nanoTime() - t) / 1_000_000_000) + "s");
}
private static double round(double t) {
return (double) (int) (t * 1000) / 1000;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment