Skip to content

Instantly share code, notes, and snippets.

@staticgc
Created January 30, 2021 06:31
Show Gist options
  • Save staticgc/f3df4ba3499d00b51c13ac4d898f3b5e to your computer and use it in GitHub Desktop.
Save staticgc/f3df4ba3499d00b51c13ac4d898f3b5e to your computer and use it in GitHub Desktop.
Bandwidth control in hyper
use std::task::{Context, Poll};
use std::pin::Pin;
use hyper::client::{connect::{Connection, Connected}};
use tokio::net::{TcpStream};
use tokio::io::{ReadBuf, AsyncWrite, AsyncRead};
use async_speed_limit::{limiter::{Limiter, Consume}, clock::{StandardClock}};
use futures_util::future::Future;
//use anyhow::Error;
use std::io::Error;
use pin_project::pin_project;
#[pin_project]
pub struct MyStream {
#[pin]
io: TcpStream,
lim: Limiter,
lim_fut: Option<Consume<StandardClock, ()>>
}
impl MyStream {
pub async fn connect(addr: &str, lim: Limiter) -> Result<MyStream, Error> {
println!("connecting ... {}", addr);
let io = TcpStream::connect(addr).await?;
println!("connected");
Ok(MyStream{
io,
lim,
lim_fut: None,
})
}
}
impl AsyncWrite for MyStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
let this = self.project();
if let Some(fut) = this.lim_fut {
if Pin::new(fut).poll(cx).is_pending() {
return Poll::Pending
}
*this.lim_fut = None;
}
if let Poll::Ready(res) = this.io.poll_write(cx, buf) {
match res {
Ok(sz) => {
println!("sz = {}", sz);
if sz > 0 {
*this.lim_fut = Some(this.lim.consume(sz))
}
Poll::Ready(Ok(sz))
},
Err(e) => Poll::Ready(Err(e))
}
}else{
Poll::Pending
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, Error>> {
let this = self.project();
this.io.poll_write_vectored(cx, bufs)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
let this = self.project();
this.io.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
let this = self.project();
this.io.poll_shutdown(cx)
}
}
impl AsyncRead for MyStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>,) -> Poll<Result<(), Error>> {
let io:Pin<&mut TcpStream> = Pin::new(&mut self.get_mut().io);
io.poll_read(cx, buf)
}
}
impl Connection for MyStream {
fn connected(&self) -> Connected {
let connected = Connected::new();
connected
}
}
use hyper::{self, service::Service, Uri};
use std::future::Future;
use std::task::{Context, Poll};
use std::pin::Pin;
use async_speed_limit::{limiter::{Limiter}};
use crate::conn::MyStream;
use std::io::Error;
#[derive(Clone)]
pub struct MyConnector{
lim: Limiter,
}
impl MyConnector {
pub fn with_limit(limit: f64) -> Self {
MyConnector {
lim: Limiter::new(limit),
}
}
async fn new_conn(u: &Uri, lim: Limiter) -> Result<MyStream, Error> {
println!("new_conn");
let host = u.host().unwrap();
let port = u.port_u16().unwrap();
let addr = format!("{}:{}", host, port);
let io = MyStream::connect(&addr, lim).await?;
println!("new_conn - done");
Ok(io)
}
}
impl Service<Uri> for MyConnector
{
type Response = MyStream;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, dst: Uri) -> Self::Future {
let lim = self.lim.clone();
Box::pin(async move { MyConnector::new_conn(&dst, lim).await })
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment