Skip to content

Instantly share code, notes, and snippets.

@pfmiles
Last active August 29, 2015 14:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pfmiles/41676da3f3441b8af4b3 to your computer and use it in GitHub Desktop.
Save pfmiles/41676da3f3441b8af4b3 to your computer and use it in GitHub Desktop.
Resettable input stream/Tee input stream/with configurable transparent file cache/可重复读取的inputStream实现, 带透明文件缓存
package test;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import javax.servlet.ServletInputStream;
/**
* 将读取的字节进行缓存以便多次读取的stream, 即resettable
*
* @author pf-miles
* @since 2014-12-12
*/
public class TeeServletInputStream extends ServletInputStream {
private static final int DEFAULT_MAX_BYTES_IN_MEM = 1048576;
// reading wrapping stream while make backup
private static final int PEND_READ = 0;
// read buf first, turn to 'PEND_READ' when buf is read up
private static final int BUF_READ = 1;
// wrapping stream finished reading, all operations will be made on buf
private static final int ALL_BUF = 2;
// 最大保存在内存中的字节数,超过此数目将启用文件缓存
private int maxBytesInMem;
// the wrapping stream
private InputStream stream;
private int state = PEND_READ;
// mem buf
private List<Byte> memBytes = new ArrayList<>();
// file buf
private FileOutputStream fileOut = null;
private FileInputStream fileIn = null;
// current reading position
private int pos;
// total num of bytes
private int total;
// the backing file of the file buf
private File tmpFile = null;
public int read() throws IOException {
startRead: while (true) {
switch (this.state) {
case PEND_READ:
int b = this.stream.read();
if (b == -1) {
state = ALL_BUF;
total = pos;
this.stream.close();
} else {
appendToBackup((byte) b);
pos++;
}
return b;
case BUF_READ:
if (pos < maxBytesInMem) {
// should read in mem buf
if (pos > memBytes.size() - 1) {
// mem buf read finished
state = PEND_READ;
continue startRead;// goto: startRead
} else {
b = memBytes.get(pos);
pos++;
}
} else {
// should read from file buf
if (tmpFile == null) {
// file buf end case 1: no tmpFile at all
state = PEND_READ;
continue startRead;// goto: startRead
} else {
if (this.fileIn == null) {
if (this.fileOut != null) {
// ensure any buffered bytes are written to disk
this.fileOut.flush();
this.fileOut.getFD().sync();
}
this.fileIn = new FileInputStream(this.tmpFile);
}
b = this.fileIn.read();
if (b == -1) {
// file buf end case 2: file buf reaches the end
state = PEND_READ;
this.fileIn.close();
this.fileIn = null;
continue startRead;// goto: startRead
} else {
pos++;
}
}
}
return b;
case ALL_BUF:
if (pos >= total) {
return -1;
} else if (pos < maxBytesInMem) {
b = this.memBytes.get(pos);
} else {
if (this.fileIn == null) this.fileIn = new FileInputStream(this.tmpFile);
b = this.fileIn.read();
}
pos++;
return b;
default:
throw new RuntimeException("Impossible.");
}
}
}
private void appendToBackup(byte b) throws IOException {
if (memBytes.size() >= maxBytesInMem) {
// 向文件缓存写入
if (this.fileOut == null) {
this.tmpFile = File.createTempFile("TisTmp_", ".buf");
this.tmpFile.deleteOnExit();
this.fileOut = new FileOutputStream(this.tmpFile, true);
}
this.fileOut.write(b);
} else {
// 向内存缓存写入
memBytes.add(b);
}
}
/**
* 最大内存缓存设置为默认1MB
*/
public TeeServletInputStream(InputStream stream){
this(stream, DEFAULT_MAX_BYTES_IN_MEM);
}
/**
* 包装一个inputStream, 使其成为resettable stream
*
* @param stream
* @param maxBytesInMem 设置可在内存中缓存的最大bytes数目, 若读取的bytes超过此数目则自动启用透明文件缓存
*/
public TeeServletInputStream(InputStream stream, int maxBytesInMem){
if (stream == null) throw new NullPointerException("Stream must not be null.");
this.stream = stream;
this.maxBytesInMem = maxBytesInMem < 0 ? DEFAULT_MAX_BYTES_IN_MEM : maxBytesInMem;
}
@Override
public int available() throws IOException {
switch (this.state) {
case PEND_READ:
return this.stream.available();
case BUF_READ:
int exp = this.memBytes.size() + this.stream.available();
if (this.tmpFile != null) exp += this.tmpFile.length();
return exp - pos;
case ALL_BUF:
return total - pos;
default:
throw new RuntimeException("Impossible.");
}
}
@Override
public void close() throws IOException {
this.stream.close();
if (this.fileIn != null) this.fileIn.close();
if (this.fileOut != null) this.fileOut.close();
if (this.tmpFile != null) this.tmpFile.delete();
}
@Override
public void reset() throws IOException {
switch (this.state) {
case PEND_READ:
this.state = BUF_READ;
this.pos = 0;
break;
case BUF_READ:
case ALL_BUF:
this.pos = 0;
if (this.fileIn != null) {
this.fileIn.close();
this.fileIn = null;
}
break;
default:
throw new RuntimeException("Impossible.");
}
}
protected void finalize() throws Throwable {
super.finalize();
this.close();
}
}
package test;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import org.junit.Assert;
import org.junit.Test;
/**
* @author pf-miles
* @since 2014-12-13
*/
@SuppressWarnings("unchecked")
public class TeeServletInputStreamTest {
private TeeServletInputStream newTis() {
ByteArrayInputStream bis = new ByteArrayInputStream(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7 });
return new TeeServletInputStream(bis, 3);
}
private void assertStates(TeeServletInputStream tis, int state, List<Byte> memBytes, boolean fos, boolean fis,
int pos, int total, boolean tmpFile, int tmpFileLength) throws Exception {
int stateVal = getField(tis, "state");
Assert.assertTrue(stateVal == state);
List<Byte> memBytesVal = getField(tis, "memBytes");
if (memBytesVal == null) {
Assert.assertTrue(memBytesVal == memBytes);
} else {
Assert.assertTrue(memBytesVal.equals(memBytes));
}
boolean fosVal = getField(tis, "fileOut") != null;
Assert.assertTrue(fosVal == fos);
boolean fisVal = getField(tis, "fileIn") != null;
Assert.assertTrue(fisVal == fis);
int posVal = getField(tis, "pos");
Assert.assertTrue(posVal == pos);
int totalVal = getField(tis, "total");
Assert.assertTrue(totalVal == total);
boolean tmpFileVal = getField(tis, "tmpFile") != null;
Assert.assertTrue(tmpFileVal == tmpFile);
if (tmpFileVal) {
File f = getField(tis, "tmpFile");
Assert.assertTrue(f.length() == tmpFileLength);
}
}
private <T> T getField(TeeServletInputStream tis, String fn) throws Exception {
Field f = TeeServletInputStream.class.getDeclaredField(fn);
f.setAccessible(true);
return (T) f.get(tis);
}
/*
* pending read
*/
@Test
public void testPendingRead() throws Exception {
TeeServletInputStream tis = newTis();
// below max-bytes-in-mem
int i = tis.read();
Assert.assertTrue(i == 0);
i = tis.read();
Assert.assertTrue(i == 1);
this.assertStates(tis, 0, Arrays.asList(new Byte[] { 0, 1 }), false, false, 2, 0, false, 0);
// on max-bytes-in-mem
tis = newTis();
i = tis.read();
Assert.assertTrue(i == 0);
i = tis.read();
Assert.assertTrue(i == 1);
i = tis.read();
Assert.assertTrue(i == 2);
this.assertStates(tis, 0, Arrays.asList(new Byte[] { 0, 1, 2 }), false, false, 3, 0, false, 0);
// exceeds max-bytes-in-mem
tis = newTis();
for (int j = 0; j < 4; j++) {
int a = tis.read();
Assert.assertTrue(j == a);
}
this.assertStates(tis, 0, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 4, 0, true, 1);
// read to the end, EOS(-1) not read
tis = newTis();
for (int j = 0; j < 8; j++) {
int a = tis.read();
Assert.assertTrue(j == a);
}
this.assertStates(tis, 0, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 8, 0, true, 5);
i = tis.read();
Assert.assertTrue(i == -1);
this.assertStates(tis, 2, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 8, 8, true, 5);
// read exceeds the end
i = tis.read();
Assert.assertTrue(i == -1);
this.assertStates(tis, 2, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 8, 8, true, 5);
}
/*
* buf read
*/
@Test
public void testBufRead() throws Exception {
// zero buf read
TeeServletInputStream tis = newTis();
tis.reset();
int a = tis.read();
Assert.assertTrue(a == 0);
this.assertStates(tis, 0, Arrays.asList(new Byte[] { 0 }), false, false, 1, 0, false, 0);
// below max-bytes-in-mem
tis = newTis();
for (int i = 0; i < 2; i++) {
Assert.assertTrue(i == tis.read());
}
tis.reset();
this.assertStates(tis, 1, Arrays.asList(new Byte[] { 0, 1 }), false, false, 0, 0, false, 0);
for (int i = 0; i < 2; i++) {
Assert.assertTrue(i == tis.read());
}
this.assertStates(tis, 1, Arrays.asList(new Byte[] { 0, 1 }), false, false, 2, 0, false, 0);
Assert.assertTrue(2 == tis.read());
this.assertStates(tis, 0, Arrays.asList(new Byte[] { 0, 1, 2 }), false, false, 3, 0, false, 0);
Assert.assertTrue(3 == tis.read());
// on max-bytes-in-mem
tis = newTis();
for (int i = 0; i < 3; i++) {
Assert.assertTrue(i == tis.read());
}
tis.reset();
this.assertStates(tis, 1, Arrays.asList(new Byte[] { 0, 1, 2 }), false, false, 0, 0, false, 0);
for (int i = 0; i < 3; i++) {
Assert.assertTrue(i == tis.read());
}
this.assertStates(tis, 1, Arrays.asList(new Byte[] { 0, 1, 2 }), false, false, 3, 0, false, 0);
Assert.assertTrue(3 == tis.read());
this.assertStates(tis, 0, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 4, 0, true, 1);
Assert.assertTrue(4 == tis.read());
// exceeds max-bytes-in-mem
tis = newTis();
for (int i = 0; i < 4; i++) {
Assert.assertTrue(i == tis.read());
}
tis.reset();
this.assertStates(tis, 1, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 0, 0, true, 1);
for (int i = 0; i < 4; i++) {
Assert.assertTrue(i == tis.read());
}
this.assertStates(tis, 1, Arrays.asList(new Byte[] { 0, 1, 2 }), true, true, 4, 0, true, 1);
Assert.assertTrue(4 == tis.read());
this.assertStates(tis, 0, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 5, 0, true, 2);
Assert.assertTrue(5 == tis.read());
// full read reset
tis = newTis();
for (int i = 0; i < 8; i++) {
Assert.assertTrue(i == tis.read());
}
this.assertStates(tis, 0, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 8, 0, true, 5);
Assert.assertTrue(-1 == tis.read());
this.assertStates(tis, 2, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 8, 8, true, 5);
tis.reset();
this.assertStates(tis, 2, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 0, 8, true, 5);
for (int i = 0; i < 8; i++) {
Assert.assertTrue(i == tis.read());
}
this.assertStates(tis, 2, Arrays.asList(new Byte[] { 0, 1, 2 }), true, true, 8, 8, true, 5);
Assert.assertTrue(-1 == tis.read());
}
/*
* buf read reset
*/
@Test
public void testBufReadReset() throws Exception {
// below max-bytes-in-mem reset
TeeServletInputStream tis = newTis();
for (int i = 0; i < 2; i++) {
Assert.assertTrue(i == tis.read());
}
tis.reset();
for (int i = 0; i < 2; i++) {
Assert.assertTrue(i == tis.read());
}
tis.reset();
this.assertStates(tis, 1, Arrays.asList(new Byte[] { 0, 1 }), false, false, 0, 0, false, 0);
// on max-bytes-in-mem reset
for (int i = 0; i < 3; i++) {
Assert.assertTrue(i == tis.read());
}
tis.reset();
this.assertStates(tis, 1, Arrays.asList(new Byte[] { 0, 1, 2 }), false, false, 0, 0, false, 0);
for (int i = 0; i < 4; i++) {
Assert.assertTrue(i == tis.read());
}
this.assertStates(tis, 0, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 4, 0, true, 1);
// exceeds max-bytes-in-mem reset
tis = newTis();
for (int i = 0; i < 4; i++) {
Assert.assertTrue(i == tis.read());
}
tis.reset();
for (int i = 0; i < 4; i++) {
Assert.assertTrue(i == tis.read());
}
tis.reset();
this.assertStates(tis, 1, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 0, 0, true, 1);
for (int i = 0; i < 4; i++) {
Assert.assertTrue(i == tis.read());
}
this.assertStates(tis, 1, Arrays.asList(new Byte[] { 0, 1, 2 }), true, true, 4, 0, true, 1);
Assert.assertTrue(4 == tis.read());
this.assertStates(tis, 0, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 5, 0, true, 2);
}
/*
* all buf reset
*/
@Test
public void allBufRead() throws Exception {
TeeServletInputStream tis = newTis();
for (int i = 0; i < 8; i++) {
Assert.assertTrue(i == tis.read());
}
Assert.assertTrue(-1 == tis.read());
tis.reset();
this.assertStates(tis, 2, Arrays.asList(new Byte[] { 0, 1, 2 }), true, false, 0, 8, true, 5);
for (int i = 0; i < 8; i++) {
Assert.assertTrue(i == tis.read());
}
Assert.assertTrue(-1 == tis.read());
this.assertStates(tis, 2, Arrays.asList(new Byte[] { 0, 1, 2 }), true, true, 8, 8, true, 5);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment