Skip to content

Instantly share code, notes, and snippets.

@bnorm
Created February 14, 2018 21:36
Show Gist options
  • Save bnorm/1295b68b7ebe498ba829136b5b70fa1b to your computer and use it in GitHub Desktop.
Save bnorm/1295b68b7ebe498ba829136b5b70fa1b to your computer and use it in GitHub Desktop.
Example CipherSource and CipherSink
import java.io.IOException;
import java.net.ProtocolException;
import java.security.GeneralSecurityException;
import java.util.Random;
import javax.crypto.Cipher;
import javax.crypto.ShortBufferException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import okio.Buffer;
import okio.BufferedSink;
import okio.BufferedSource;
import okio.ByteString;
import okio.Okio;
import okio.Sink;
import okio.Source;
import okio.Timeout;
public final class Ciphers {
public void run() throws Exception {
ByteString key = ByteString.encodeUtf8("Bar12345Bar12345"); // 128 bit key
ByteString initVector = ByteString.encodeUtf8("RandomInitVector"); // 16 bytes IV
Cipher encrypt = createCipher(Cipher.ENCRYPT_MODE, key, initVector, "AES", "CBC");
Cipher decrypt = createCipher(Cipher.DECRYPT_MODE, key, initVector, "AES", "CBC");
Random random = new Random(0);
Buffer buffer = new Buffer();
int split = 15;
for (int i = 0; i < split; i++) {
buffer.writeByte(i);
}
buffer = buffer.clone();
for (int i = split; i < 8192 * encrypt.getBlockSize(); i++) {
if (random.nextInt(8192) == 0) {
buffer = buffer.clone();
}
buffer.writeByte(i);
}
ByteString original = buffer.snapshot();
System.out.println("Original : " + original);
Buffer encrypted = new Buffer();
try (BufferedSink sink = Okio.buffer(new CipherSink(encrypted, encrypt))) {
sink.write(original);
}
System.out.println("Encrypted : " + encrypted.snapshot());
ByteString decrypted;
try (BufferedSource source = Okio.buffer(new CipherSource(encrypted, decrypt))) {
decrypted = source.readByteString();
}
System.out.println("Decrypted : " + decrypted);
}
Cipher createCipher(int mode, ByteString key, ByteString initVector, String algorithm, String transform) throws GeneralSecurityException {
IvParameterSpec iv = new IvParameterSpec(initVector.toByteArray());
SecretKeySpec keySpec = new SecretKeySpec(key.toByteArray(), algorithm);
Cipher cipher = Cipher.getInstance(algorithm + "/" + transform + "/NoPadding");
cipher.init(mode, keySpec, iv);
return cipher;
}
/**
* Drain {@code byteCount} bytes from {@code source}, run them through the cipher, and write all
* processed output into {@code sink}.
*/
static void process(Cipher cipher,
Buffer source,
Buffer.UnsafeCursor sourceCursor,
long byteCount,
Buffer sink,
Buffer.UnsafeCursor sinkCursor) throws IOException {
if (byteCount < 0) throw new IllegalArgumentException("byteCount < 0: " + byteCount);
if (byteCount > source.size()) {
throw new IllegalArgumentException("size=" + source.size() + " byteCount=" + byteCount);
}
source.readUnsafe(sourceCursor);
sink.readAndWriteUnsafe(sinkCursor);
try {
sourceCursor.seek(0);
long remaining = byteCount;
while (remaining > 0) {
int inputSize = (int) Math.min(sourceCursor.end - sourceCursor.start, remaining);
int outputSize = cipher.getOutputSize(inputSize);
if (cipher.getBlockSize() > 0) {
// Block ciphers output data in BlockSize chunks
outputSize -= outputSize % cipher.getBlockSize();
}
if (outputSize > 8192) {
throw new AssertionError(
String.format("existing=%d blockSize=%d inputSize=%d outputSize=%s",
cipher.getOutputSize(0), cipher.getBlockSize(), inputSize, outputSize));
}
if (outputSize == 0) {
// No output, but add data to the cipher
cipher.update(
sourceCursor.data,
sourceCursor.start,
inputSize);
} else {
long oldSize = sink.size();
sinkCursor.expandBuffer(outputSize);
int update = cipher.update(
sourceCursor.data,
sourceCursor.start,
inputSize,
sinkCursor.data,
sinkCursor.start);
sinkCursor.resizeBuffer(oldSize + update);
}
sourceCursor.seek(sourceCursor.offset + inputSize);
remaining -= inputSize;
}
} catch (ShortBufferException e) {
throw new AssertionError(e);
} finally {
sourceCursor.close();
sinkCursor.close();
}
source.skip(byteCount);
}
final class CipherSink implements Sink {
private final BufferedSink sink;
private final Cipher cipher;
private final Buffer.UnsafeCursor sourceCursor = new Buffer.UnsafeCursor();
private final Buffer.UnsafeCursor sinkCursor = new Buffer.UnsafeCursor();
private boolean closed;
CipherSink(Sink sink, Cipher cipher) {
this(Okio.buffer(sink), cipher);
}
CipherSink(BufferedSink sink, Cipher cipher) {
if (cipher == null) throw new IllegalArgumentException("process == null");
if (!cipher.getAlgorithm().contains("NoPadding"))
throw new IllegalArgumentException(cipher.getAlgorithm());
this.cipher = cipher;
this.sink = sink;
}
@Override
public void write(Buffer source, long byteCount) throws IOException {
if (byteCount < 0) throw new IllegalArgumentException("byteCount < 0: " + byteCount);
if (closed) throw new IllegalStateException("closed");
process(cipher, source, sourceCursor, byteCount, sink.buffer(), sinkCursor);
sink.emitCompleteSegments();
}
@Override
public void flush() throws IOException {
if (closed) throw new IllegalStateException("closed");
sink.flush();
}
@Override
public void close() throws IOException {
if (closed) return;
closed = true;
try {
if (cipher.getOutputSize(0) > 0) {
throw new ProtocolException(String.format("blockSize=%d unprocessed=%d",
cipher.getBlockSize(), cipher.getOutputSize(0)));
}
} finally {
sink.close();
}
}
@Override
public Timeout timeout() {
return sink.timeout();
}
}
final class CipherSource implements Source {
private final BufferedSource source;
private final Cipher cipher;
private final Buffer ciphered = new Buffer();
private final Buffer.UnsafeCursor sourceCursor = new Buffer.UnsafeCursor();
private final Buffer.UnsafeCursor sinkCursor = new Buffer.UnsafeCursor();
private boolean closed;
CipherSource(Source source, Cipher cipher) {
this(Okio.buffer(source), cipher);
}
CipherSource(BufferedSource source, Cipher cipher) {
if (cipher == null) throw new IllegalArgumentException("process == null");
if (!cipher.getAlgorithm().contains("NoPadding"))
throw new IllegalArgumentException(cipher.getAlgorithm());
this.cipher = cipher;
this.source = source;
}
@Override
public long read(Buffer sink, long byteCount) throws IOException {
if (byteCount < 0) throw new IllegalArgumentException("byteCount < 0: " + byteCount);
if (closed) throw new IllegalStateException("closed");
if (byteCount == 0) return 0;
if (ciphered.size() >= byteCount) return ciphered.read(sink, byteCount);
refill(byteCount);
process(cipher, source.buffer(), sourceCursor, source.buffer().size(), ciphered, sinkCursor);
if (source.exhausted() && ciphered.exhausted() && cipher.getOutputSize(0) > 0) {
throw new ProtocolException(String.format("blockSize=%d unprocessed=%d",
cipher.getBlockSize(), cipher.getOutputSize(0)));
}
return ciphered.read(sink, byteCount);
}
private boolean refill(long byteCount) throws IOException {
// how many ciphered bytes do we need to fulfill the request
long needed = byteCount - ciphered.size(); // subtract how many we currently have
if (cipher.getBlockSize() > 0) {
// pad to process block size
long remainder = needed % cipher.getBlockSize();
if (remainder > 0) {
needed += cipher.getBlockSize() - remainder;
}
}
return source.request(needed);
}
@Override
public void close() throws IOException {
if (closed) return;
source.close();
closed = true;
}
@Override
public Timeout timeout() {
return source.timeout();
}
}
public static void main(String... args) throws Exception {
new Ciphers().run();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment