Skip to content

Instantly share code, notes, and snippets.

@tfutada
Last active June 16, 2022 23:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tfutada/477924a7e02c83cec0344110a02fc3df to your computer and use it in GitHub Desktop.
Save tfutada/477924a7e02c83cec0344110a02fc3df to your computer and use it in GitHub Desktop.
Ratelimit with Rust hyper and Redis
use anyhow::Result;
use chrono::prelude::*;
use http::HeaderValue;
use std::net::IpAddr;
use std::{convert::Infallible, env, net::SocketAddr};
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server, StatusCode};
extern crate r2d2;
extern crate redis;
use r2d2_redis::redis::Commands;
use r2d2_redis::RedisConnectionManager;
const RATE_LIMIT: i32 = 10;
const CONNECTION_POOL_SIZE: u32 = 15;
type Pool = r2d2::Pool<RedisConnectionManager>;
#[derive(Clone)]
struct AppContext {
pool: Pool,
upstream: String,
}
async fn handle(
ctx: AppContext,
client_ip: IpAddr,
req: Request<Body>,
) -> Result<Response<Body>, Infallible> {
println!("{:?}", req);
let x_forward = req.headers().get("x-forwarded-for");
let client_ip_addr: String = match x_forward {
None => client_ip.to_string(),
Some(x) => x.to_str().unwrap_or("").into(),
};
// return true if it exceeds the rate limit. NB. This is a sync call.
match exceed_rate_limit(ctx.pool.clone(), client_ip_addr) {
Ok(x) if x => {
return Ok(Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS) // status=429
.body(Body::empty())
.expect("failed to get Result"));
}
Err(e) => {
println!("{:?}", e);
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR) // status=500
.body(Body::empty())
.expect("failed to get Result"));
}
_ => {
// no problem. Let's move on.
// You can do auth check, logging and so on.
}
};
// Lastly, dispatch a request to the upstream server asynchronously.
match hyper_reverse_proxy::call(client_ip, &ctx.upstream, req).await {
Ok(response) => Ok(response), // status could not be 200.
Err(e) => {
let msg = format!("{:?}", e);
println!("{:?}", msg);
Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY) // error occurred in proxy side.
.body(Body::from(msg))
.unwrap())
}
}
}
#[tokio::main(flavor = "multi_thread", worker_threads = 15)]
async fn main() {
let redis_endpoint = env::var("REDIS_ENDPOINT").expect("plz set REDIS_ENDPOINT");
let redis_url = format!("redis://{}", redis_endpoint);
let manager =
RedisConnectionManager::new(redis_url).expect("failed to create a connection manager");
let pool = r2d2::Pool::builder()
.max_size(CONNECTION_POOL_SIZE)
.build(manager)
.expect("pool");
let upstream = env::var("UPSTREAM_ENDPOINT").expect("plz set UPSTREAM_ENDPOINT");
let context = AppContext { pool, upstream };
// core logic of HTTP server.
let make_svc = make_service_fn(move |conn: &AddrStream| {
let remote_addr = conn.remote_addr().ip(); // Remote address could be LB or local loopback
let context = context.clone();
// a func to process each request.
let service = service_fn(move |req| handle(context.clone(), remote_addr, req));
async move { Ok::<_, Infallible>(service) }
});
let bind_addr = format!("0.0.0.0:{}", get_port());
let addr: SocketAddr = bind_addr.parse().expect("failed to parse ip:port.");
let server = Server::bind(&addr).serve(make_svc);
println!("Running server on {:?}", addr);
if let Err(e) = server.await {
eprintln!("server error: {}", e);
}
}
// returns true if it exceeds the rate limit.
fn exceed_rate_limit(pool: Pool, ipaddr: String) -> Result<bool> {
let ut = unix_time();
let key = format!("RATELIMIT:{ipaddr}:{ut}");
let mut con = pool.get()?;
let count: i32 = con.incr(&key, 1)?;
let _: () = con.expire(&key, 10)?; // delete the key in 10s
let ret = count > RATE_LIMIT;
println!("Redis key>{} count>{} Above the limit?>{}", key, count, ret);
Ok(ret)
}
fn unix_time() -> i64 {
let dt: DateTime<Utc> = Utc::now();
return dt.timestamp();
}
// returns the port of Gateway, this hyper web server.
fn get_port() -> u16 {
// start up GW, reverse proxy
let mut port: u16 = 8080;
match env::var("PORT") {
Ok(p) => {
match p.parse::<u16>() {
Ok(n) => {
port = n;
}
Err(_e) => {}
};
}
Err(_e) => {}
};
port
}
fn debug_request(req: Request<Body>) -> Result<Response<Body>, Infallible> {
let body_str = format!("{:?}", req);
Ok(Response::new(Body::from(body_str)))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment