Skip to content

Instantly share code, notes, and snippets.

@marcobacis
Created May 5, 2024 09:01
Show Gist options
  • Save marcobacis/4ba9d0885862239d9ed7482983047472 to your computer and use it in GitHub Desktop.
Save marcobacis/4ba9d0885862239d9ed7482983047472 to your computer and use it in GitHub Desktop.
Simple HTTP forward in Rust using actix and reqwest
[package]
name = "lb"
version = "0.1.0"
edition = "2021"
[lib]
name = "lb"
path = "lib.rs"
[dependencies]
actix-web = "4.5.1"
clap = "4.5.2"
reqwest = "0.11.25"
tokio = { version="1.36.0", features = ["macros", "rt-multi-thread"] }
[dev-dependencies]
wiremock = "0.6.0"
use std::{fmt::Display};
use actix_web::{
http::header::ContentType,
web::{self},
App, HttpRequest, HttpResponse, HttpServer, ResponseError,
};
use reqwest::Client;
struct LoadBalancer {
port: u16,
servers: Vec<String>,
}
struct AppState {
servers: Vec<String>,
}
impl LoadBalancer {
pub fn new(port: u16, servers: Vec<String>) -> Self {
LoadBalancer { port, servers }
}
pub fn uri(&self) -> String {
format!("http://127.0.0.1:{}", self.port)
}
pub async fn run(&self) {
let data = web::Data::new(AppState {
servers: self.servers.clone(),
});
HttpServer::new(move || {
App::new()
.default_service(web::to(Self::handler))
// We add the initial instance of our shared app state
.app_data(data.clone())
})
.bind(("127.0.0.1", self.port))
.unwrap()
.run()
.await
.unwrap();
}
async fn handler(
req: HttpRequest,
data: web::Data<AppState>,
bytes: web::Bytes,
) -> Result<HttpResponse, Error> {
let server = data.servers[0].clone();
let uri = format!("{}{}", server, req.uri());
let client = Client::new();
let request_builder = client
.request(req.method().clone(), uri)
.headers(req.headers().into())
.body(bytes);
let response = request_builder.send().await?;
let mut response_builder = HttpResponse::build(response.status());
for h in response.headers().iter() {
response_builder.append_header(h);
}
let body = response.bytes().await?;
Ok(response_builder.body(body))
}
}
#[derive(Debug)]
struct Error {
inner: reqwest::Error,
}
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Forwarding error: {}", self.inner)
}
}
impl From<reqwest::Error> for Error {
fn from(value: reqwest::Error) -> Self {
Error { inner: value }
}
}
impl ResponseError for Error {
fn status_code(&self) -> reqwest::StatusCode {
reqwest::StatusCode::INTERNAL_SERVER_ERROR
}
fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
HttpResponse::build(self.status_code())
.insert_header(ContentType::html())
.body(self.to_string())
}
}
#[cfg(test)]
mod tests {
use reqwest::{Client, StatusCode};
use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
use crate::LoadBalancer;
#[tokio::test]
async fn test_get_root() {
// Setup a mock upstream server, to test that the request gets forwarded
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(200).set_body_string("backend"))
.expect(1)
.mount(&mock_server)
.await;
let client = Client::new();
// The class under test, the load balancer itself
let server = LoadBalancer::new(8080, vec![mock_server.uri()]);
let server_uri = server.uri();
tokio::spawn(async move { server.run().await });
// Wait for the server to be up (will fix this later)
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
// Check that we receive response from the mock backend
// (and not from the load balancer)
let response = client.get(server_uri).send().await.unwrap();
assert_eq!(StatusCode::OK, response.status());
assert_eq!("backend", response.text().await.unwrap());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment