Skip to content

Instantly share code, notes, and snippets.

@iseki0
Created June 25, 2024 11:05
Show Gist options
  • Save iseki0/709247f2348df01acb7de33c5f3192ad to your computer and use it in GitHub Desktop.
Save iseki0/709247f2348df01acb7de33c5f3192ad to your computer and use it in GitHub Desktop.
import java.io.IOException;
import java.io.InputStream;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
class MemorySegmentAsInputStream extends InputStream {
private final MemorySegment memorySegment;
private long pos;
MemorySegmentAsInputStream(MemorySegment memorySegment) {
this.memorySegment = memorySegment;
}
private static void doNotAlive() throws IOException {
throw new IOException("I'm dead");
}
@Override
public int read() throws IOException {
if (pos >= memorySegment.byteSize()) return -1;
try {
return memorySegment.get(ValueLayout.JAVA_BYTE, pos++) & 0xff;
} catch (IllegalStateException e) {
doNotAlive();
return -1;
}
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
if (len < 0) throw new IllegalArgumentException("len < 0");
if (off < 0) throw new IllegalArgumentException("off < 0");
if (off + len > b.length) throw new IllegalArgumentException("off + len > b.length");
try {
if (pos >= memorySegment.byteSize()) return -1;
long r = Math.min(memorySegment.byteSize() - pos, len);
MemorySegment.copy(memorySegment, ValueLayout.JAVA_BYTE, pos, b, off, len);
pos += r;
return (int) r;
} catch (IllegalStateException e) {
doNotAlive();
return -1;
}
}
@Override
public long skip(long n) {
if (n <= 0) return 0;
long min = Math.min(memorySegment.byteSize() - pos, n);
pos += min;
return min;
}
@Override
public int available() {
return (int) Math.min(memorySegment.byteSize() - pos, Integer.MAX_VALUE);
}
@Override
public String toString() {
return "MemorySegmentAsInputStream{" + "memorySegment=" + memorySegment + ", pos=" + pos + '}';
}
}
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.charset.StandardCharsets;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class MemorySegmentAsInputStreamTest {
@Test
public void testRead() throws IOException {
String testString = "Hello, World!";
byte[] testData = testString.getBytes(StandardCharsets.UTF_8);
var segment = MemorySegment.ofArray(testData);
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) {
byte[] readData = new byte[testData.length];
int bytesRead = stream.read(readData, 0, testData.length);
assertEquals(testData.length, bytesRead);
assertEquals(testString, new String(readData, StandardCharsets.UTF_8));
}
}
@Test
public void testReadEmptyStream() throws IOException {
MemorySegment segment = MemorySegment.ofArray(new byte[0]);
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) {
assertEquals(-1, stream.read());
}
}
@Test
public void testReadOneByte() throws IOException {
String testString = "H";
byte[] testData = testString.getBytes(StandardCharsets.UTF_8);
var segment = MemorySegment.ofArray(testData);
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) {
int readData = stream.read();
assertEquals(testData[0], readData);
}
}
@Test
public void testSkip() throws IOException {
String testString = "Hello, World!";
byte[] testData = testString.getBytes(StandardCharsets.UTF_8);
var segment = MemorySegment.ofArray(testData);
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) {
long skipped = stream.skip(6);
assertEquals(6, skipped);
byte[] readData = new byte[testData.length - 6];
int bytesRead = stream.read(readData, 0, testData.length - 6);
assertEquals(testData.length - 6, bytesRead);
assertEquals(testString.substring(6), new String(readData, StandardCharsets.UTF_8));
}
}
@Test
public void testAvailable() throws IOException {
String testString = "Hello, World!";
byte[] testData = testString.getBytes(StandardCharsets.UTF_8);
var segment = MemorySegment.ofArray(testData);
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) {
assertEquals(testData.length, stream.available());
}
}
@Test
public void testEOF() throws IOException {
String testString = "H";
byte[] testData = testString.getBytes(StandardCharsets.UTF_8);
var segment = MemorySegment.ofArray(testData);
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) {
int readData = stream.read();
assertEquals(testData[0], readData);
// Read again to reach EOF
readData = stream.read();
assertEquals(-1, readData); // -1 indicates EOF
}
}
@Test
public void testBadIO() {
assertThrows(IOException.class, () -> {
var arena = Arena.ofConfined();
var segment = arena.allocate(1);
arena.close(); // Close the segment to simulate bad I/O
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) {
stream.read(); // This should throw an IOException
}
});
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment