Skip to content

Instantly share code, notes, and snippets.

@Maldivia
Created April 5, 2021 18:28
Show Gist options
  • Save Maldivia/be5a2908658f0e75cb09a1e81264e9ee to your computer and use it in GitHub Desktop.
Save Maldivia/be5a2908658f0e75cb09a1e81264e9ee to your computer and use it in GitHub Desktop.
Java Record to MemorySegment mapper
package com.test;
import jdk.incubator.foreign.GroupLayout;
import jdk.incubator.foreign.MemoryAddress;
import jdk.incubator.foreign.MemoryLayout;
import jdk.incubator.foreign.MemoryLayouts;
import jdk.incubator.foreign.MemorySegment;
import java.lang.reflect.RecordComponent;
import java.util.Arrays;
import java.util.Map;
public class LayoutMaker {
static final Map<Class<?>, Class<?>> primitiveWrapper = Map.of(
Boolean.class, boolean.class,
Byte.class, byte.class,
Short.class, short.class,
Character.class, char.class,
Integer.class, int.class,
Long.class, long.class,
Float.class, float.class,
Double.class, double.class
);
public static MemoryLayout makeLayout(Class<?> klass) {
klass = primitiveWrapper.getOrDefault(klass, klass);
if (klass.isPrimitive()) {
if (klass == boolean.class) return MemoryLayouts.JAVA_BYTE.withBitAlignment(8);
if (klass == byte.class) return MemoryLayouts.JAVA_BYTE.withBitAlignment(8);
if (klass == short.class) return MemoryLayouts.JAVA_SHORT.withBitAlignment(8);
if (klass == char.class) return MemoryLayouts.JAVA_CHAR.withBitAlignment(8);
if (klass == int.class) return MemoryLayouts.JAVA_INT.withBitAlignment(8);
if (klass == long.class) return MemoryLayouts.JAVA_LONG.withBitAlignment(8);
if (klass == float.class) return MemoryLayouts.JAVA_FLOAT.withBitAlignment(8);
if (klass == double.class) return MemoryLayouts.JAVA_DOUBLE.withBitAlignment(8);
}
else if (klass.isRecord()) {
RecordComponent[] components = klass.getRecordComponents();
MemoryLayout[] elements = new MemoryLayout[components.length];
for (int i = 0; i < components.length; i++) {
elements[i] = makeLayout(components[i].getType()).withName(components[i].getName());
}
return MemoryLayout.ofStruct(elements).withBitAlignment(8);
}
throw new IllegalArgumentException("Only supports primitive and record types:" + klass);
}
private static <T> T readRecord(GroupLayout layout, MemoryAddress base, Class<T> klass, MemoryLayout.PathElement... path) throws Throwable {
RecordComponent[] components = klass.getRecordComponents();
Object[] values = new Object[components.length];
MemoryLayout.PathElement[] elementPath = Arrays.copyOf(path, path.length + 1);
for (int i = 0; i < components.length; i++) {
Class<?> type = primitiveWrapper.getOrDefault(components[i].getType(), components[i].getType());
elementPath[elementPath.length - 1] = MemoryLayout.PathElement.groupElement(components[i].getName());
if (type.isRecord()) {
values[i] = readRecord(layout, base, type, elementPath);
}
else if (type == boolean.class) {
values[i] = (((byte)layout.varHandle(byte.class, elementPath).get(base)) != 0);
}
else {
values[i] = layout.varHandle(type, elementPath).get(base);
}
}
return (T)klass.getDeclaredConstructors()[0].newInstance(values);
}
private static <T> void writeRecord(GroupLayout layout, MemoryAddress base, Record record, MemoryLayout.PathElement... path) throws Throwable {
MemoryLayout.PathElement[] elementPath = Arrays.copyOf(path, path.length + 1);
for (RecordComponent component : record.getClass().getRecordComponents()) {
Class<?> type = primitiveWrapper.getOrDefault(component.getType(), component.getType());
elementPath[elementPath.length - 1] = MemoryLayout.PathElement.groupElement(component.getName());
Object value = component.getAccessor().invoke(record);
if (value == null) {
throw new IllegalArgumentException(record.getClass().getName() + "." + component.getName() + " is null");
}
else if (type.isRecord()) {
writeRecord(layout, base, (Record) value, elementPath);
}
else if (type == boolean.class) {
layout.varHandle(byte.class, elementPath).set(base, (byte) ((Boolean) value ? 1 : 0));
}
else {
layout.varHandle(type, elementPath).set(base, value);
}
}
}
public static <T> T read(MemoryAddress base, Class<T> klass) throws Throwable {
MemoryLayout memoryLayout = makeLayout(klass);
if (memoryLayout instanceof GroupLayout) {
return readRecord((GroupLayout)memoryLayout, base, klass);
}
else if (klass == boolean.class || klass == Boolean.class) {
return (T)(Object)(((byte)memoryLayout.varHandle(byte.class).get(base)) != 0);
}
else {
return (T)memoryLayout.varHandle(klass).get(base);
}
}
public static <T> void write(MemoryAddress base, T value) throws Throwable {
Class<?> klass = value.getClass();
MemoryLayout memoryLayout = makeLayout(klass);
if (memoryLayout instanceof GroupLayout) {
writeRecord((GroupLayout)memoryLayout, base, (Record)value);
}
else {
memoryLayout.varHandle(klass).set(base, value);
}
}
public record Foo(int a, long b, Integer c) {}
public record Bar(Foo one, double d, Foo two, boolean b) {}
public static void main(String[] args) throws Throwable {
MemoryLayout barLayout = makeLayout(Bar.class);
MemoryLayout memoryLayout = MemoryLayout.ofSequence(2, barLayout);
try (MemorySegment segment = MemorySegment.allocateNative(memoryLayout)) {
MemoryAddress base = segment.baseAddress();
write(base, new Bar(new Foo(1, 2, 3), Math.PI, new Foo(6, 7, 8), false));
write(base.addOffset(barLayout.byteSize()), new Bar(new Foo(99, 88, 77), Math.E, new Foo(55, 44, 33), true));
System.out.println(read(base, Bar.class));
System.out.println(read(base.addOffset(barLayout.byteSize()), Bar.class));
// "Random" offset
System.out.println(read(base.addOffset(16), Bar.class));
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment