Skip to content

Instantly share code, notes, and snippets.

@reu
Last active February 15, 2021 22:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save reu/c3fd098c046c139cc259a5027170708f to your computer and use it in GitHub Desktop.
Save reu/c3fd098c046c139cc259a5027170708f to your computer and use it in GitHub Desktop.
Naive http/https proxy
[package]
name = "naive-http-proxy"
version = "0.1.0"
authors = ["Rodrigo Navarro <rnavarro@rnavarro.com.br>"]
edition = "2018"
[[bin]]
name = "proxy"
path = "proxy.rs"
[dependencies]
clap = "2.33"
futures = "0.3"
log = "0.4"
headers = "0.3"
http = "0.2"
hyper = { version = "0.14", features = ["full"] }
stderrlog = "0.5"
tokio = { version = "1", features = ["full"] }
use clap::{crate_version, App, Arg};
use headers::authorization::{Basic, Bearer};
use headers::ProxyAuthorization;
use http::{Method, Request, Response, StatusCode};
use hyper::header;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Client, Server};
use log::{error, info, warn};
use std::convert::TryFrom;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::task::spawn;
enum AuthCredentials {
Basic { username: String, password: String },
Bearer { token: String },
}
impl AuthCredentials {
fn authenticate(&self, headers: &http::HeaderMap) -> bool {
use headers::HeaderMapExt;
match self {
Self::Basic { username, password } => headers
.typed_get::<ProxyAuthorization<Basic>>()
.map(|auth| auth.0)
.map(|auth| auth.username() == username && auth.password() == password)
.unwrap_or(false),
Self::Bearer { token } => headers
.typed_get::<ProxyAuthorization<Bearer>>()
.map(|auth| auth.0)
.map(|auth| auth.token() == token)
.unwrap_or(false),
}
}
}
impl TryFrom<&str> for AuthCredentials {
type Error = &'static str;
fn try_from(value: &str) -> Result<Self, Self::Error> {
match value.split(":").collect::<Vec<&str>>().as_slice() {
[username, password] => Ok(Self::Basic {
username: username.to_string(),
password: password.to_string(),
}),
[token] => Ok(Self::Bearer {
token: token.to_string(),
}),
_ => Err("Invalid authentication credentials."),
}
}
}
async fn tunnel(client: impl AsyncRead + AsyncWrite, server: impl AsyncRead + AsyncWrite) -> Result<(u64, u64), Box<dyn std::error::Error + Send + Sync>> {
let (mut client_reader, mut client_writer) = io::split(client);
let (mut server_reader, mut server_writer) = io::split(server);
let client_to_server = async {
let bytes = io::copy(&mut client_reader, &mut server_writer).await?;
server_writer.shutdown().await?;
Ok(bytes)
};
let server_to_client = async {
let bytes = io::copy(&mut server_reader, &mut client_writer).await?;
client_writer.shutdown().await?;
Ok(bytes)
};
tokio::try_join!(client_to_server, server_to_client)
}
#[tokio::main]
async fn main() {
let args = App::new("http-proxy")
.version(crate_version!())
.arg(
Arg::with_name("port")
.short("p")
.long("port")
.required(true)
.takes_value(true),
)
.arg(
Arg::with_name("quiet")
.short("q")
.long("quiet")
.help("Silence all output")
.takes_value(false),
)
.arg(
Arg::with_name("auth")
.short("a")
.long("auth")
.help("Proxy authentication credentials. Use <username:password> for basic auth or <token> for bearer token auth")
.takes_value(true),
)
.get_matches();
stderrlog::new()
.module(module_path!())
.verbosity(3)
.quiet(args.is_present("quiet"))
.init()
.unwrap();
let port = args
.value_of("port")
.and_then(|port| port.parse::<u16>().ok())
.expect("Invalid port number");
let credentials = args
.value_of("auth")
.and_then(|auth| AuthCredentials::try_from(auth).ok());
let credentials = Arc::new(credentials);
Server::bind(&SocketAddr::from(([0, 0, 0, 0], port)))
.serve(make_service_fn(|_conn| {
let credentials = credentials.clone();
async move {
let http_client = Arc::new(Client::new());
let http2_client: Client<_, Body> = hyper::Client::builder()
.http2_only(true)
.build_http();
let http2_client = Arc::new(http2_client);
Ok::<_, hyper::Error>(service_fn(move |mut req| {
let http_client = http_client.clone();
let http2_client = http2_client.clone();
let credentials = credentials.clone();
async move {
if let Some(credentials) = &*credentials {
if !credentials.authenticate(req.headers()) {
return Response::builder()
.status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
.body(Body::empty());
}
}
if req.method() == Method::CONNECT {
info!("Tunneling: {}", req.uri());
if let Some(address) = req.uri().authority().map(|a| a.to_string()) {
spawn(async move {
if let Ok(upgraded) = hyper::upgrade::on(req).await {
if let Ok(server) = TcpStream::connect(address.clone()).await {
match tunnel(upgraded, server).await {
Ok((a, b)) => info!("Tunneled bytes written {} read {}", a, b),
Err(err) => error!("CONNECT tunnel error: {}", err),
};
} else {
error!("Could not connect to address: {}", address);
}
}
});
Response::builder().status(StatusCode::OK).body(Body::empty())
} else {
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid socket address informed"))
}
} else {
info!("Proxying: {}", req.uri());
if req.method() == Method::TRACE || req.method() == Method::OPTIONS {
let max_forwards = req
.headers()
.get(header::MAX_FORWARDS)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<i32>().ok());
match max_forwards {
Some(num) if num == 0 => {
return Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Body::empty()) // TODO: add support for proper tracing
}
Some(num) => {
req.headers_mut().insert(header::MAX_FORWARDS, num.into());
}
_ => {}
}
}
if let Some(upgrade_header) = req.headers().get(header::UPGRADE) {
info!("Upgrading: {}", upgrade_header.to_str().unwrap_or_default());
let mut proxy_req = Request::builder()
.uri(req.uri())
.body(Body::empty())
.unwrap();
*proxy_req.headers_mut() = req.headers().clone();
return match http_client.request(proxy_req).await {
Ok(client_res) if client_res.status() == StatusCode::SWITCHING_PROTOCOLS => {
let mut res = Response::new(Body::empty());
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
*res.headers_mut() = client_res.headers().clone();
match hyper::upgrade::on(client_res).await {
Ok(upgraded_client) => {
spawn(async move {
if let Ok(upgraded_server) = hyper::upgrade::on(&mut req).await {
match tunnel(upgraded_client, upgraded_server).await {
Ok((a, b)) => info!("Tunneled bytes written {} read {}", a, b),
Err(err) => error!("Websocket tunnel error: {}", err),
};
}
});
Ok(res)
},
Err(err) => {
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from(format!("Bad gateway {}", err)))
}
}
},
Ok(client_res) => Ok(client_res),
Err(err) => {
error!("Upgrade error: {}", err);
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from(format!("Bad gateway {}", err)))
},
}
}
if req.version() == http::version::Version::HTTP_2 {
info!("HTTP2 connection");
return match http2_client.request(req).await {
Ok(res) => Ok(res),
Err(err) => {
error!("HTTP2 proxy error: {}", err);
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from(format!("Bad gateway {}", err)))
}
}
}
match http_client.request(req).await {
Ok(res) => Ok(res),
Err(err) => {
warn!("Proxy error: {}", err);
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from(format!("Bad gateway {}", err)))
}
}
}
}
}))
}
}))
.await
.unwrap();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment