Skip to content

Instantly share code, notes, and snippets.

@Pzixel
Created May 23, 2024 11:02
Show Gist options
  • Save Pzixel/3928af2b3dce9baf2ece569beedc14a8 to your computer and use it in GitHub Desktop.
Save Pzixel/3928af2b3dce9baf2ece569beedc14a8 to your computer and use it in GitHub Desktop.
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::sync::Mutex;
use std::task::Waker;
use std::time::Duration;
use futures_util::SinkExt;
use futures_util::StreamExt;
use reqwest::header::HeaderName;
use tokio_tungstenite;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Update {
pub data: u32
}
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub pricing_url: String,
pub auth_name: String,
pub api_password: String,
}
fn get_stream(config: Config, ping_interval: Duration, pong_timout: Duration, reconnect_delay: Duration) -> impl futures_core::stream::Stream<Item = Update> {
async_stream::stream! {
// async {
loop {
let mut request = config.pricing_url.as_str().into_client_request().unwrap();
request.headers_mut().insert(
HeaderName::from_static("name"),
config.auth_name.parse().unwrap(),
);
request.headers_mut().insert(
reqwest::header::AUTHORIZATION,
config.api_password.parse().unwrap(),
);
let ws_stream = match tokio_tungstenite::connect_async(request).await {
Ok((ws_stream, _)) => ws_stream,
Err(e) => {
tracing::warn!("[wss] Cannot connect to RPC endpoint '{}': {}. Waiting and reconnecting", &config.pricing_url, e);
tokio::time::sleep(reconnect_delay).await;
continue;
}
};
tracing::info!("[] Connected to {}", config.pricing_url);
let (mut write, mut read) = ws_stream.split();
let should_keep_pinging = Arc::new(AtomicBool::new(true));
{
let should_keep_pinging = Arc::clone(&should_keep_pinging);
tokio::spawn({
async move {
while should_keep_pinging.load(std::sync::atomic::Ordering::Relaxed) {
tokio::time::sleep(ping_interval).await;
if let Err(e) = write.send(tokio_tungstenite::tungstenite::Message::Ping(Vec::new())).await {
tracing::error!("[] Error sending ping for subscription: {}. Expecting disconnect to be discovered in {}s", e, pong_timout.as_secs_f32());
}
};
}
});
}
loop {
let msg = match tokio::time::timeout(pong_timout, read.next()).await {
Ok(Some(Ok(msg))) => msg,
Ok(Some(Err(e))) => {
tracing::warn!("[] Error reading from stream: {}", e);
break;
}
Ok(None) => {
tracing::warn!("[] Stream closed");
break;
}
Err(e) => {
tracing::warn!("[] Pong timeout of {}s has been elasped: {}", pong_timout.as_secs_f32(), e);
break;
}
};
let msg = match msg {
tokio_tungstenite::tungstenite::Message::Text(x) => x,
tokio_tungstenite::tungstenite::Message::Ping(_) => {
tracing::info!("[] Received ping");
// TODO: send pong back?
continue;
},
tokio_tungstenite::tungstenite::Message::Pong(_) => {
tracing::info!("[] Received pong");
continue;
}
x => {
tracing::warn!("[] Received unexpected message: {:?}. Breaking listening", x);
break;
}
};
let Ok(response) = serde_json::from_str::<Update>(&msg) else {
tracing::error!("[] Error parsing following response as NewHeadsResponse: {}", msg);
continue;
};
yield response;
}
should_keep_pinging.store(false, std::sync::atomic::Ordering::Relaxed);
tracing::info!("[] Reconnecting");
}
}
}
pub struct BufferedMessages {
future: tokio::task::JoinHandle<()>,
shared_state: Arc<Mutex<SharedState>>,
}
struct SharedState {
buffer: Vec<Update>,
waker: Option<Waker>,
}
impl BufferedMessages {
pub fn new(config: Config, ping_interval: Duration, pong_timout: Duration, reconnect_delay: Duration) -> Self {
let shared_state = Arc::new(Mutex::new(SharedState {
buffer: Vec::new(),
waker: None,
}));
let future = tokio::spawn({
let shared_state = Arc::clone(&shared_state);
async move {
let mut stream = Box::pin(get_stream(config, ping_interval, pong_timout, reconnect_delay));
while let Some(x) = stream.next().await {
let mut shared_state = shared_state.lock().unwrap();
shared_state.buffer.push(x);
if let Some(waker) = shared_state.waker.take() {
waker.wake()
}
}
}
});
Self {
future,
shared_state,
}
}
}
impl futures_core::stream::Stream for BufferedMessages {
type Item = Vec<Update>;
fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
let mut shared_state = self.shared_state.lock().unwrap();
if shared_state.buffer.is_empty() {
shared_state.waker = Some(cx.waker().clone());
std::task::Poll::Pending
} else {
std::task::Poll::Ready(Some(std::mem::take( &mut shared_state.buffer)))
}
}
}
impl Drop for BufferedMessages {
fn drop(&mut self) {
tracing::info!("[] Dropping BufferedMessages");
self.future.abort();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment