Created
November 9, 2022 16:11
-
-
Save b-zee/9edd14a051434e8be8e3ad43da9e0eb8 to your computer and use it in GitHub Desktop.
`quinn` test failing on 0.9
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use quinn::{ | |
ClientConfig, ConnectionError, Endpoint, ServerConfig, VarInt, | |
}; | |
use std::{error::Error, net::SocketAddr, sync::Arc}; | |
const COUNT: u32 = 1_000; | |
#[tokio::test] | |
async fn test_many_uni() -> Result<(), Box<dyn Error>> { | |
// A vector of numbers counting up from 0. | |
let msgs_to_send: Vec<u32> = (0..COUNT).collect(); | |
let server = tokio::spawn(run_server()); | |
run_client(&msgs_to_send).await?; | |
let mut msgs_received = server.await?; | |
// Server might/will receive uni streams out of order. | |
msgs_received.sort(); | |
assert_eq!(msgs_to_send, msgs_received); | |
Ok(()) | |
} | |
async fn run_server() -> Vec<u32> { | |
let endpoint = make_server_endpoint("127.0.0.1:5000".parse().unwrap()).unwrap(); | |
let connection = endpoint.accept().await.unwrap().await.unwrap(); | |
println!( | |
"[server] accepted connection: addr={}", | |
connection.remote_address() | |
); | |
let mut tasks = vec![]; | |
loop { | |
let recv = connection.accept_uni().await; | |
let recv = match recv { | |
Ok(recv) => recv, | |
// Gracefully stop when the peer closes the connection on us. | |
Err(ConnectionError::ApplicationClosed(_)) => break, | |
Err(err) => panic!("accepting stream error: {err:?}"), | |
}; | |
// Receive the u32 from the stream. | |
tasks.push(tokio::spawn(async move { | |
let msg = recv.read_to_end(3000).await.unwrap(); | |
match <[u8; 4]>::try_from(msg) { | |
Ok(msg) => u32::from_be_bytes(msg), | |
Err(err) => panic!("received data is invalid u32: {err:?}"), | |
} | |
})); | |
} | |
// Split task results into two vectors (`Ok` and `Err`). | |
let (ok, err): (Vec<_>, Vec<_>) = futures::future::join_all(tasks) | |
.await | |
.into_iter() | |
.partition(|r| r.is_ok()); | |
let ok: Vec<_> = ok.into_iter().map(Result::unwrap).collect(); | |
let err: Vec<_> = err.into_iter().map(Result::unwrap_err).collect(); | |
if !err.is_empty() { | |
panic!("[server] recv tasks failed: {err:?}"); | |
} | |
ok | |
} | |
async fn run_client(msgs_to_send: &[u32]) -> Result<(), Box<dyn Error>> { | |
let client_cfg = configure_client(); | |
let mut endpoint = Endpoint::client("127.0.0.1:0".parse().unwrap())?; | |
endpoint.set_default_client_config(client_cfg); | |
// Connect to server | |
let connection = endpoint | |
.connect("127.0.0.1:5000".parse().unwrap(), "localhost") | |
.unwrap() | |
.await | |
.unwrap(); | |
let mut tasks = vec![]; | |
for id in msgs_to_send.iter().copied() { | |
let connection = connection.clone(); | |
tasks.push(tokio::spawn(async move { | |
let mut send = connection.open_uni().await.unwrap(); | |
send.write_all(&id.to_be_bytes()).await.unwrap(); | |
send.finish().await.unwrap(); | |
})); | |
} | |
// Wait for messages to get sent and close connection immediately. | |
futures::future::join_all(tasks).await; | |
connection.close(VarInt::from_u32(0), b""); | |
// Make sure the server has a chance to clean up | |
endpoint.wait_idle().await; | |
Ok(()) | |
} | |
/// Dummy certificate verifier that treats any certificate as valid. | |
/// NOTE, such verification is vulnerable to MITM attacks, but convenient for testing. | |
struct SkipServerVerification; | |
impl SkipServerVerification { | |
fn new() -> Arc<Self> { | |
Arc::new(Self) | |
} | |
} | |
impl rustls::client::ServerCertVerifier for SkipServerVerification { | |
fn verify_server_cert( | |
&self, | |
_end_entity: &rustls::Certificate, | |
_intermediates: &[rustls::Certificate], | |
_server_name: &rustls::ServerName, | |
_scts: &mut dyn Iterator<Item = &[u8]>, | |
_ocsp_response: &[u8], | |
_now: std::time::SystemTime, | |
) -> Result<rustls::client::ServerCertVerified, rustls::Error> { | |
Ok(rustls::client::ServerCertVerified::assertion()) | |
} | |
} | |
pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<Endpoint, Box<dyn Error>> { | |
let (server_config, _) = configure_server()?; | |
Ok(Endpoint::server(server_config, bind_addr)?) | |
} | |
fn configure_client() -> ClientConfig { | |
let crypto = rustls::ClientConfig::builder() | |
.with_safe_defaults() | |
.with_custom_certificate_verifier(SkipServerVerification::new()) | |
.with_no_client_auth(); | |
ClientConfig::new(Arc::new(crypto)) | |
} | |
fn configure_server() -> Result<(ServerConfig, Vec<u8>), Box<dyn Error>> { | |
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); | |
let cert_der = cert.serialize_der().unwrap(); | |
let priv_key = cert.serialize_private_key_der(); | |
let priv_key = rustls::PrivateKey(priv_key); | |
let cert_chain = vec![rustls::Certificate(cert_der.clone())]; | |
let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key)?; | |
let mut transport = quinn::TransportConfig::default(); | |
transport.max_concurrent_uni_streams(VarInt::from_u32(1)); | |
server_config.transport = Arc::new(transport); | |
Ok((server_config, cert_der)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment