Last active
February 15, 2021 22:57
-
-
Save reu/c3fd098c046c139cc259a5027170708f to your computer and use it in GitHub Desktop.
Naive http/https proxy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "naive-http-proxy" | |
version = "0.1.0" | |
authors = ["Rodrigo Navarro <rnavarro@rnavarro.com.br>"] | |
edition = "2018" | |
[[bin]] | |
name = "proxy" | |
path = "proxy.rs" | |
[dependencies] | |
clap = "2.33" | |
futures = "0.3" | |
log = "0.4" | |
headers = "0.3" | |
http = "0.2" | |
hyper = { version = "0.14", features = ["full"] } | |
stderrlog = "0.5" | |
tokio = { version = "1", features = ["full"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use clap::{crate_version, App, Arg}; | |
use headers::authorization::{Basic, Bearer}; | |
use headers::ProxyAuthorization; | |
use http::{Method, Request, Response, StatusCode}; | |
use hyper::header; | |
use hyper::service::{make_service_fn, service_fn}; | |
use hyper::{Body, Client, Server}; | |
use log::{error, info, warn}; | |
use std::convert::TryFrom; | |
use std::net::SocketAddr; | |
use std::sync::Arc; | |
use tokio::io; | |
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; | |
use tokio::net::TcpStream; | |
use tokio::task::spawn; | |
enum AuthCredentials { | |
Basic { username: String, password: String }, | |
Bearer { token: String }, | |
} | |
impl AuthCredentials { | |
fn authenticate(&self, headers: &http::HeaderMap) -> bool { | |
use headers::HeaderMapExt; | |
match self { | |
Self::Basic { username, password } => headers | |
.typed_get::<ProxyAuthorization<Basic>>() | |
.map(|auth| auth.0) | |
.map(|auth| auth.username() == username && auth.password() == password) | |
.unwrap_or(false), | |
Self::Bearer { token } => headers | |
.typed_get::<ProxyAuthorization<Bearer>>() | |
.map(|auth| auth.0) | |
.map(|auth| auth.token() == token) | |
.unwrap_or(false), | |
} | |
} | |
} | |
impl TryFrom<&str> for AuthCredentials { | |
type Error = &'static str; | |
fn try_from(value: &str) -> Result<Self, Self::Error> { | |
match value.split(":").collect::<Vec<&str>>().as_slice() { | |
[username, password] => Ok(Self::Basic { | |
username: username.to_string(), | |
password: password.to_string(), | |
}), | |
[token] => Ok(Self::Bearer { | |
token: token.to_string(), | |
}), | |
_ => Err("Invalid authentication credentials."), | |
} | |
} | |
} | |
async fn tunnel(client: impl AsyncRead + AsyncWrite, server: impl AsyncRead + AsyncWrite) -> Result<(u64, u64), Box<dyn std::error::Error + Send + Sync>> { | |
let (mut client_reader, mut client_writer) = io::split(client); | |
let (mut server_reader, mut server_writer) = io::split(server); | |
let client_to_server = async { | |
let bytes = io::copy(&mut client_reader, &mut server_writer).await?; | |
server_writer.shutdown().await?; | |
Ok(bytes) | |
}; | |
let server_to_client = async { | |
let bytes = io::copy(&mut server_reader, &mut client_writer).await?; | |
client_writer.shutdown().await?; | |
Ok(bytes) | |
}; | |
tokio::try_join!(client_to_server, server_to_client) | |
} | |
#[tokio::main] | |
async fn main() { | |
let args = App::new("http-proxy") | |
.version(crate_version!()) | |
.arg( | |
Arg::with_name("port") | |
.short("p") | |
.long("port") | |
.required(true) | |
.takes_value(true), | |
) | |
.arg( | |
Arg::with_name("quiet") | |
.short("q") | |
.long("quiet") | |
.help("Silence all output") | |
.takes_value(false), | |
) | |
.arg( | |
Arg::with_name("auth") | |
.short("a") | |
.long("auth") | |
.help("Proxy authentication credentials. Use <username:password> for basic auth or <token> for bearer token auth") | |
.takes_value(true), | |
) | |
.get_matches(); | |
stderrlog::new() | |
.module(module_path!()) | |
.verbosity(3) | |
.quiet(args.is_present("quiet")) | |
.init() | |
.unwrap(); | |
let port = args | |
.value_of("port") | |
.and_then(|port| port.parse::<u16>().ok()) | |
.expect("Invalid port number"); | |
let credentials = args | |
.value_of("auth") | |
.and_then(|auth| AuthCredentials::try_from(auth).ok()); | |
let credentials = Arc::new(credentials); | |
Server::bind(&SocketAddr::from(([0, 0, 0, 0], port))) | |
.serve(make_service_fn(|_conn| { | |
let credentials = credentials.clone(); | |
async move { | |
let http_client = Arc::new(Client::new()); | |
let http2_client: Client<_, Body> = hyper::Client::builder() | |
.http2_only(true) | |
.build_http(); | |
let http2_client = Arc::new(http2_client); | |
Ok::<_, hyper::Error>(service_fn(move |mut req| { | |
let http_client = http_client.clone(); | |
let http2_client = http2_client.clone(); | |
let credentials = credentials.clone(); | |
async move { | |
if let Some(credentials) = &*credentials { | |
if !credentials.authenticate(req.headers()) { | |
return Response::builder() | |
.status(StatusCode::PROXY_AUTHENTICATION_REQUIRED) | |
.body(Body::empty()); | |
} | |
} | |
if req.method() == Method::CONNECT { | |
info!("Tunneling: {}", req.uri()); | |
if let Some(address) = req.uri().authority().map(|a| a.to_string()) { | |
spawn(async move { | |
if let Ok(upgraded) = hyper::upgrade::on(req).await { | |
if let Ok(server) = TcpStream::connect(address.clone()).await { | |
match tunnel(upgraded, server).await { | |
Ok((a, b)) => info!("Tunneled bytes written {} read {}", a, b), | |
Err(err) => error!("CONNECT tunnel error: {}", err), | |
}; | |
} else { | |
error!("Could not connect to address: {}", address); | |
} | |
} | |
}); | |
Response::builder().status(StatusCode::OK).body(Body::empty()) | |
} else { | |
Response::builder() | |
.status(StatusCode::BAD_REQUEST) | |
.body(Body::from("Invalid socket address informed")) | |
} | |
} else { | |
info!("Proxying: {}", req.uri()); | |
if req.method() == Method::TRACE || req.method() == Method::OPTIONS { | |
let max_forwards = req | |
.headers() | |
.get(header::MAX_FORWARDS) | |
.and_then(|value| value.to_str().ok()) | |
.and_then(|value| value.parse::<i32>().ok()); | |
match max_forwards { | |
Some(num) if num == 0 => { | |
return Response::builder() | |
.status(StatusCode::NO_CONTENT) | |
.body(Body::empty()) // TODO: add support for proper tracing | |
} | |
Some(num) => { | |
req.headers_mut().insert(header::MAX_FORWARDS, num.into()); | |
} | |
_ => {} | |
} | |
} | |
if let Some(upgrade_header) = req.headers().get(header::UPGRADE) { | |
info!("Upgrading: {}", upgrade_header.to_str().unwrap_or_default()); | |
let mut proxy_req = Request::builder() | |
.uri(req.uri()) | |
.body(Body::empty()) | |
.unwrap(); | |
*proxy_req.headers_mut() = req.headers().clone(); | |
return match http_client.request(proxy_req).await { | |
Ok(client_res) if client_res.status() == StatusCode::SWITCHING_PROTOCOLS => { | |
let mut res = Response::new(Body::empty()); | |
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; | |
*res.headers_mut() = client_res.headers().clone(); | |
match hyper::upgrade::on(client_res).await { | |
Ok(upgraded_client) => { | |
spawn(async move { | |
if let Ok(upgraded_server) = hyper::upgrade::on(&mut req).await { | |
match tunnel(upgraded_client, upgraded_server).await { | |
Ok((a, b)) => info!("Tunneled bytes written {} read {}", a, b), | |
Err(err) => error!("Websocket tunnel error: {}", err), | |
}; | |
} | |
}); | |
Ok(res) | |
}, | |
Err(err) => { | |
Response::builder() | |
.status(StatusCode::BAD_GATEWAY) | |
.body(Body::from(format!("Bad gateway {}", err))) | |
} | |
} | |
}, | |
Ok(client_res) => Ok(client_res), | |
Err(err) => { | |
error!("Upgrade error: {}", err); | |
Response::builder() | |
.status(StatusCode::BAD_GATEWAY) | |
.body(Body::from(format!("Bad gateway {}", err))) | |
}, | |
} | |
} | |
if req.version() == http::version::Version::HTTP_2 { | |
info!("HTTP2 connection"); | |
return match http2_client.request(req).await { | |
Ok(res) => Ok(res), | |
Err(err) => { | |
error!("HTTP2 proxy error: {}", err); | |
Response::builder() | |
.status(StatusCode::BAD_GATEWAY) | |
.body(Body::from(format!("Bad gateway {}", err))) | |
} | |
} | |
} | |
match http_client.request(req).await { | |
Ok(res) => Ok(res), | |
Err(err) => { | |
warn!("Proxy error: {}", err); | |
Response::builder() | |
.status(StatusCode::BAD_GATEWAY) | |
.body(Body::from(format!("Bad gateway {}", err))) | |
} | |
} | |
} | |
} | |
})) | |
} | |
})) | |
.await | |
.unwrap(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment