Skip to content

Instantly share code, notes, and snippets.

@benaubin
Created January 26, 2021 21:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benaubin/698f7a59fd2f8a8db72167ed13f4cf6f to your computer and use it in GitHub Desktop.
Save benaubin/698f7a59fd2f8a8db72167ed13f4cf6f to your computer and use it in GitHub Desktop.
use std::{borrow::Cow, sync::{Arc, Weak}};
use std::io::{Read, Write, IoSlice, IoSliceMut};
const FRAME_BUFFER_MAX_LEN: usize = 8 * 1024;
pub struct FrameChannel {
queue: crossbeam::queue::SegQueue<Vec<u8>>,
waker: mio::Waker
}
impl FrameChannel {
pub fn new<T>(stream: T, waker: mio::Waker) -> (Arc<FrameChannel>, FramedIO<T>) {
let channel = Arc::new(FrameChannel {
queue: crossbeam::queue::SegQueue::new(),
waker
});
(
channel,
FramedIO {
stream,
channel: Arc::downgrade(channel),
out_in_progress_frame: None,
out_bytes_written: 0,
read_state: FrameReadState::HeaderRead { header_buf: [0; 4], header_bytes_read: 0 }
}
)
}
pub fn send(&self, frame: Vec<u8>) {
self.queue.push(frame);
self.waker.wake().expect("failed to wake writer");
}
}
impl Drop for FrameChannel {
fn drop(&mut self) {
self.waker.wake().expect("failed to wake writer")
}
}
struct FramedIO<T> {
stream: T,
channel: Weak<FrameChannel>,
out_bytes_written: usize,
out_in_progress_frame: Option<Vec<u8>>,
read_state: FrameReadState
}
enum FrameReadState {
HeaderRead {
header_buf: [u8; 4],
header_bytes_read: usize
},
BufferedFrameRead {
buf: [u8; 8 * 1024],
frame_start: usize,
frame_bytes_read: usize,
frame_len: usize,
},
LargeFrameRead {
frame: Vec<u8>,
frame_bytes_read: usize
}
}
impl<T: Read> FramedIO<T> {
/// Attempt to read the next frame from the stream. Returns Ok(None) if no more bytes are available
fn read(&mut self) -> std::io::Result<Option<Cow<[u8]>>> {
loop {
match self.read_state {
FrameReadState::HeaderRead { header_buf, ref mut header_bytes_read } => {
let mut overflow_buf = [0; 8 * 1024];
while 4 > header_bytes_read {
let bytes_read = self.stream.read_vectored(&mut [
IoSliceMut::new(&mut header_buf[header_bytes_read..]),
IoSliceMut::new(&mut overflow_buf)
])?;
if bytes_read == 0 { return Ok(None) }
header_bytes_read += bytes_read; // no bytes available right now
}
self.read_state = FrameReadState::BufferedFrameRead {
buf: overflow_buf,
frame_start: 0,
frame_bytes_read: header_bytes_read - 4,
frame_len: u32::from_le_bytes(header_buf)
};
}
FrameReadState::BufferedFrameRead {
ref mut buf,
frame_bytes_read,
frame_len,
..
} if frame_len > 8 * 1024 => {
// the maximum buffer size is too small to read this frame into. allocate a vector for the frame and copy the bytes into it
let mut frame = vec! [ 0; frame_len ];
frame[..frame_bytes_read] = buf[..frame_bytes_read];
self.read_state = FrameReadState::LargeFrameRead {
frame,
frame_bytes_read
};
}
FrameReadState::BufferedFrameRead {
ref mut buf,
ref mut frame_start,
frame_bytes_read,
frame_len
} if frame_start + frame_len > buf.len() => {
// the current buffer is too small to read the frame into.
// create a new buffer so that we don't have to heap allocate a vector to hold this frame
let next_buf = [0; 8 * 1024];
next_buf[..frame_bytes_read] = buf[(*frame_start)..(*frame_start + frame_bytes_read)];
buf = next_buf;
frame_start = 0;
}
FrameReadState::BufferedFrameRead {
ref mut buf,
ref mut frame_start,
ref mut frame_bytes_read,
ref mut frame_len
} => {
// we only care about at the buffer after the start position of the frame
let buf = &mut buf[*frame_start..];
// read from the stream until we've read the frame (note that we may already have read the next frame)
while frame_len > frame_bytes_read {
let bytes_read = self.stream.read(buf[frame_bytes_read..])?;
if bytes_read == 0 { return Ok(None) } // no bytes to read at the moment
frame_bytes_read += bytes_read;
}
let (frame_buf, overflow_buf) = (&mut buf[..*frame_bytes_read]).split_at_mut(frame_len);
self.read_state = if overflow_buf.len() > 4 {
// we've read the current frame and the header for the next frame
frame_start += frame_len + 4;
frame_bytes_read -= frame_len + 4;
frame_len = u32::from_le_bytes(overflow_buf[..4]);
} else {
// we've read the current frame, but not the complete header for the next frame
let mut header_buf = [0; 4];
// calculate the number of header bytes read
let header_bytes_read = *frame_bytes_read - frame_len;
// copy the header bytes from the overflow buffer
header_buf[..header_bytes_read] = overflow_buf[..header_bytes_read];
FrameReadState::HeaderRead {
header_buf,
header_bytes_read
};
};
return Ok(Some(Cow::Borrowed(frame_buf)));
}
FrameReadState::LargeFrameRead { mut frame, ref mut frame_bytes_read } => {
let mut header_buf = [0; 4];
let mut overflow_buf = [0; 8 * 1024];
while frame.len() > frame_bytes_read {
let unread_buf = &mut frame[*frame_bytes_read..];
let read_bytes = self.stream.read_vectored(&mut [
IoSliceMut::new(unread_buf),
IoSliceMut::new(header_buf),
IoSliceMut::new(overflow_buf)
])?;
if read_bytes == 0 { return Ok(None) } // no bytes to read at the moment
frame_bytes_read += read_bytes;
}
let overflow_bytes_read = frame_bytes_read - frame.len();
self.read_state = if overflow_bytes_read >= 4 {
FrameReadState::BufferedFrameRead {
buf: overflow_buf,
frame_start: 0,
frame_bytes_read: overflow_bytes_read - 4,
frame_len: u32::from_le_bytes(header_buf)
}
} else {
FrameReadState::HeaderRead {
header_buf,
header_bytes_read: overflow_bytes_read
}
}
}
}
}
}
}
pub enum WriteStatus {
/// The stream is out of write capacity
WantsWrite,
/// Out of frames to write
WantsFrames,
/// The frames channel has disconnected, and no more frames will be sent.
Disconnected
}
impl<T: Write> FramedIO<T> {
/// Attempt to write the next frame into the stream. Returns Ok(None) if no frames available
fn write(&mut self) -> std::io::Result<WriteStatus> {
loop {
match self.out_in_progress_frame {
Some(ref frame) => {
let bufs =
&[
IoSlice::new(&(frame.len() as u32).to_le_bytes()[self.out_bytes_written.max(4)..]),
IoSlice::new(frame[(self.out_bytes_written).min(4) - 4..])
];
match self.stream.write_vectored(bufs)? {
0 => return Ok(WriteStatus::WantsWrite), // can't write anymore
written => self.out_bytes_written += written
};
if self.out_bytes_written == frame.len() + 4 {
self.out_in_progress_frame = None; // finished writing
self.out_bytes_written -= frame.len() + 4;
}
}
None => {
match self.channel.upgrade().map(|src| src.queue.pop() ) {
Some(Some(frame)) => {
self.out_in_progress_frame = Some(frame);
},
Some(None) => return Ok(WriteStatus::WantsFrames), // no more frames to write
None => return Ok(WriteStatus::Disconnected)
}
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment