Skip to content

Instantly share code, notes, and snippets.

@marcobacis
Created May 22, 2024 16:00
Show Gist options
  • Save marcobacis/477191a3b908c27fd334137cf4b5ee1d to your computer and use it in GitHub Desktop.
Save marcobacis/477191a3b908c27fd334137cf4b5ee1d to your computer and use it in GitHub Desktop.
Load Balancer in Rust - Part 2
use std::{
fmt::Display,
sync::atomic::{AtomicUsize, Ordering},
};
use actix_web::{
http::header::ContentType,
web::{self, Data},
App, HttpRequest, HttpResponse, HttpServer, ResponseError,
};
use async_trait::async_trait;
use reqwest::Client;
pub struct LoadBalancer {
port: u16,
data: Data<AppState>,
}
struct AppState {
client: Client,
policy: Box<SafeRoutingPolicy>,
}
impl LoadBalancer {
pub fn new(port: u16, policy: Box<SafeRoutingPolicy>) -> Self {
LoadBalancer {
port,
data: web::Data::new(AppState {
client: Client::new(),
policy,
}),
}
}
pub fn uri(&self) -> String {
format!("http://127.0.0.1:{}", self.port)
}
pub async fn run(&self) {
let data = self.data.clone();
HttpServer::new(move || {
App::new()
// Healthcheck endpoint always returning 200 OK
.route("/health", web::get().to(HttpResponse::Ok))
.default_service(web::to(Self::handler))
// We add the initial instance of our shared app state
.app_data(data.clone())
})
.bind(("127.0.0.1", self.port))
.unwrap()
.run()
.await
.unwrap();
}
async fn handler(
req: HttpRequest,
data: web::Data<AppState>,
bytes: web::Bytes,
) -> Result<HttpResponse, Error> {
let server = data.policy.next(&req).await;
let uri = format!("{}{}", server, req.uri());
let request_builder = data
.client
.request(req.method().clone(), uri)
.headers(req.headers().into())
.body(bytes);
let response = request_builder.send().await?;
let mut response_builder = HttpResponse::build(response.status());
for h in response.headers().iter() {
response_builder.append_header(h);
}
let body = response.bytes().await?;
Ok(response_builder.body(body))
}
}
#[derive(Debug)]
pub struct Error {
inner: reqwest::Error,
}
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Forwarding error: {}", self.inner)
}
}
impl From<reqwest::Error> for Error {
fn from(value: reqwest::Error) -> Self {
Error { inner: value }
}
}
impl ResponseError for Error {
fn status_code(&self) -> reqwest::StatusCode {
reqwest::StatusCode::INTERNAL_SERVER_ERROR
}
fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
HttpResponse::build(self.status_code())
.insert_header(ContentType::html())
.body(self.to_string())
}
}
pub type SafeRoutingPolicy = dyn RoutingPolicy + Sync + Send;
#[async_trait]
pub trait RoutingPolicy {
async fn next(&self, request: &HttpRequest) -> String;
}
pub struct RoundRobinPolicy {
servers: Vec<String>,
idx: AtomicUsize,
}
impl RoundRobinPolicy {
pub fn new(servers: Vec<String>) -> Self {
Self {
idx: AtomicUsize::new(0),
servers: servers.clone(),
}
}
}
#[async_trait]
impl RoutingPolicy for RoundRobinPolicy {
async fn next(&self, _request: &HttpRequest) -> String {
let servers = &self.servers;
let max_server_idx = servers.len() - 1;
// Update index
let idx = self
.idx
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |idx| match idx {
x if x >= max_server_idx => Some(0),
c => Some(c + 1),
})
.unwrap_or_default();
// Return next server to forward the request to
servers.get(idx).unwrap().clone()
}
}
#[cfg(test)]
mod tests {
use reqwest::{Client, StatusCode};
use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
use crate::{LoadBalancer, RoundRobinPolicy};
#[tokio::test]
async fn test_get_root() {
// Setup a mock upstream server, to test that the request gets forwarded
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(200).set_body_string("backend"))
.expect(1)
.mount(&mock_server)
.await;
let client = Client::new();
// The class under test, the load balancer itself
let policy = Box::new(RoundRobinPolicy::new(vec![mock_server.uri().clone()]));
let server = LoadBalancer::new(8080, policy);
let server_uri = server.uri();
tokio::spawn(async move { server.run().await });
// Wait for the server to be up (will fix this later)
wait_server_up(&client, &server_uri, 3).await;
// Check that we receive response from the mock backend
// (and not from the load balancer)
let response = client.get(server_uri).send().await.unwrap();
assert_eq!(StatusCode::OK, response.status());
assert_eq!("backend", response.text().await.unwrap());
}
#[tokio::test]
async fn test_round_robin_three_servers() {
let mocks = [
MockServer::start().await,
MockServer::start().await,
MockServer::start().await,
];
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(200).set_body_string("1"))
.mount(&mocks[0])
.await;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(200).set_body_string("2"))
.mount(&mocks[1])
.await;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(200).set_body_string("3"))
.mount(&mocks[2])
.await;
let client = Client::new();
let mock_uris: Vec<_> = mocks.iter().map(|mock| mock.uri()).collect();
// Spawn server
let policy = Box::new(RoundRobinPolicy::new(mock_uris.clone()));
let server = LoadBalancer::new(8082, policy);
let server_uri = server.uri();
tokio::spawn(async move { server.run().await });
wait_server_up(&client, &server_uri, 3).await;
// Send requests, expect to respond in round robin
let response = client.get(&server_uri).send().await.unwrap();
assert_eq!(StatusCode::OK, response.status());
assert_eq!("1", response.text().await.unwrap());
let response = client.get(&server_uri).send().await.unwrap();
assert_eq!(StatusCode::OK, response.status());
assert_eq!("2", response.text().await.unwrap());
let response = client.get(&server_uri).send().await.unwrap();
assert_eq!(StatusCode::OK, response.status());
assert_eq!("3", response.text().await.unwrap());
let response = client.get(&server_uri).send().await.unwrap();
assert_eq!(StatusCode::OK, response.status());
assert_eq!("1", response.text().await.unwrap());
}
pub async fn wait_server_up(client: &Client, uri: &str, max_retries: usize) {
let health_uri = format!("{}/health", uri);
for _ in 0..max_retries {
let response = client.get(&health_uri).send().await;
if response.is_ok() {
return;
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
panic!("Server didn't start...");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment