Skip to content

Instantly share code, notes, and snippets.

@lithdew
Created May 4, 2023 10:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lithdew/4a009884829ad4c0ce41e3c3e8bf039b to your computer and use it in GitHub Desktop.
Save lithdew/4a009884829ad4c0ce41e3c3e8bf039b to your computer and use it in GitHub Desktop.
rust (quinn, rustls): quic holepunching w/ basic stun client
[package]
name = "quic-holepunching"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.71"
bitflags = "2.2.1"
num_enum = "0.6.1"
quinn = { git = "https://github.com/quinn-rs/quinn.git", rev = "99e462f8c868b1ac511a360a496e90a73a878f07", features = ["runtime-tokio"] }
rcgen = "0.10.0"
rustls = { version = "0.21.0", features = ["quic", "dangerous_configuration"] }
serde = { version = "1.0.160", features = ["derive"] }
serde_json = "1.0.96"
tokio = { version = "1.28.0", features = ["full"] }
use anyhow::Result;
mod stun {
use core::fmt::Debug;
pub const MAGIC_COOKIE: u32 = 0x2112a442u32;
pub const MAGIC_COOKIE_BYTES: [u8; 4] = MAGIC_COOKIE.to_be_bytes();
#[derive(
Debug,
Copy,
Clone,
Eq,
PartialEq,
num_enum::FromPrimitive,
num_enum::IntoPrimitive,
serde::Serialize,
serde::Deserialize,
)]
#[repr(u16)]
#[serde(rename_all = "snake_case")]
pub enum Class {
Request = 0x0000,
Indication = 0x0010,
Success = 0x0100,
Error = 0x0110,
#[num_enum(catch_all)]
Unknown(u16),
}
#[derive(
Debug,
Copy,
Clone,
Eq,
PartialEq,
num_enum::FromPrimitive,
num_enum::IntoPrimitive,
serde::Serialize,
serde::Deserialize,
)]
#[repr(u8)]
#[serde(rename_all = "snake_case")]
pub enum Method {
Binding = 0x0001,
#[num_enum(catch_all)]
Unknown(u8),
}
#[derive(
Debug,
Copy,
Clone,
Eq,
PartialEq,
num_enum::FromPrimitive,
num_enum::IntoPrimitive,
serde::Serialize,
serde::Deserialize,
)]
#[repr(u16)]
#[serde(rename_all = "snake_case")]
pub enum Type {
MappedAddress = 0x0001,
Username = 0x0006,
MessageIntegrity = 0x0008,
ErrorCode = 0x0009,
UnknownAttributes = 0x000a,
Realm = 0x0014,
Nonce = 0x0015,
XorMappedAddress = 0x0020,
Software = 0x8022,
AlternateServer = 0x8023,
Fingerprint = 0x8028,
#[num_enum(catch_all)]
Unknown(u16),
}
#[derive(
Debug,
Copy,
Clone,
Eq,
PartialEq,
num_enum::FromPrimitive,
num_enum::IntoPrimitive,
serde::Serialize,
serde::Deserialize,
)]
#[repr(u8)]
#[serde(rename_all = "lowercase")]
pub enum AddressFamily {
IPv4 = 0x01,
IPv6 = 0x02,
#[num_enum(catch_all)]
Unknown(u8),
}
#[derive(Copy, Clone, serde::Deserialize)]
#[repr(transparent)]
#[serde(rename_all = "camelCase")]
pub struct HeaderFlags(pub u16);
impl HeaderFlags {
pub fn class(self) -> Class {
Class::from(self.0 & 0x0110)
}
pub fn method(self) -> Method {
Method::from((self.0 & 0x0001) as u8)
}
}
impl TryFrom<u16> for HeaderFlags {
type Error = anyhow::Error;
fn try_from(value: u16) -> std::result::Result<Self, Self::Error> {
if value & 0xc000 != 0x0000 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid header flags",
)
.into());
}
Ok(HeaderFlags(value))
}
}
impl Debug for HeaderFlags {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("HeaderFlags")
.field("class", &self.class())
.field("method", &self.method())
.finish()
}
}
impl serde::Serialize for HeaderFlags {
fn serialize<S: serde::Serializer>(
&self,
serializer: S,
) -> core::result::Result<S::Ok, S::Error> {
use serde::ser::SerializeStruct;
let mut flags = serializer.serialize_struct("HeaderFlags", 2)?;
flags.serialize_field("class", &self.class())?;
flags.serialize_field("method", &self.method())?;
flags.end()
}
}
#[repr(C)]
#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Header {
pub flags: HeaderFlags,
pub length: u16,
pub magic_cookie: u32,
pub transaction_id: [u32; 3],
}
impl Header {
pub const LENGTH: usize = 20;
}
const _: [(); std::mem::size_of::<Header>()] = [(); Header::LENGTH];
impl TryFrom<&[u8]> for Header {
type Error = anyhow::Error;
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
if value.len() < Header::LENGTH {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid header length",
)
.into());
}
let mut header = unsafe { std::ptr::read_unaligned(value.as_ptr() as *const Header) };
header.flags = HeaderFlags(u16::from_be(header.flags.0));
header.length = u16::from_be(header.length);
Ok(header)
}
}
impl From<Header> for [u8; Header::LENGTH] {
fn from(val: Header) -> [u8; Header::LENGTH] {
let mut header = val;
header.flags = HeaderFlags(header.flags.0.to_be());
header.length = header.length.to_be();
let mut buffer = [0u8; Header::LENGTH];
unsafe {
std::ptr::copy_nonoverlapping(
&header as *const Header as *const u8,
buffer.as_mut_ptr(),
Header::LENGTH,
);
}
buffer
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MappedAddress {
IPv4(std::net::SocketAddrV4),
IPv6(std::net::SocketAddrV6),
Unknown {
family: AddressFamily,
port: u16,
address: Vec<u8>,
},
}
impl MappedAddress {
pub const MIN_LENGTH: usize = 8;
pub fn xor(&self) -> Self {
let mut address = self.clone();
match &mut address {
Self::IPv4(address) => {
address.set_port(address.port() ^ (MAGIC_COOKIE >> 16) as u16);
let mut octets = address.ip().octets();
for i in 0..octets.len() {
octets[i] ^= MAGIC_COOKIE_BYTES[i % 4];
}
address.set_ip(std::net::Ipv4Addr::from(octets));
}
Self::IPv6(address) => {
address.set_port(address.port() ^ (MAGIC_COOKIE >> 16) as u16);
let mut octets = address.ip().octets();
for i in 0..octets.len() {
octets[i] ^= MAGIC_COOKIE_BYTES[i % 4];
}
address.set_ip(std::net::Ipv6Addr::from(octets));
}
Self::Unknown { port, address, .. } => {
*port ^= (MAGIC_COOKIE >> 16) as u16;
for i in 0..address.len() {
address[i] ^= MAGIC_COOKIE_BYTES[i % 4];
}
}
}
address
}
}
impl TryFrom<MappedAddress> for std::net::SocketAddr {
type Error = std::io::Error;
fn try_from(value: MappedAddress) -> std::result::Result<Self, Self::Error> {
match value {
MappedAddress::IPv4(address) => Ok(std::net::SocketAddr::V4(address)),
MappedAddress::IPv6(address) => Ok(std::net::SocketAddr::V6(address)),
MappedAddress::Unknown { .. } => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unknown address family",
)),
}
}
}
impl TryFrom<&[u8]> for MappedAddress {
type Error = anyhow::Error;
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
if value.len() < MappedAddress::MIN_LENGTH {
return Err(
std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid length").into(),
);
}
let family = AddressFamily::from(value[1]);
let port = u16::from_be_bytes([value[2], value[3]]);
Ok(match family {
AddressFamily::IPv4 => {
let address = [value[4], value[5], value[6], value[7]];
Self::IPv4(std::net::SocketAddrV4::new(address.into(), port))
}
AddressFamily::IPv6 => {
if value.len() - MappedAddress::MIN_LENGTH != 16 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid length",
)
.into());
}
let address = [
value[4], value[5], value[6], value[7], value[8], value[9], value[10],
value[11], value[12], value[13], value[14], value[15], value[16],
value[17], value[18], value[19],
];
Self::IPv6(std::net::SocketAddrV6::new(address.into(), port, 0, 0))
}
AddressFamily::Unknown(_) => Self::Unknown {
family,
port,
address: value[4..].to_vec(),
},
})
}
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum AttributeValue {
MappedAddress(MappedAddress),
XorMappedAddress(MappedAddress),
Software { value: String },
AlternateServer(MappedAddress),
Fingerprint { value: u32 },
Unknown(Vec<u8>),
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Attribute {
pub length: u16,
pub value: AttributeValue,
}
impl Attribute {
pub const HEADER_LENGTH: usize = 4;
}
impl TryFrom<&[u8]> for Attribute {
type Error = anyhow::Error;
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
if value.len() < Attribute::HEADER_LENGTH {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid attribute length",
)
.into());
}
let r#type = Type::from(u16::from_be_bytes([value[0], value[1]]));
let length = u16::from_be_bytes([value[2], value[3]]);
let value = &value[Attribute::HEADER_LENGTH..][..length as usize];
Ok(Attribute {
length,
value: match r#type {
Type::MappedAddress => {
AttributeValue::MappedAddress(MappedAddress::try_from(value)?)
}
Type::XorMappedAddress => {
AttributeValue::XorMappedAddress(MappedAddress::try_from(value)?.xor())
}
Type::Software => {
let value = std::str::from_utf8(value)?;
AttributeValue::Software {
value: value.to_owned(),
}
}
Type::AlternateServer => {
AttributeValue::AlternateServer(MappedAddress::try_from(value)?)
}
Type::Fingerprint => {
if value.len() != 4 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid fingerprint length",
)
.into());
}
let value = u32::from_be_bytes([value[0], value[1], value[2], value[3]]);
AttributeValue::Fingerprint { value }
}
_ => AttributeValue::Unknown(value.to_owned()),
},
})
}
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Message {
pub header: Header,
pub attributes: Vec<Attribute>,
}
impl Message {
pub const MAX_LENGTH: usize = 548;
}
impl TryFrom<&[u8]> for Message {
type Error = anyhow::Error;
fn try_from(value: &[u8]) -> std::result::Result<Self, Self::Error> {
let header = Header::try_from(value)?;
let mut attributes = Vec::new();
let mut offset = Header::LENGTH;
while offset < value.len() {
let attribute = Attribute::try_from(&value[offset..])?;
offset += Attribute::HEADER_LENGTH + attribute.length as usize;
attributes.push(attribute);
}
Ok(Self { header, attributes })
}
}
}
struct SkipServerVerification;
impl SkipServerVerification {
fn new() -> std::sync::Arc<Self> {
std::sync::Arc::new(Self)
}
}
impl rustls::client::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &rustls::Certificate,
_intermediates: &[rustls::Certificate],
_server_name: &rustls::ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: std::time::SystemTime,
) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
Ok(rustls::client::ServerCertVerified::assertion())
}
}
fn configure_client() -> quinn::ClientConfig {
let crypto = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(SkipServerVerification::new())
.with_no_client_auth();
quinn::ClientConfig::new(std::sync::Arc::new(crypto))
}
fn configure_server() -> Result<quinn::ServerConfig> {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])?;
let cert_der = cert.serialize_der()?;
let priv_key = cert.serialize_private_key_der();
let priv_key = rustls::PrivateKey(priv_key);
let cert_chain = vec![rustls::Certificate(cert_der)];
let mut server_config = quinn::ServerConfig::with_single_cert(cert_chain, priv_key)?;
let transport_config = std::sync::Arc::get_mut(&mut server_config.transport).unwrap();
transport_config.max_concurrent_uni_streams(0_u8.into());
Ok(server_config)
}
#[tokio::main]
async fn main() -> Result<()> {
let runtime = quinn::default_runtime()
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "no async runtime found"))?;
let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;
println!("Bound UDP socket to: {}", socket.local_addr()?);
let stun_request: [u8; stun::Header::LENGTH] = (stun::Header {
flags: stun::HeaderFlags(Into::<u8>::into(stun::Method::Binding) as u16),
length: 0,
magic_cookie: stun::MAGIC_COOKIE,
transaction_id: [0u32; 3],
})
.into();
socket
.send_to(&stun_request, "stun.l.google.com:19302")
.await?;
let message = {
let mut buffer = [0u8; stun::Message::MAX_LENGTH];
let (num_read, _) = socket.recv_from(&mut buffer).await?;
stun::Message::try_from(&buffer[..num_read])?
};
println!("{}", serde_json::to_string_pretty(&message)?);
let external_address = message
.attributes
.iter()
.find_map(|attribute| match &attribute.value {
stun::AttributeValue::MappedAddress(address) => Some(address),
stun::AttributeValue::XorMappedAddress(address) => Some(address),
_ => None,
})
.unwrap();
println!(
"Send this over: {}",
serde_json::to_string(external_address)?
);
use std::io::Write;
std::io::stdout().write_all(b"Provide the JSON from the other side: ")?;
std::io::stdout().flush()?;
let mut peer_address_json = String::new();
std::io::stdin().read_line(&mut peer_address_json)?;
let peer_address: stun::MappedAddress = serde_json::from_str(&peer_address_json)?;
let mut endpoint = quinn::Endpoint::new(
quinn::EndpointConfig::default(),
Some(configure_server()?),
socket.into_std()?,
runtime,
)?;
endpoint.set_default_client_config(configure_client());
let client_endpoint = endpoint.clone();
let server_endpoint = endpoint.clone();
tokio::spawn(async move {
while let Ok(peer) =
client_endpoint.connect(peer_address.to_owned().try_into().unwrap(), "localhost")
{
let peer = match peer.await {
Ok(peer) => peer,
Err(e) => {
println!("Error establishing connection on outgoing socket: {}", e);
continue;
}
};
let (mut send, mut recv) = match peer.open_bi().await {
Err(quinn::ConnectionError::ApplicationClosed { .. }) => return,
Err(err) => {
println!(
"Error opening bidirectional stream on outgoing socket: {}",
err
);
return;
}
Ok(stream) => stream,
};
use tokio::io::AsyncWriteExt;
tokio::spawn(async move {
let mut buffer = [0u8; 64 * 1024];
loop {
if let Err(e) = recv.read_exact(&mut buffer[.."PONG".len()]).await {
println!("Error reading from outgoing socket: {}", e);
return;
}
println!(
"Got '{}' from: {:?}",
String::from_utf8_lossy(&buffer[.."PONG".len()]),
peer.remote_address()
);
}
});
loop {
if let Err(e) = send.write_all("PING".as_bytes()).await {
println!("Error writing to outgoing socket: {}", e);
return;
}
if let Err(e) = send.flush().await {
println!("Error flushing writes to outgoing socket: {}", e);
return;
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
});
while let Some(peer) = server_endpoint.accept().await {
let peer = match peer.await {
Ok(peer) => peer,
Err(e) => {
println!("Error accepting connection on incoming socket: {}", e);
continue;
}
};
println!("Peer {} connected.", peer.remote_address());
use tokio::io::AsyncWriteExt;
tokio::spawn(async move {
let (mut send, mut recv) = match peer.accept_bi().await {
Err(quinn::ConnectionError::ApplicationClosed { .. }) => return,
Err(err) => {
println!(
"Error opening bidirectional stream on incoming socket: {}",
err
);
return;
}
Ok(stream) => stream,
};
let mut buffer = [0u8; 64 * 1024];
loop {
if let Err(e) = recv.read_exact(&mut buffer[.."PING".len()]).await {
println!("Error reading from incoming socket: {}", e);
return;
};
println!(
"Got '{}' from: {:?}",
String::from_utf8_lossy(&buffer[.."PING".len()]),
peer.remote_address()
);
if let Err(e) = send.write_all("PONG".as_bytes()).await {
println!("Error writing to incoming socket: {}", e);
return;
}
if let Err(e) = send.flush().await {
println!("Error flushing writes to incoming socket: {}", e);
return;
}
}
});
}
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment