Skip to content

Instantly share code, notes, and snippets.

@Armillus
Last active April 27, 2024 16:10
Show Gist options
  • Save Armillus/dd728df74b08c762adbf41f5a2b2627e to your computer and use it in GitHub Desktop.
Save Armillus/dd728df74b08c762adbf41f5a2b2627e to your computer and use it in GitHub Desktop.
[Rust][Tokio] io::copy_bidirectional issue

Usage

To test it, start the server and in a separate terminal, type cat /dev/urandom | nc 127.0.0.1 8000. Cut the server whenever you want and observe the logs on stdout.

Whenever io::copy_bidirectional is flushing, it will not lead the flush future to completion. Although it is not an issue for such a simple example, it can lead to a halt of the program when both ends are waiting for data, which is never flushed because io::copy_bidirectional does not lead poll_flush to completion.

[package]
name = "TokioCopyTest"
version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { version = "1.37.0", features = ["full"] }
// Standard imports
use std::future::Future;
use std::io::Error;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
// Tokio imports
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpListener;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::mpsc::error::SendError;
// --------------------------------------------------------
// Declarations
enum Data {
Network(Vec<u8>),
}
struct Stream {
flushing_future: Option<Pin<Box<dyn Send + Sync + Future<Output = Result<(), SendError<Data>>>>>>,
receiver: Receiver<Data>,
sender: Sender<Data>,
write_buffer: Vec<u8>,
read_bytes: usize,
written_bytes: usize,
flushed_bytes: usize,
}
// --------------------------------------------------------
// Implementations
#[tokio::main]
async fn main() {
let server = TcpListener::bind("127.0.0.1:8000").await.unwrap();
let (server_tx, mut server_rx) = tokio::sync::mpsc::channel::<Data>(64);
println!("TCP server listening on 127.0.0.1:8000!");
loop {
// Accept a new client and creates its own channel
let (mut client, _) = server.accept().await.unwrap();
let (client_tx, client_rx) = tokio::sync::mpsc::channel::<Data>(64);
println!("TCP client connected, bidirectional copy will now start.");
// Starts the bidirectional copy between the client and our stream concurrently
let mut stream = Stream::new(client_rx, server_tx.clone());
tokio::spawn(async move {
if let Err(e) = tokio::io::copy_bidirectional(&mut client, &mut stream).await {
eprintln!("Bidirectional copy failed: {e}.");
}
});
// Get client packets and answer to them while the client is opened
while !client_tx.is_closed() {
let Some(_data) = server_rx.recv().await else {
break;
};
// Process `_data` and get a response, here hardcoded
let response = vec![0; 512];
let response = Data::Network(response);
if let Err(e) = client_tx.send(response).await {
eprintln!("Failed to send response to the client: {e}.");
break;
}
}
}
}
impl Stream {
pub fn new(receiver: Receiver<Data>, sender: Sender<Data>) -> Self {
Self {
flushing_future: None,
receiver,
sender,
write_buffer: Vec::new(),
written_bytes: 0,
read_bytes: 0,
flushed_bytes: 0,
}
}
pub fn is_stopped(&self) -> bool {
self.sender.is_closed()
}
}
impl AsyncRead for Stream {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
if self.is_stopped() {
return Poll::Ready(Err(std::io::ErrorKind::Other.into()));
}
// Get some data from the server when it's ready
let data = match self.receiver.poll_recv(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
Poll::Ready(Some(Data::Network(data))) => data,
};
// For the sake of simplicity, we assume that the reading buffer is always
// large enough to hold all of `data` at once
assert!(buf.remaining() >= data.len());
self.read_bytes += data.len();
println!("Reading {} bytes | {} bytes read totally", data.len(), self.read_bytes);
// Extend the provided buffer with server's data
Poll::Ready(Ok(buf.put_slice(&data)))
}
}
impl AsyncWrite for Stream {
fn poll_write(mut self: Pin<&mut Self>, _: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
if self.is_stopped() {
return Poll::Ready(Err(std::io::ErrorKind::Other.into()));
}
// Bufferize written data
self.written_bytes += buf.len();
self.write_buffer.extend_from_slice(buf);
println!("Writing {} bytes | {} bytes written overall.", buf.len(), self.written_bytes);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.is_stopped() {
return Poll::Ready(Err(std::io::ErrorKind::Other.into()));
}
// Take the existing flushing future if any, or create a new one otherwise
let mut flushing_future = self.flushing_future.take().unwrap_or_else(|| {
let mut data = Vec::new();
// Take all data out of the writing buffer
data.append(&mut self.write_buffer);
self.flushed_bytes += data.len();
println!(
"Flushing {} bytes ({}/{} bytes flushed so far)",
data.len(),
self.flushed_bytes,
self.written_bytes,
);
let data = Data::Network(data);
let sender = self.sender.clone();
// Send flushed data to the serve
Box::pin(async move {
// Simulating some slight back-pressure
tokio::time::sleep(Duration::from_millis(1)).await;
sender.send(data).await
})
});
// Returns the flushing result if available, or store the future otherwise
match flushing_future.as_mut().poll(cx) {
Poll::Pending => {
self.flushing_future = Some(flushing_future);
println!("Flushing is pending.");
Poll::Pending
},
Poll::Ready(Err(_)) => Poll::Ready(Err(std::io::ErrorKind::Other.into())),
Poll::Ready(Ok(_)) => {
println!(
"Flushing done ({}/{} bytes flushed so far)",
self.flushed_bytes,
self.written_bytes,
);
Poll::Ready(Ok(()))
},
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.is_stopped() {
return Poll::Ready(Err(std::io::ErrorKind::Other.into()));
}
Poll::Ready(Ok(self.receiver.close()))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment