Created
June 25, 2024 11:05
-
-
Save iseki0/709247f2348df01acb7de33c5f3192ad to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.IOException; | |
import java.io.InputStream; | |
import java.lang.foreign.MemorySegment; | |
import java.lang.foreign.ValueLayout; | |
class MemorySegmentAsInputStream extends InputStream { | |
private final MemorySegment memorySegment; | |
private long pos; | |
MemorySegmentAsInputStream(MemorySegment memorySegment) { | |
this.memorySegment = memorySegment; | |
} | |
private static void doNotAlive() throws IOException { | |
throw new IOException("I'm dead"); | |
} | |
@Override | |
public int read() throws IOException { | |
if (pos >= memorySegment.byteSize()) return -1; | |
try { | |
return memorySegment.get(ValueLayout.JAVA_BYTE, pos++) & 0xff; | |
} catch (IllegalStateException e) { | |
doNotAlive(); | |
return -1; | |
} | |
} | |
@Override | |
public int read(byte[] b, int off, int len) throws IOException { | |
if (len < 0) throw new IllegalArgumentException("len < 0"); | |
if (off < 0) throw new IllegalArgumentException("off < 0"); | |
if (off + len > b.length) throw new IllegalArgumentException("off + len > b.length"); | |
try { | |
if (pos >= memorySegment.byteSize()) return -1; | |
long r = Math.min(memorySegment.byteSize() - pos, len); | |
MemorySegment.copy(memorySegment, ValueLayout.JAVA_BYTE, pos, b, off, len); | |
pos += r; | |
return (int) r; | |
} catch (IllegalStateException e) { | |
doNotAlive(); | |
return -1; | |
} | |
} | |
@Override | |
public long skip(long n) { | |
if (n <= 0) return 0; | |
long min = Math.min(memorySegment.byteSize() - pos, n); | |
pos += min; | |
return min; | |
} | |
@Override | |
public int available() { | |
return (int) Math.min(memorySegment.byteSize() - pos, Integer.MAX_VALUE); | |
} | |
@Override | |
public String toString() { | |
return "MemorySegmentAsInputStream{" + "memorySegment=" + memorySegment + ", pos=" + pos + '}'; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.junit.jupiter.api.Test; | |
import java.io.IOException; | |
import java.lang.foreign.Arena; | |
import java.lang.foreign.MemorySegment; | |
import java.nio.charset.StandardCharsets; | |
import static org.junit.jupiter.api.Assertions.assertEquals; | |
import static org.junit.jupiter.api.Assertions.assertThrows; | |
public class MemorySegmentAsInputStreamTest { | |
@Test | |
public void testRead() throws IOException { | |
String testString = "Hello, World!"; | |
byte[] testData = testString.getBytes(StandardCharsets.UTF_8); | |
var segment = MemorySegment.ofArray(testData); | |
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) { | |
byte[] readData = new byte[testData.length]; | |
int bytesRead = stream.read(readData, 0, testData.length); | |
assertEquals(testData.length, bytesRead); | |
assertEquals(testString, new String(readData, StandardCharsets.UTF_8)); | |
} | |
} | |
@Test | |
public void testReadEmptyStream() throws IOException { | |
MemorySegment segment = MemorySegment.ofArray(new byte[0]); | |
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) { | |
assertEquals(-1, stream.read()); | |
} | |
} | |
@Test | |
public void testReadOneByte() throws IOException { | |
String testString = "H"; | |
byte[] testData = testString.getBytes(StandardCharsets.UTF_8); | |
var segment = MemorySegment.ofArray(testData); | |
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) { | |
int readData = stream.read(); | |
assertEquals(testData[0], readData); | |
} | |
} | |
@Test | |
public void testSkip() throws IOException { | |
String testString = "Hello, World!"; | |
byte[] testData = testString.getBytes(StandardCharsets.UTF_8); | |
var segment = MemorySegment.ofArray(testData); | |
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) { | |
long skipped = stream.skip(6); | |
assertEquals(6, skipped); | |
byte[] readData = new byte[testData.length - 6]; | |
int bytesRead = stream.read(readData, 0, testData.length - 6); | |
assertEquals(testData.length - 6, bytesRead); | |
assertEquals(testString.substring(6), new String(readData, StandardCharsets.UTF_8)); | |
} | |
} | |
@Test | |
public void testAvailable() throws IOException { | |
String testString = "Hello, World!"; | |
byte[] testData = testString.getBytes(StandardCharsets.UTF_8); | |
var segment = MemorySegment.ofArray(testData); | |
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) { | |
assertEquals(testData.length, stream.available()); | |
} | |
} | |
@Test | |
public void testEOF() throws IOException { | |
String testString = "H"; | |
byte[] testData = testString.getBytes(StandardCharsets.UTF_8); | |
var segment = MemorySegment.ofArray(testData); | |
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) { | |
int readData = stream.read(); | |
assertEquals(testData[0], readData); | |
// Read again to reach EOF | |
readData = stream.read(); | |
assertEquals(-1, readData); // -1 indicates EOF | |
} | |
} | |
@Test | |
public void testBadIO() { | |
assertThrows(IOException.class, () -> { | |
var arena = Arena.ofConfined(); | |
var segment = arena.allocate(1); | |
arena.close(); // Close the segment to simulate bad I/O | |
try (MemorySegmentAsInputStream stream = new MemorySegmentAsInputStream(segment)) { | |
stream.read(); // This should throw an IOException | |
} | |
}); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment