Created
April 5, 2021 18:28
-
-
Save Maldivia/be5a2908658f0e75cb09a1e81264e9ee to your computer and use it in GitHub Desktop.
Java Record to MemorySegment mapper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.test; | |
import jdk.incubator.foreign.GroupLayout; | |
import jdk.incubator.foreign.MemoryAddress; | |
import jdk.incubator.foreign.MemoryLayout; | |
import jdk.incubator.foreign.MemoryLayouts; | |
import jdk.incubator.foreign.MemorySegment; | |
import java.lang.reflect.RecordComponent; | |
import java.util.Arrays; | |
import java.util.Map; | |
public class LayoutMaker { | |
static final Map<Class<?>, Class<?>> primitiveWrapper = Map.of( | |
Boolean.class, boolean.class, | |
Byte.class, byte.class, | |
Short.class, short.class, | |
Character.class, char.class, | |
Integer.class, int.class, | |
Long.class, long.class, | |
Float.class, float.class, | |
Double.class, double.class | |
); | |
public static MemoryLayout makeLayout(Class<?> klass) { | |
klass = primitiveWrapper.getOrDefault(klass, klass); | |
if (klass.isPrimitive()) { | |
if (klass == boolean.class) return MemoryLayouts.JAVA_BYTE.withBitAlignment(8); | |
if (klass == byte.class) return MemoryLayouts.JAVA_BYTE.withBitAlignment(8); | |
if (klass == short.class) return MemoryLayouts.JAVA_SHORT.withBitAlignment(8); | |
if (klass == char.class) return MemoryLayouts.JAVA_CHAR.withBitAlignment(8); | |
if (klass == int.class) return MemoryLayouts.JAVA_INT.withBitAlignment(8); | |
if (klass == long.class) return MemoryLayouts.JAVA_LONG.withBitAlignment(8); | |
if (klass == float.class) return MemoryLayouts.JAVA_FLOAT.withBitAlignment(8); | |
if (klass == double.class) return MemoryLayouts.JAVA_DOUBLE.withBitAlignment(8); | |
} | |
else if (klass.isRecord()) { | |
RecordComponent[] components = klass.getRecordComponents(); | |
MemoryLayout[] elements = new MemoryLayout[components.length]; | |
for (int i = 0; i < components.length; i++) { | |
elements[i] = makeLayout(components[i].getType()).withName(components[i].getName()); | |
} | |
return MemoryLayout.ofStruct(elements).withBitAlignment(8); | |
} | |
throw new IllegalArgumentException("Only supports primitive and record types:" + klass); | |
} | |
private static <T> T readRecord(GroupLayout layout, MemoryAddress base, Class<T> klass, MemoryLayout.PathElement... path) throws Throwable { | |
RecordComponent[] components = klass.getRecordComponents(); | |
Object[] values = new Object[components.length]; | |
MemoryLayout.PathElement[] elementPath = Arrays.copyOf(path, path.length + 1); | |
for (int i = 0; i < components.length; i++) { | |
Class<?> type = primitiveWrapper.getOrDefault(components[i].getType(), components[i].getType()); | |
elementPath[elementPath.length - 1] = MemoryLayout.PathElement.groupElement(components[i].getName()); | |
if (type.isRecord()) { | |
values[i] = readRecord(layout, base, type, elementPath); | |
} | |
else if (type == boolean.class) { | |
values[i] = (((byte)layout.varHandle(byte.class, elementPath).get(base)) != 0); | |
} | |
else { | |
values[i] = layout.varHandle(type, elementPath).get(base); | |
} | |
} | |
return (T)klass.getDeclaredConstructors()[0].newInstance(values); | |
} | |
private static <T> void writeRecord(GroupLayout layout, MemoryAddress base, Record record, MemoryLayout.PathElement... path) throws Throwable { | |
MemoryLayout.PathElement[] elementPath = Arrays.copyOf(path, path.length + 1); | |
for (RecordComponent component : record.getClass().getRecordComponents()) { | |
Class<?> type = primitiveWrapper.getOrDefault(component.getType(), component.getType()); | |
elementPath[elementPath.length - 1] = MemoryLayout.PathElement.groupElement(component.getName()); | |
Object value = component.getAccessor().invoke(record); | |
if (value == null) { | |
throw new IllegalArgumentException(record.getClass().getName() + "." + component.getName() + " is null"); | |
} | |
else if (type.isRecord()) { | |
writeRecord(layout, base, (Record) value, elementPath); | |
} | |
else if (type == boolean.class) { | |
layout.varHandle(byte.class, elementPath).set(base, (byte) ((Boolean) value ? 1 : 0)); | |
} | |
else { | |
layout.varHandle(type, elementPath).set(base, value); | |
} | |
} | |
} | |
public static <T> T read(MemoryAddress base, Class<T> klass) throws Throwable { | |
MemoryLayout memoryLayout = makeLayout(klass); | |
if (memoryLayout instanceof GroupLayout) { | |
return readRecord((GroupLayout)memoryLayout, base, klass); | |
} | |
else if (klass == boolean.class || klass == Boolean.class) { | |
return (T)(Object)(((byte)memoryLayout.varHandle(byte.class).get(base)) != 0); | |
} | |
else { | |
return (T)memoryLayout.varHandle(klass).get(base); | |
} | |
} | |
public static <T> void write(MemoryAddress base, T value) throws Throwable { | |
Class<?> klass = value.getClass(); | |
MemoryLayout memoryLayout = makeLayout(klass); | |
if (memoryLayout instanceof GroupLayout) { | |
writeRecord((GroupLayout)memoryLayout, base, (Record)value); | |
} | |
else { | |
memoryLayout.varHandle(klass).set(base, value); | |
} | |
} | |
public record Foo(int a, long b, Integer c) {} | |
public record Bar(Foo one, double d, Foo two, boolean b) {} | |
public static void main(String[] args) throws Throwable { | |
MemoryLayout barLayout = makeLayout(Bar.class); | |
MemoryLayout memoryLayout = MemoryLayout.ofSequence(2, barLayout); | |
try (MemorySegment segment = MemorySegment.allocateNative(memoryLayout)) { | |
MemoryAddress base = segment.baseAddress(); | |
write(base, new Bar(new Foo(1, 2, 3), Math.PI, new Foo(6, 7, 8), false)); | |
write(base.addOffset(barLayout.byteSize()), new Bar(new Foo(99, 88, 77), Math.E, new Foo(55, 44, 33), true)); | |
System.out.println(read(base, Bar.class)); | |
System.out.println(read(base.addOffset(barLayout.byteSize()), Bar.class)); | |
// "Random" offset | |
System.out.println(read(base.addOffset(16), Bar.class)); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment