Created
May 22, 2024 16:00
-
-
Save marcobacis/477191a3b908c27fd334137cf4b5ee1d to your computer and use it in GitHub Desktop.
Load Balancer in Rust - Part 2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
fmt::Display, | |
sync::atomic::{AtomicUsize, Ordering}, | |
}; | |
use actix_web::{ | |
http::header::ContentType, | |
web::{self, Data}, | |
App, HttpRequest, HttpResponse, HttpServer, ResponseError, | |
}; | |
use async_trait::async_trait; | |
use reqwest::Client; | |
pub struct LoadBalancer { | |
port: u16, | |
data: Data<AppState>, | |
} | |
struct AppState { | |
client: Client, | |
policy: Box<SafeRoutingPolicy>, | |
} | |
impl LoadBalancer { | |
pub fn new(port: u16, policy: Box<SafeRoutingPolicy>) -> Self { | |
LoadBalancer { | |
port, | |
data: web::Data::new(AppState { | |
client: Client::new(), | |
policy, | |
}), | |
} | |
} | |
pub fn uri(&self) -> String { | |
format!("http://127.0.0.1:{}", self.port) | |
} | |
pub async fn run(&self) { | |
let data = self.data.clone(); | |
HttpServer::new(move || { | |
App::new() | |
// Healthcheck endpoint always returning 200 OK | |
.route("/health", web::get().to(HttpResponse::Ok)) | |
.default_service(web::to(Self::handler)) | |
// We add the initial instance of our shared app state | |
.app_data(data.clone()) | |
}) | |
.bind(("127.0.0.1", self.port)) | |
.unwrap() | |
.run() | |
.await | |
.unwrap(); | |
} | |
async fn handler( | |
req: HttpRequest, | |
data: web::Data<AppState>, | |
bytes: web::Bytes, | |
) -> Result<HttpResponse, Error> { | |
let server = data.policy.next(&req).await; | |
let uri = format!("{}{}", server, req.uri()); | |
let request_builder = data | |
.client | |
.request(req.method().clone(), uri) | |
.headers(req.headers().into()) | |
.body(bytes); | |
let response = request_builder.send().await?; | |
let mut response_builder = HttpResponse::build(response.status()); | |
for h in response.headers().iter() { | |
response_builder.append_header(h); | |
} | |
let body = response.bytes().await?; | |
Ok(response_builder.body(body)) | |
} | |
} | |
#[derive(Debug)] | |
pub struct Error { | |
inner: reqwest::Error, | |
} | |
impl Display for Error { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
write!(f, "Forwarding error: {}", self.inner) | |
} | |
} | |
impl From<reqwest::Error> for Error { | |
fn from(value: reqwest::Error) -> Self { | |
Error { inner: value } | |
} | |
} | |
impl ResponseError for Error { | |
fn status_code(&self) -> reqwest::StatusCode { | |
reqwest::StatusCode::INTERNAL_SERVER_ERROR | |
} | |
fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> { | |
HttpResponse::build(self.status_code()) | |
.insert_header(ContentType::html()) | |
.body(self.to_string()) | |
} | |
} | |
pub type SafeRoutingPolicy = dyn RoutingPolicy + Sync + Send; | |
#[async_trait] | |
pub trait RoutingPolicy { | |
async fn next(&self, request: &HttpRequest) -> String; | |
} | |
pub struct RoundRobinPolicy { | |
servers: Vec<String>, | |
idx: AtomicUsize, | |
} | |
impl RoundRobinPolicy { | |
pub fn new(servers: Vec<String>) -> Self { | |
Self { | |
idx: AtomicUsize::new(0), | |
servers: servers.clone(), | |
} | |
} | |
} | |
#[async_trait] | |
impl RoutingPolicy for RoundRobinPolicy { | |
async fn next(&self, _request: &HttpRequest) -> String { | |
let servers = &self.servers; | |
let max_server_idx = servers.len() - 1; | |
// Update index | |
let idx = self | |
.idx | |
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |idx| match idx { | |
x if x >= max_server_idx => Some(0), | |
c => Some(c + 1), | |
}) | |
.unwrap_or_default(); | |
// Return next server to forward the request to | |
servers.get(idx).unwrap().clone() | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use reqwest::{Client, StatusCode}; | |
use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate}; | |
use crate::{LoadBalancer, RoundRobinPolicy}; | |
#[tokio::test] | |
async fn test_get_root() { | |
// Setup a mock upstream server, to test that the request gets forwarded | |
let mock_server = MockServer::start().await; | |
Mock::given(method("GET")) | |
.respond_with(ResponseTemplate::new(200).set_body_string("backend")) | |
.expect(1) | |
.mount(&mock_server) | |
.await; | |
let client = Client::new(); | |
// The class under test, the load balancer itself | |
let policy = Box::new(RoundRobinPolicy::new(vec![mock_server.uri().clone()])); | |
let server = LoadBalancer::new(8080, policy); | |
let server_uri = server.uri(); | |
tokio::spawn(async move { server.run().await }); | |
// Wait for the server to be up (will fix this later) | |
wait_server_up(&client, &server_uri, 3).await; | |
// Check that we receive response from the mock backend | |
// (and not from the load balancer) | |
let response = client.get(server_uri).send().await.unwrap(); | |
assert_eq!(StatusCode::OK, response.status()); | |
assert_eq!("backend", response.text().await.unwrap()); | |
} | |
#[tokio::test] | |
async fn test_round_robin_three_servers() { | |
let mocks = [ | |
MockServer::start().await, | |
MockServer::start().await, | |
MockServer::start().await, | |
]; | |
Mock::given(method("GET")) | |
.respond_with(ResponseTemplate::new(200).set_body_string("1")) | |
.mount(&mocks[0]) | |
.await; | |
Mock::given(method("GET")) | |
.respond_with(ResponseTemplate::new(200).set_body_string("2")) | |
.mount(&mocks[1]) | |
.await; | |
Mock::given(method("GET")) | |
.respond_with(ResponseTemplate::new(200).set_body_string("3")) | |
.mount(&mocks[2]) | |
.await; | |
let client = Client::new(); | |
let mock_uris: Vec<_> = mocks.iter().map(|mock| mock.uri()).collect(); | |
// Spawn server | |
let policy = Box::new(RoundRobinPolicy::new(mock_uris.clone())); | |
let server = LoadBalancer::new(8082, policy); | |
let server_uri = server.uri(); | |
tokio::spawn(async move { server.run().await }); | |
wait_server_up(&client, &server_uri, 3).await; | |
// Send requests, expect to respond in round robin | |
let response = client.get(&server_uri).send().await.unwrap(); | |
assert_eq!(StatusCode::OK, response.status()); | |
assert_eq!("1", response.text().await.unwrap()); | |
let response = client.get(&server_uri).send().await.unwrap(); | |
assert_eq!(StatusCode::OK, response.status()); | |
assert_eq!("2", response.text().await.unwrap()); | |
let response = client.get(&server_uri).send().await.unwrap(); | |
assert_eq!(StatusCode::OK, response.status()); | |
assert_eq!("3", response.text().await.unwrap()); | |
let response = client.get(&server_uri).send().await.unwrap(); | |
assert_eq!(StatusCode::OK, response.status()); | |
assert_eq!("1", response.text().await.unwrap()); | |
} | |
pub async fn wait_server_up(client: &Client, uri: &str, max_retries: usize) { | |
let health_uri = format!("{}/health", uri); | |
for _ in 0..max_retries { | |
let response = client.get(&health_uri).send().await; | |
if response.is_ok() { | |
return; | |
} | |
tokio::time::sleep(std::time::Duration::from_secs(1)).await; | |
} | |
panic!("Server didn't start..."); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment