Created
May 5, 2024 09:01
-
-
Save marcobacis/4ba9d0885862239d9ed7482983047472 to your computer and use it in GitHub Desktop.
Simple HTTP forward in Rust using actix and reqwest
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "lb" | |
version = "0.1.0" | |
edition = "2021" | |
[lib] | |
name = "lb" | |
path = "lib.rs" | |
[dependencies] | |
actix-web = "4.5.1" | |
clap = "4.5.2" | |
reqwest = "0.11.25" | |
tokio = { version="1.36.0", features = ["macros", "rt-multi-thread"] } | |
[dev-dependencies] | |
wiremock = "0.6.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{fmt::Display}; | |
use actix_web::{ | |
http::header::ContentType, | |
web::{self}, | |
App, HttpRequest, HttpResponse, HttpServer, ResponseError, | |
}; | |
use reqwest::Client; | |
struct LoadBalancer { | |
port: u16, | |
servers: Vec<String>, | |
} | |
struct AppState { | |
servers: Vec<String>, | |
} | |
impl LoadBalancer { | |
pub fn new(port: u16, servers: Vec<String>) -> Self { | |
LoadBalancer { port, servers } | |
} | |
pub fn uri(&self) -> String { | |
format!("http://127.0.0.1:{}", self.port) | |
} | |
pub async fn run(&self) { | |
let data = web::Data::new(AppState { | |
servers: self.servers.clone(), | |
}); | |
HttpServer::new(move || { | |
App::new() | |
.default_service(web::to(Self::handler)) | |
// We add the initial instance of our shared app state | |
.app_data(data.clone()) | |
}) | |
.bind(("127.0.0.1", self.port)) | |
.unwrap() | |
.run() | |
.await | |
.unwrap(); | |
} | |
async fn handler( | |
req: HttpRequest, | |
data: web::Data<AppState>, | |
bytes: web::Bytes, | |
) -> Result<HttpResponse, Error> { | |
let server = data.servers[0].clone(); | |
let uri = format!("{}{}", server, req.uri()); | |
let client = Client::new(); | |
let request_builder = client | |
.request(req.method().clone(), uri) | |
.headers(req.headers().into()) | |
.body(bytes); | |
let response = request_builder.send().await?; | |
let mut response_builder = HttpResponse::build(response.status()); | |
for h in response.headers().iter() { | |
response_builder.append_header(h); | |
} | |
let body = response.bytes().await?; | |
Ok(response_builder.body(body)) | |
} | |
} | |
#[derive(Debug)] | |
struct Error { | |
inner: reqwest::Error, | |
} | |
impl Display for Error { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
write!(f, "Forwarding error: {}", self.inner) | |
} | |
} | |
impl From<reqwest::Error> for Error { | |
fn from(value: reqwest::Error) -> Self { | |
Error { inner: value } | |
} | |
} | |
impl ResponseError for Error { | |
fn status_code(&self) -> reqwest::StatusCode { | |
reqwest::StatusCode::INTERNAL_SERVER_ERROR | |
} | |
fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> { | |
HttpResponse::build(self.status_code()) | |
.insert_header(ContentType::html()) | |
.body(self.to_string()) | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use reqwest::{Client, StatusCode}; | |
use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate}; | |
use crate::LoadBalancer; | |
#[tokio::test] | |
async fn test_get_root() { | |
// Setup a mock upstream server, to test that the request gets forwarded | |
let mock_server = MockServer::start().await; | |
Mock::given(method("GET")) | |
.respond_with(ResponseTemplate::new(200).set_body_string("backend")) | |
.expect(1) | |
.mount(&mock_server) | |
.await; | |
let client = Client::new(); | |
// The class under test, the load balancer itself | |
let server = LoadBalancer::new(8080, vec![mock_server.uri()]); | |
let server_uri = server.uri(); | |
tokio::spawn(async move { server.run().await }); | |
// Wait for the server to be up (will fix this later) | |
tokio::time::sleep(std::time::Duration::from_secs(3)).await; | |
// Check that we receive response from the mock backend | |
// (and not from the load balancer) | |
let response = client.get(server_uri).send().await.unwrap(); | |
assert_eq!(StatusCode::OK, response.status()); | |
assert_eq!("backend", response.text().await.unwrap()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment