Skip to content

Instantly share code, notes, and snippets.

@uzytkownik
Created March 16, 2018 00:20
Show Gist options
  • Save uzytkownik/edd622a9334058de97bd20c658738e91 to your computer and use it in GitHub Desktop.
Save uzytkownik/edd622a9334058de97bd20c658738e91 to your computer and use it in GitHub Desktop.
Block WebSocket stream for a time
extern crate actix;
extern crate actix_web;
extern crate tokio;
use actix::*;
use actix_web::*;
use std::cell::RefCell;
use std::rc::Rc;
use std::time::Duration;
use std::vec::Vec;
use tokio::prelude::*;
#[derive(Debug)]
struct BlockState {
blocked: usize,
tasks: Vec<task::Task>
}
#[derive(Debug)]
struct Blocker {
state: Rc<RefCell<BlockState>>
}
#[derive(Debug)]
struct Blocked<S> {
val: S,
state: Rc<RefCell<BlockState>>
}
impl Blocker {
fn new() -> Self {
Blocker {
state: Rc::new(RefCell::new(BlockState {
blocked: 0,
tasks: Vec::new()
}))
}
}
fn gate<T>(&self, t: T) -> Blocked<T> {
Blocked {
val: t,
state: self.state.clone()
}
}
fn block(&mut self) {
let mut st = self.state.borrow_mut();
st.blocked += 1;
}
fn unblock(&mut self) {
let mut st = self.state.borrow_mut();
st.blocked -= 1;
if st.blocked == 0 {
let tasks = std::mem::replace(&mut st.tasks, Vec::new());
std::mem::drop(st);
for t in tasks {
t.notify();
}
}
}
}
impl<S: Stream> Stream for Blocked<S> {
type Item = S::Item;
type Error = S::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let mut st = self.state.borrow_mut();
if st.blocked > 0 {
st.tasks.push(task::current());
return Ok(Async::NotReady)
} else {
std::mem::drop(st);
self.val.poll()
}
}
}
#[derive(Debug)]
struct WSHandler {
blocker: Blocker
}
impl Actor for WSHandler {
type Context = ws::WebsocketContext<Self>;
}
impl StreamHandler<ws::Message, ws::ProtocolError> for WSHandler {
fn handle(&mut self, _msg: ws::Message, ctx: &mut Self::Context) {
self.blocker.block();
ctx.run_later(Duration::new(5, 0), |a, _ctx| {
a.blocker.unblock();
});
}
}
fn index(req: HttpRequest) -> Result<HttpResponse, Error> {
let mut resp = ws::handshake(&req)?;
let blocker = Blocker::new();
let stream = blocker.gate(ws::WsStream::new(req.clone()));
let mut ctx = ws::WebsocketContext::new(req, WSHandler {
blocker: blocker
});
ctx.add_stream(stream);
Ok(resp.body(ctx)?)
}
fn main() {
HttpServer::new(
|| Application::new()
.resource("/", |r| r.f(index)))
.bind("127.0.0.1:8088").expect("Can not bind to 127.0.0.1:8088")
.run();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment