Skip to content

Instantly share code, notes, and snippets.

@ggarber
Created July 13, 2020 12:16
Show Gist options
  • Save ggarber/8e203cfeba0ca7bc73e935d5ddb6c509 to your computer and use it in GitHub Desktop.
Save ggarber/8e203cfeba0ca7bc73e935d5ddb6c509 to your computer and use it in GitHub Desktop.
use std::{
fs, io,
net::SocketAddr,
path::{Path, PathBuf},
sync::Arc
};
use anyhow::{bail, Context, Result};
use futures::{StreamExt, TryFutureExt};
use structopt::{self, StructOpt};
use tracing::{error, info, info_span};
use tracing_futures::Instrument as _;
mod common;
#[derive(StructOpt, Debug)]
#[structopt(name = "server")]
struct Opt {
/// file to log TLS keys to for debugging
#[structopt(long = "keylog")]
keylog: bool,
/// directory to serve files from
#[structopt(parse(from_os_str))]
root: PathBuf,
/// TLS private key in PEM format
#[structopt(parse(from_os_str), short = "k", long = "key", requires = "cert")]
key: Option<PathBuf>,
/// TLS certificate in PEM format
#[structopt(parse(from_os_str), short = "c", long = "cert", requires = "key")]
cert: Option<PathBuf>,
/// Enable stateless retries
#[structopt(long = "stateless-retry")]
stateless_retry: bool,
/// Address to listen on
#[structopt(long = "listen", default_value = "[::1]:4433")]
listen: SocketAddr,
}
fn main() {
tracing::subscriber::set_global_default(
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.finish(),
)
.unwrap();
let opt = Opt::from_args();
let code = {
if let Err(e) = run(opt) {
eprintln!("ERROR: {}", e);
1
} else {
0
}
};
::std::process::exit(code);
}
static mut CONNS: Vec<&mut quinn::Connection> = Vec::new();
#[tokio::main]
async fn run(options: Opt) -> Result<()> {
let mut transport_config = quinn::TransportConfig::default();
transport_config.stream_window_uni(0);
let mut server_config = quinn::ServerConfig::default();
server_config.transport = Arc::new(transport_config);
let mut server_config = quinn::ServerConfigBuilder::new(server_config);
server_config.protocols(&[b"wq-vvv-01"]);
if options.keylog {
server_config.enable_keylog();
}
if options.stateless_retry {
server_config.use_stateless_retry(true);
}
if let (Some(key_path), Some(cert_path)) = (&options.key, &options.cert) {
let key = fs::read(key_path).context("failed to read private key")?;
let key = if key_path.extension().map_or(false, |x| x == "der") {
quinn::PrivateKey::from_der(&key)?
} else {
quinn::PrivateKey::from_pem(&key)?
};
let cert_chain = fs::read(cert_path).context("failed to read certificate chain")?;
let cert_chain = if cert_path.extension().map_or(false, |x| x == "der") {
quinn::CertificateChain::from_certs(quinn::Certificate::from_der(&cert_chain))
} else {
quinn::CertificateChain::from_pem(&cert_chain)?
};
server_config.certificate(cert_chain, key)?;
} else {
let dirs = directories_next::ProjectDirs::from("org", "quinn", "quinn-examples").unwrap();
let path = dirs.data_local_dir();
let cert_path = path.join("cert.der");
let key_path = path.join("key.der");
let (cert, key) = match fs::read(&cert_path).and_then(|x| Ok((x, fs::read(&key_path)?))) {
Ok(x) => x,
Err(ref e) if e.kind() == io::ErrorKind::NotFound => {
info!("generating self-signed certificate");
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let key = cert.serialize_private_key_der();
let cert = cert.serialize_der().unwrap();
fs::create_dir_all(&path).context("failed to create certificate directory")?;
fs::write(&cert_path, &cert).context("failed to write certificate")?;
fs::write(&key_path, &key).context("failed to write private key")?;
(cert, key)
}
Err(e) => {
bail!("failed to read certificate: {}", e);
}
};
info!("{}", &path.display());
let key = quinn::PrivateKey::from_der(&key)?;
let cert = quinn::Certificate::from_der(&cert)?;
server_config.certificate(quinn::CertificateChain::from_certs(vec![cert]), key)?;
}
let mut endpoint = quinn::Endpoint::builder();
endpoint.listen(server_config.build());
let root = Arc::<Path>::from(options.root.clone());
if !root.exists() {
bail!("root path does not exist");
}
let mut incoming = {
let (endpoint, incoming) = endpoint.bind(&options.listen)?;
info!("listening on {}", endpoint.local_addr()?);
incoming
};
while let Some(conn) = incoming.next().await {
info!("connection incoming");
tokio::spawn(
handle_connection(root.clone(), conn).unwrap_or_else(move |e| {
error!("connection failed: {reason}", reason = e.to_string())
}),
);
}
Ok(())
}
async fn handle_connection(_root: Arc<Path>, conn: quinn::Connecting) -> Result<()> {
let quinn::NewConnection {
mut connection,
mut datagrams,
..
} = conn.await?;
unsafe {
CONNS.push(&mut connection);
}
let span = info_span!(
"connection",
remote = %connection.remote_address(),
protocol = %connection
.handshake_data()
.unwrap()
.protocol
.map_or_else(|| "<none>".into(), |x| String::from_utf8_lossy(&x).into_owned())
);
async {
info!("established");
// Each stream initiated by the client constitutes a new request.
while let Some(datagram) = datagrams.next().await {
let _res = match datagram {
Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
info!("connection closed demo closed");
return Ok(());
}
Err(e) => {
return Err(e);
}
Ok(data) => {
info!("RECEIVED");
unsafe {
for conn in &CONNS {
info!("SENT");
let _res = conn.send_datagram(data.clone());
}
}
}
};
}
info!("done");
Ok(())
}
.instrument(span)
.await?;
info!("finished");
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment