Skip to content

Instantly share code, notes, and snippets.

@liamwh
Last active July 6, 2023 07:30
Show Gist options
  • Save liamwh/21afe1d90c9a2d33cb2fe024b6edaa4f to your computer and use it in GitHub Desktop.
Save liamwh/21afe1d90c9a2d33cb2fe024b6edaa4f to your computer and use it in GitHub Desktop.
Axum Websocket example
use std::{net::SocketAddr, sync::Arc};
use axum::{
debug_handler,
extract::{ws::WebSocket, ConnectInfo, Path, Query, State, WebSocketUpgrade},
response::{IntoResponse, Json},
};
use cloudevents::{AttributesReader, Event};
use serde::Deserialize;
use tokio::sync::broadcast::Receiver;
use tracing::instrument;
use crate::{
application::INTERVAL_CREATED_EVENT_TYPE,
domain::interval::{Interval, IntervalId, IntervalRepository},
presentation::{appstate::AppState, interval::IntervalApiError},
};
/// The handler for the HTTP request (this gets called when the HTTP GET lands at the start
/// of websocket negotiation). After this completes, the actual switching from HTTP to
/// websocket protocol will occur.
/// This is the last point where we can extract TCP/IP metadata such as IP address of the client
/// as well as things from HTTP headers such as user-agent of the browser etc.
#[debug_handler]
pub async fn interval_ws_handler(
ws: WebSocketUpgrade,
State(app_state): State<Arc<AppState>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
let interval_receiver = app_state.interval_receiver.resubscribe();
ws.on_upgrade(move |socket| handle_interval_socket(socket, addr, interval_receiver))
}
#[instrument(skip(interval_receiver))]
async fn handle_interval_socket(
mut websocket: WebSocket,
addr: SocketAddr,
mut interval_receiver: Receiver<Event>,
) {
loop {
let received_interval_event = interval_receiver.recv().await.unwrap();
if received_interval_event.ty() != INTERVAL_CREATED_EVENT_TYPE {
tracing::warn!("Received an event that was not an interval created event");
continue;
}
let interval: Interval = received_interval_event.try_into().unwrap();
let interval_view: IntervalView = interval.into();
let data = serde_json::to_string(&interval_view).unwrap();
if let Err(e) = websocket.send(data.into()).await {
tracing::error!("Error sending WebSocket message: {}", e);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment