Skip to content

Instantly share code, notes, and snippets.

@ethereumdegen
Created June 11, 2023 15:42
Show Gist options
  • Save ethereumdegen/17b6f35c4f49e191044a2701f67fa555 to your computer and use it in GitHub Desktop.
Save ethereumdegen/17b6f35c4f49e191044a2701f67fa555 to your computer and use it in GitHub Desktop.
A websocket server + client implementation
use std::sync::Arc ;
use tokio::sync::{Mutex,RwLock};
mod websocket_messages;
use websocket_messages::{
SocketMessage,
SocketMessageDestination,
InboundMessage,
OutboundMessage
};
/*
This will start the websocket server.
You can use a utility such as 'websocat' to send messages to the server.
*/
#[tokio::main]
async fn main() -> std::io::Result<()> {
let websocket_client = Arc::new( Mutex::new( WebsocketClient::new() ) );
let websocket_server = Arc::new( Mutex::new( WebsocketServer::new() ) );
let server_url:String = "localhost:9000".to_string();
websocket_server.lock().await.start_in_thread(Some(server_url));
client_socket_conn.lock().await.connect("ws://localhost:9000".to_string()).await;
if let Err(e) = client_socket_conn {
println!("Error connectiong to socket server {}", e);
}else {
println!("Connected to socket server");
let msg = SocketMessage::Text("hello world".to_string());
client_socket_conn.lock().await.send_message( msg ).await
}
}
use futures_util::{ StreamExt, SinkExt};
use tokio_tungstenite::{connect_async, tungstenite::Message};
use tokio_tungstenite::{WebSocketStream,MaybeTlsStream};
use tokio::net::TcpStream;
use std::sync::Arc ;
use tokio::sync::{Mutex};
use std::thread;
use tokio::runtime::Runtime;
use crossbeam_channel::{ Receiver, Sender};
use crate::util::websocket_messages::SocketMessage;
use super::websocket_messages::InboundMessage;
pub struct Connection {
write: futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
read: Option< futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>> > , //can be used like a one-time mutex !
pub socket_connection_uuid: String
}
impl Connection {
/*
Consumes the single read stream and starts a new loop which continuously forwards received packets into a crossbeam channel
*/
pub fn start_listening_on_new_thread(
&mut self,
sender_channel: Sender<InboundMessage>,
) {
let mut read = self.read.take().expect("The read stream has already been consumed.");
let socket_connection_uuid = self.socket_connection_uuid.clone();
// Start a new OS thread
thread::spawn(move || {
// Create a new Tokio runtime
let rt = Runtime::new().unwrap();
// Use the runtime
rt.block_on(async {
while let Some(message_result) = read.next().await {
match message_result {
Ok(message) => {
let inbound_msg = InboundMessage {
socket_connection_uuid: socket_connection_uuid.clone(),
message: SocketMessage::from_message(message),
};
// Send the message into the crossbeam channel
sender_channel.send(inbound_msg).unwrap();
}
Err(e) => {
eprintln!("Error while reading message: {:?}", e);
break;
}
}
}
//if stops looping then somehow notify self that we are disconnected / not listening ?
});
});
}
pub async fn send_message(&mut self, message: SocketMessage )
{
println!("sending message out of conn");
let send_msg_result = self.write.send( message.to_message() ).await ;
println!("tried to send a msg out of websocket client conn ");
}
}
pub struct WebsocketClient{
pub connection: Option<Connection>,
}
impl WebsocketClient {
pub fn new() -> Self {
Self {
connection: None,
}
}
pub async fn connect(&mut self, connect_addr: String ) -> std::io::Result<()> {
let url = url::Url::parse(&connect_addr).unwrap();
loop {
match connect_async(url.clone()).await {
Ok((ws_stream, _)) => {
println!("WebSocket handshake has been successfully completed");
let (write, read) = ws_stream.split();
let socket_connection_uuid = uuid::Uuid::new_v4().to_string();
self.connection = Some( Connection {
write,
read : Some(read),
socket_connection_uuid
});
// once connected, break the loop
break;
},
Err(e) => {
println!("Failed to connect, retrying in 1 second...");
// wait for 1 second
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
}
}
}
Ok(())
}
pub fn listen_on_new_thread(&mut self, sender_channel: Sender<InboundMessage>){
match &mut self.connection {
Some(conn) => {conn.start_listening_on_new_thread(sender_channel) }
None => {
println!("Could not start listening! No connection :( ")
}
}
}
pub async fn send_message(&mut self, message: SocketMessage )
{
match &mut self.connection {
Some(conn) => { conn.send_message(message).await }
None => {
println!("Could not send message! No connection :( ")
}
}
}
}
use serde::{Serialize,Deserialize};
use tokio_tungstenite::tungstenite::Message;
use uuid;
//create custom message types for your app to serialize and send through. See struct 'WrappedMessage'
//use shared::net::message_types::{ ClientMessage, ServerMessage };
//this is essentially the same as 'message' but we can use serde directives on it ( a hack -- there might be a better way )
#[derive(Serialize,Deserialize,Clone)]
pub enum SocketMessage {
Text(String),
Binary(Vec<u8>),
Unknown,
Close
}
impl SocketMessage {
pub fn from_message(msg: Message) -> Self {
match msg {
Message::Text(inner) => SocketMessage::Text(inner),
Message::Binary(inner) => SocketMessage::Binary(inner.into_iter().collect()),
Message::Close(_) => SocketMessage::Close,
_ => SocketMessage::Unknown,
}
}
pub fn to_message(&self) -> Message{
match self {
SocketMessage::Text(inner) => Message::Text(inner.to_string()),
SocketMessage::Binary(inner) => Message::Binary(inner.to_vec()),
_ => Message::Text("Unknown!".to_string())
}
}
//should throw an error instead !
pub fn to_string(&self) -> String{
match self {
SocketMessage::Text(inner) => inner.to_string(),
SocketMessage::Binary(inner) => format!("{:?}",inner),
_ => "Unknown!".to_string()
}
}
}
#[derive(Serialize, Deserialize,Debug ,Clone)]
pub enum SocketMessageDestination {
All,
ClientConnection(String), //client connection uuid
Room(String),
}
#[derive(Serialize, Deserialize,Debug ,Clone)]
pub enum WrappedMessageDestination {
All,
Client(String), //client uuid
Room(String),
ResponseToMsg(String), //message uuid
Server //EcosystemServer) //server type
}
#[derive(Serialize,Deserialize,Clone)]
pub struct OutboundMessage {
pub destination: SocketMessageDestination,
pub message: SocketMessage
}
#[derive(Serialize,Deserialize,Clone)]
pub struct InboundMessage {
pub socket_connection_uuid: String,
pub message: SocketMessage
}
impl InboundMessage {
pub fn new(socket_connection_uuid:String, msg:Message) -> Self {
let message = SocketMessage::from_message(msg); //text( msg.clone().into_text().unwrap() ) ;
Self{
socket_connection_uuid,
message
}
}
}
/*
#[derive(Serialize, Deserialize,Debug ,Clone)]
pub struct WrappedMessage {
pub destination: WrappedMessageDestination,
pub message_uuid: String, //used so we can respond to it
pub contents: WrappedMessageContents // the contents and the From info
}
impl WrappedMessage {
pub fn from_stringified( raw_msg_string:String ) -> Result<Self, serde_json::Error> {
let message:WrappedMessage = serde_json::from_str( &raw_msg_string )?;
Ok(message)
}
pub fn to_stringified(&self) -> Result<String, serde_json::Error> {
let message_string = serde_json::to_string(&self)?;
Ok(message_string)
}
pub fn wrap(
destination:WrappedMessageDestination,
contents:WrappedMessageContents
) -> Self {
let message_uuid = uuid::Uuid::new_v4().to_string();
let wrapped_message = WrappedMessage{
message_uuid,
destination,
contents
};
wrapped_message
}
}
#[derive(Serialize, Deserialize,Debug ,Clone)]
#[serde(tag = "msg_type", content = "data")]
pub enum WrappedMessageContents {
ClientMsg(ClientMessage),
ServerMsg(ServerMessage)
}
impl WrappedMessageContents {
pub fn from_stringified( raw_msg_string:String ) -> Result<Self, serde_json::Error> {
let message:WrappedMessageContents = serde_json::from_str( &raw_msg_string )?;
Ok(message)
}
}
*/
use futures_util::StreamExt;
use futures_util::stream::SplitSink;
use futures_util::future::join_all;
use tokio_tungstenite::WebSocketStream;
use tokio::net::{TcpListener, TcpStream};
use futures::SinkExt;
use std::collections::HashMap;
use std::thread;
use tokio::sync::RwLock;
use std::sync::{Arc};
use tokio::sync:: Mutex;
use tokio_tungstenite::tungstenite::Message;
use shared::util::rand::generate_random_uuid;
use std::collections::HashSet;
use crossbeam_channel::{ unbounded, Receiver, Sender, TryRecvError};
use super::websocket_messages::{
SocketMessage,
SocketMessageDestination,
InboundMessage,
OutboundMessage
};
type ClientsMap = Arc<RwLock<HashMap<String, ClientConnection>>>;
type RoomsMap = Arc<RwLock<HashMap<String, HashSet<String>>>>;
type TxSink = Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>;
type RxSink = Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>;
#[derive(Clone)]
pub struct ClientConnection {
pub client_socket_uuid: String,
pub addr: String,
pub tx_sink: TxSink,
}
impl ClientConnection {
pub fn new( addr:String, client_tx: SplitSink<WebSocketStream<tokio::net::TcpStream>, Message> ) -> Self{
Self {
client_socket_uuid: generate_random_uuid(),
addr: addr.clone(),
tx_sink: Arc::new(Mutex::new( client_tx ))
}
}
pub async fn send_message(&self, msg: Message) -> Result<(), tokio_tungstenite::tungstenite::error::Error> {
self.tx_sink.lock().await.send(msg).await
}
}
pub struct WebsocketServer{
clients: ClientsMap,
rooms: RoomsMap, // room name -> Set[client_uuid]
//let (sender, receiver): (Sender<T>, Receiver<T>) = unbounded();
global_recv_tx: Sender<InboundMessage>, //passed to each client connection
global_recv_rx: Receiver<InboundMessage>,
global_send_tx: Sender<OutboundMessage>,
global_send_rx: Receiver<OutboundMessage>,
}
impl WebsocketServer {
pub fn new() -> Self {
let (global_recv_tx, global_recv_rx): (Sender<InboundMessage>, Receiver<InboundMessage>) = unbounded();
let (global_send_tx, global_send_rx): (Sender<OutboundMessage>, Receiver<OutboundMessage>) = unbounded();
Self {
clients: Arc::new(RwLock::new(HashMap::new())),
rooms: Arc::new(RwLock::new(HashMap::new())),
global_recv_tx,
global_recv_rx,
global_send_tx,
global_send_rx,
}
}
pub fn start_in_thread(&mut self, url: Option<String>) ->
std::io::Result< std::thread::JoinHandle<()> > {
let clients = Arc::clone(&self.clients);
let rooms = Arc::clone(&self.rooms);
let global_recv_channel = self.global_recv_tx.clone();
let global_send_channel = self.global_send_rx.clone();
let accept_connections_thread = thread::spawn(move || { //use a non-tokio thread here
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async {
let addr: String = url.unwrap_or_else(|| "127.0.0.1:8080".to_string());
// Create the event loop and TCP listener we'll accept connections on.
let try_socket = TcpListener::bind(&addr).await;
let listener = try_socket.expect("Failed to bind");
println!("Listening on: {}", addr);
let accept_connections = Self::try_accept_new_connections( Arc::clone(&clients), listener,global_recv_channel );
let send_outbound_messages = Self::try_send_outbound_messages(
Arc::clone(&clients),
Arc::clone(&rooms),
global_send_channel
);
tokio::try_join!(accept_connections,send_outbound_messages);
});
});
println!("Started websocket server");
Ok( accept_connections_thread )
}
pub async fn start(&mut self, url:Option<String>) -> std::io::Result<()> {
let clients = Arc::clone(&self.clients);
let rooms = Arc::clone(&self.rooms);
let global_recv_channel = self.global_recv_tx.clone();
let global_send_channel = self.global_send_rx.clone();
let addr: String = url.unwrap_or_else(|| "127.0.0.1:8080".to_string());
// Create the event loop and TCP listener we'll accept connections on.
let try_socket = TcpListener::bind(&addr).await;
let listener = try_socket.expect("Failed to bind");
println!("Listening on: {}", addr);
let accept_connections = Self::try_accept_new_connections( Arc::clone(&clients), listener,global_recv_channel );
let send_outbound_messages = Self::try_send_outbound_messages(
Arc::clone(&clients) ,
Arc::clone(&rooms),
global_send_channel
);
tokio::try_join!(accept_connections, send_outbound_messages);
Ok(())
}
//recv'd client messages are fed into here
pub fn get_recv_channel(&self) -> Receiver<InboundMessage> {
self.global_recv_rx.clone()
}
pub fn get_send_channel(&self) -> Sender<OutboundMessage> {
self.global_send_tx.clone()
}
pub async fn send_outbound_message(&self, msg:OutboundMessage) {
Self::broadcast(
Arc::clone(&self.clients),
Arc::clone(&self.rooms),
msg ).await;
}
async fn get_cloned_clients(clients: &ClientsMap) -> Vec<ClientConnection> {
let clients_map = clients.read().await;
clients_map.values().cloned().collect()
}
async fn get_cloned_clients_in_room(clients: &ClientsMap, rooms: &RoomsMap, room_name: String ) -> Vec<ClientConnection> {
let client_connection_uuids = Vec::new();
let rooms = rooms.read().await;
match rooms.get(&room_name) {
Some(uuid_set) => {}
None => {}
}
return Self::get_cloned_clients_filtered(clients, client_connection_uuids).await;
}
async fn get_cloned_clients_filtered(clients: &ClientsMap, client_connection_uuids: Vec<String> ) -> Vec<ClientConnection> {
let clients_map = clients.read().await;
let mut filtered_clients: Vec<ClientConnection> = Vec::new();
for uuid in client_connection_uuids {
if let Some(client_conn) = clients_map.get(&uuid) {
filtered_clients.push(client_conn.clone());
}
}
filtered_clients
}
async fn get_cloned_client_specific(clients: &ClientsMap, client_connection_uuid: String ) -> Vec<ClientConnection> {
let clients_map = clients.read().await;
let mut filtered_clients: Vec<ClientConnection> = Vec::new();
if let Some(client_conn) = clients_map.get(&client_connection_uuid) {
filtered_clients.push(client_conn.clone());
}
filtered_clients
}
pub async fn try_send_outbound_messages(
clients_map: ClientsMap,
rooms_map: RoomsMap,
global_send_rx: Receiver<OutboundMessage>
) -> std::io::Result<()> {
loop {
match global_send_rx.try_recv() {
Ok(msg) => {
// let message = msg;
let clients_map = Arc::clone(&clients_map);
let rooms_map = Arc::clone(&rooms_map);
println!("try send outbound message 2 " );
Self::broadcast(clients_map, rooms_map, msg).await;
}
Err(TryRecvError::Empty) => {
// No messages available right now, sleep for a short duration
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
Err(TryRecvError::Disconnected) => break,
}
}
Ok(())
}
pub async fn add_client_to_room(&self, client_connection_uuid:String, room_name: String ) {
let mut rooms = self.rooms.write().await;
let room_clients = rooms.entry(room_name).or_insert_with(HashSet::new);
room_clients.insert(client_connection_uuid);
}
pub async fn remove_client_from_room(&self, client_connection_uuid:String, room_name: String ) {
let mut rooms = self.rooms.write().await;
if let Some(room_clients) = rooms.get_mut(&room_name) {
room_clients.remove(&client_connection_uuid);
// Optionally, you can remove the room if it's now empty
if room_clients.is_empty() {
rooms.remove(&room_name);
}
}
}
pub async fn broadcast(
clients_map: ClientsMap,
rooms_map:RoomsMap,
outbound_message: OutboundMessage
) {
println!("broadcasting msg: {} ", outbound_message.message.to_string() );
let socket_message = outbound_message.message;
let client_connections = match outbound_message.destination {
SocketMessageDestination::All => Self::get_cloned_clients(&clients_map).await,
SocketMessageDestination::Room(room_name) => Self::get_cloned_clients_in_room(&clients_map,&rooms_map,room_name).await,
SocketMessageDestination::ClientConnection(client_connection_uuid) => Self::get_cloned_client_specific(&clients_map,client_connection_uuid).await,
// MessageDestination::ResponseToMsg(msg_uuid) => {},
// MessageDestination::Server => {}
};
Self::broadcast_to_connections(client_connections, socket_message).await;
}
pub async fn broadcast_to_connections( connections: Vec<ClientConnection>, socket_message: SocketMessage) {
let message = socket_message.to_message();
//Could cause thread lock issue !?
let send_futures: Vec<_> = {
connections
.iter()
.map(|client| {
let message = message.clone();
client.send_message(message)
})
.collect()
};
let results = join_all(send_futures).await;
for result in results {
if let Err(err) = result {
eprintln!("Failed to send a message: {}", err);
}
}
}
pub async fn try_accept_new_connections(
clients_map: ClientsMap,
listener: TcpListener,
global_recv_tx: Sender<InboundMessage>
) -> std::io::Result<()> {
while let Ok((stream, _)) = listener.accept().await {
let clients_map = Arc::clone(&clients_map);
tokio::spawn(Self::accept_connection(clients_map, stream, global_recv_tx.clone()));
}
Ok(())
}
async fn accept_connection(
clients: ClientsMap,
raw_stream: TcpStream,
global_socket_tx: Sender<InboundMessage>
) {
let addr = raw_stream
.peer_addr()
.expect("connected streams should have a peer address")
.to_string();
let ws_stream = tokio_tungstenite::accept_async(raw_stream)
.await
.expect("Error during the websocket handshake occurred");
println!("New WebSocket connection: {}", addr);
let ( client_tx, mut client_rx) = ws_stream.split(); //this is how i can read and write to this client
let new_client_connection = ClientConnection::new( addr.clone(), client_tx );
let client_uuid = new_client_connection.client_socket_uuid.clone();
clients.write().await.insert(
new_client_connection.client_socket_uuid.clone(),
new_client_connection
);
//in this new thread for the socket connection, recv'd messages are constantly collected
while let Some(msg) = client_rx.next().await {
match msg {
Ok(msg) => {
if msg.is_text() || msg.is_binary() {
let data = msg.clone().into_data();
println!("Received a message from {}: {:?}", addr, data);
// here you can consume your messages
let client_msg = InboundMessage::new(
client_uuid.clone(),
msg
);
global_socket_tx.send( client_msg );
}
}
Err(e) => {
eprintln!(
"an error occurred while processing incoming messages: {:?}",
e
);
break;
}
}
}
// Remove the client from the map once it has disconnected.
clients.write().await.remove(&addr);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment