Last active
June 16, 2022 23:52
-
-
Save tfutada/477924a7e02c83cec0344110a02fc3df to your computer and use it in GitHub Desktop.
Ratelimit with Rust hyper and Redis
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use anyhow::Result; | |
use chrono::prelude::*; | |
use http::HeaderValue; | |
use std::net::IpAddr; | |
use std::{convert::Infallible, env, net::SocketAddr}; | |
use hyper::server::conn::AddrStream; | |
use hyper::service::{make_service_fn, service_fn}; | |
use hyper::{Body, Request, Response, Server, StatusCode}; | |
extern crate r2d2; | |
extern crate redis; | |
use r2d2_redis::redis::Commands; | |
use r2d2_redis::RedisConnectionManager; | |
const RATE_LIMIT: i32 = 10; | |
const CONNECTION_POOL_SIZE: u32 = 15; | |
type Pool = r2d2::Pool<RedisConnectionManager>; | |
#[derive(Clone)] | |
struct AppContext { | |
pool: Pool, | |
upstream: String, | |
} | |
async fn handle( | |
ctx: AppContext, | |
client_ip: IpAddr, | |
req: Request<Body>, | |
) -> Result<Response<Body>, Infallible> { | |
println!("{:?}", req); | |
let x_forward = req.headers().get("x-forwarded-for"); | |
let client_ip_addr: String = match x_forward { | |
None => client_ip.to_string(), | |
Some(x) => x.to_str().unwrap_or("").into(), | |
}; | |
// return true if it exceeds the rate limit. NB. This is a sync call. | |
match exceed_rate_limit(ctx.pool.clone(), client_ip_addr) { | |
Ok(x) if x => { | |
return Ok(Response::builder() | |
.status(StatusCode::TOO_MANY_REQUESTS) // status=429 | |
.body(Body::empty()) | |
.expect("failed to get Result")); | |
} | |
Err(e) => { | |
println!("{:?}", e); | |
return Ok(Response::builder() | |
.status(StatusCode::INTERNAL_SERVER_ERROR) // status=500 | |
.body(Body::empty()) | |
.expect("failed to get Result")); | |
} | |
_ => { | |
// no problem. Let's move on. | |
// You can do auth check, logging and so on. | |
} | |
}; | |
// Lastly, dispatch a request to the upstream server asynchronously. | |
match hyper_reverse_proxy::call(client_ip, &ctx.upstream, req).await { | |
Ok(response) => Ok(response), // status could not be 200. | |
Err(e) => { | |
let msg = format!("{:?}", e); | |
println!("{:?}", msg); | |
Ok(Response::builder() | |
.status(StatusCode::BAD_GATEWAY) // error occurred in proxy side. | |
.body(Body::from(msg)) | |
.unwrap()) | |
} | |
} | |
} | |
#[tokio::main(flavor = "multi_thread", worker_threads = 15)] | |
async fn main() { | |
let redis_endpoint = env::var("REDIS_ENDPOINT").expect("plz set REDIS_ENDPOINT"); | |
let redis_url = format!("redis://{}", redis_endpoint); | |
let manager = | |
RedisConnectionManager::new(redis_url).expect("failed to create a connection manager"); | |
let pool = r2d2::Pool::builder() | |
.max_size(CONNECTION_POOL_SIZE) | |
.build(manager) | |
.expect("pool"); | |
let upstream = env::var("UPSTREAM_ENDPOINT").expect("plz set UPSTREAM_ENDPOINT"); | |
let context = AppContext { pool, upstream }; | |
// core logic of HTTP server. | |
let make_svc = make_service_fn(move |conn: &AddrStream| { | |
let remote_addr = conn.remote_addr().ip(); // Remote address could be LB or local loopback | |
let context = context.clone(); | |
// a func to process each request. | |
let service = service_fn(move |req| handle(context.clone(), remote_addr, req)); | |
async move { Ok::<_, Infallible>(service) } | |
}); | |
let bind_addr = format!("0.0.0.0:{}", get_port()); | |
let addr: SocketAddr = bind_addr.parse().expect("failed to parse ip:port."); | |
let server = Server::bind(&addr).serve(make_svc); | |
println!("Running server on {:?}", addr); | |
if let Err(e) = server.await { | |
eprintln!("server error: {}", e); | |
} | |
} | |
// returns true if it exceeds the rate limit. | |
fn exceed_rate_limit(pool: Pool, ipaddr: String) -> Result<bool> { | |
let ut = unix_time(); | |
let key = format!("RATELIMIT:{ipaddr}:{ut}"); | |
let mut con = pool.get()?; | |
let count: i32 = con.incr(&key, 1)?; | |
let _: () = con.expire(&key, 10)?; // delete the key in 10s | |
let ret = count > RATE_LIMIT; | |
println!("Redis key>{} count>{} Above the limit?>{}", key, count, ret); | |
Ok(ret) | |
} | |
fn unix_time() -> i64 { | |
let dt: DateTime<Utc> = Utc::now(); | |
return dt.timestamp(); | |
} | |
// returns the port of Gateway, this hyper web server. | |
fn get_port() -> u16 { | |
// start up GW, reverse proxy | |
let mut port: u16 = 8080; | |
match env::var("PORT") { | |
Ok(p) => { | |
match p.parse::<u16>() { | |
Ok(n) => { | |
port = n; | |
} | |
Err(_e) => {} | |
}; | |
} | |
Err(_e) => {} | |
}; | |
port | |
} | |
fn debug_request(req: Request<Body>) -> Result<Response<Body>, Infallible> { | |
let body_str = format!("{:?}", req); | |
Ok(Response::new(Body::from(body_str))) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment