Skip to content

Instantly share code, notes, and snippets.

@brianm
Last active March 20, 2018 14:28
Show Gist options
  • Save brianm/bd138438d8c9abe3b84ab36eb0909975 to your computer and use it in GitHub Desktop.
Save brianm/bd138438d8c9abe3b84ab36eb0909975 to your computer and use it in GitHub Desktop.
extern crate base64;
use std::fmt;
use std::io::{Read, Result};
pub struct Base64Decoder<'a> {
r: &'a mut Read,
config: base64::Config,
extra: [u8; 3],
extra_len: usize,
}
impl<'a> fmt::Debug for Base64Decoder<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "extra: {:?}, extra_len: {}", &self.extra, self.extra_len)
}
}
impl<'a> Base64Decoder<'a> {
fn new(r: &'a mut Read, config: base64::Config) -> Base64Decoder<'a> {
Base64Decoder {
r,
config,
extra: [0u8; 3],
extra_len: 0,
}
}
}
impl<'a> Read for Base64Decoder<'a> {
fn read(&mut self, out: &mut [u8]) -> Result<usize> {
eprintln!("read");
eprintln!(" self: {:?}", self);
eprintln!(" out.len: {}", out.len());
// new approach, just read 4 bytes a time
if self.extra_len > 0 {
// we have leftovers from last time!
if out.len() >= self.extra_len {
eprintln!(" out.len >= extra_len");
// enough space, let's finish up!
out[..self.extra_len].copy_from_slice(&self.extra[..self.extra_len]);
let n = self.extra_len;
self.extra_len = 0;
eprintln!(" returning Ok({})", n);
return Ok(n);
} else /* we know out.len() < extra_len */ {
eprintln!(" out.len < ecra_len");
// copy as much as we can into out
let n = out.len();
out.copy_from_slice(&self.extra[..n]);
self.extra_len = self.extra_len - n;
eprintln!(" self.extra_len: {}, n: {}", self.extra_len, n);
// okay, now we need to jiggle self.extra to hold the rest
let mut scratch = [0u8;3];
scratch.copy_from_slice(&self.extra);
self.extra[.. self.extra_len].copy_from_slice(&scratch[n .. n + self.extra_len]);
eprintln!(" returning Ok({})", n);
return Ok(n);
}
}
let mut n = 0;
let mut in_buf = [0u8; 4];
let mut out_buf = [0u8; 3];
let chunks = if out.len() < 3 {
1
} else {
out.len() / 3
};
for i in 0..chunks {
let mut sz = self.r.read(&mut in_buf)?;
eprintln!(" chunk i: {}", i);
eprintln!(" read sz: {}", sz);
while sz != 4 {
// we got either a partial read or input is unpadded
// so we keep trying to fill our 4 byte buffer until it
// fills or or we get a read of 0, indicating EOF
let nxt_sz = self.r.read(&mut in_buf[sz..])?;
eprintln!(" extra read sz: {}", nxt_sz);
if nxt_sz == 0 {
// EOF, assume input is not padded
break;
}
sz += nxt_sz;
}
if sz == 0 {
// we reached the end of file, we're done!
eprintln!(" think we've hit EOF!");
eprintln!(" returning Ok(0)");
return Ok(0);
}
// TODO convert to io result instead of unwrap
let decoded_sz = base64::decode_config_slice(&in_buf, self.config, &mut out_buf).unwrap();
if out.len() >= n + decoded_sz {
out[i * 3..(i * 3) + 3].copy_from_slice(&out_buf);
n += decoded_sz;
} else {
let out_remaining_sz = out.len() - n;
let start_offset = i * 3;
let end_offset = (i * 3) + out_remaining_sz;
out[start_offset .. end_offset].copy_from_slice(&out_buf[..out_remaining_sz]);
self.extra[..3 - out_remaining_sz].copy_from_slice(&out_buf[out_remaining_sz..]);
self.extra_len = 3 - out_remaining_sz;
n += out_remaining_sz;
}
}
eprintln!(" returning Ok({})", n);
Ok(n)
}
}
#[cfg(test)]
mod tests {
use base64;
use decoder::Base64Decoder;
use std::io::{Cursor, Read};
#[test]
fn decode_3_from_3() {
let mut encoded = String::new();
base64::encode_config_buf("abc", base64::URL_SAFE, &mut encoded);
let mut c = Cursor::new(&mut encoded);
let mut dec = Base64Decoder::new(&mut c, base64::URL_SAFE);
let mut buf = [0u8, 0u8, 0u8];
let sz = dec.read(&mut buf).unwrap();
assert_eq!(sz, 3);
assert_eq!(&buf, b"abc")
}
#[test]
fn decode_2_from_3() {
let mut encoded = String::new();
base64::encode_config_buf("abc", base64::URL_SAFE, &mut encoded);
let mut c = Cursor::new(&mut encoded);
let mut dec = Base64Decoder::new(&mut c, base64::URL_SAFE);
let mut buf = [0u8, 0u8];
let sz = dec.read(&mut buf).unwrap();
assert_eq!(sz, 2);
assert_eq!(&buf, b"ab");
}
#[test]
fn decode_3_3_from_6() {
let mut encoded = String::new();
base64::encode_config_buf("abcdef", base64::URL_SAFE, &mut encoded);
let mut c = Cursor::new(&mut encoded);
let mut dec = Base64Decoder::new(&mut c, base64::URL_SAFE);
let mut buf = [0u8, 0u8, 0u8];
let sz = dec.read(&mut buf).unwrap();
assert_eq!(sz, 3);
assert_eq!(&buf, b"abc");
let sz = dec.read(&mut buf).unwrap();
assert_eq!(sz, 3);
assert_eq!(&buf, b"def");
}
#[test]
fn decode_4_4_from_6() {
let mut encoded = String::new();
base64::encode_config_buf("abcdef", base64::URL_SAFE, &mut encoded);
let mut c = Cursor::new(&mut encoded);
let mut dec = Base64Decoder::new(&mut c, base64::URL_SAFE);
let mut buf = [0u8, 0u8, 0u8, 0u8];
let sz = dec.read(&mut buf).unwrap();
assert_eq!(sz, 3);
assert_eq!(buf[..sz], b"abc"[..]);
let sz = dec.read(&mut buf).unwrap();
assert_eq!(sz, 3);
assert_eq!(buf[..sz], b"def"[..]);
}
#[test]
fn decode_1_4_1_from_6() {
let mut encoded = String::new();
base64::encode_config_buf(&[0,1,2,3,4,5, 6], base64::URL_SAFE, &mut encoded);
let mut c = Cursor::new(&mut encoded);
let mut dec = Base64Decoder::new(&mut c, base64::URL_SAFE);
let mut buf = [0u8];
let sz = dec.read(&mut buf).unwrap();
assert_eq!(sz, 1);
assert_eq!(buf[..sz], [0]);
let mut buf = [0, 0, 0, 0];
let sz = dec.read(&mut buf).unwrap();
assert_eq!(sz, 2); // only read 2, the leftovers from before
assert_eq!(buf[..sz], [1,2]);
let sz = dec.read(&mut buf).unwrap();
assert_eq!(sz, 3);
assert_eq!(buf[..sz], [3, 4, 5]);
let sz = dec.read(&mut buf).unwrap();
assert_eq!(sz, 1);
assert_eq!(buf[..sz], [6]);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment