|
use std::collections::VecDeque; |
|
use std::net::TcpListener; |
|
use std::os::unix::io::AsRawFd; |
|
use std::sync::Arc; |
|
use std::sync::mpsc::{channel, Receiver, Sender, SyncSender}; |
|
use std::thread; |
|
use std::time::Duration; |
|
|
|
use tokio::net::TcpStream; |
|
use tokio::prelude::*; |
|
use tokio::sync::mpsc::unbounded_channel; |
|
use tokio::sync::RwLock; |
|
|
|
type ConnectionHandler = Box<dyn FnMut(TcpStream) + Send + 'static>; |
|
|
|
struct Server { |
|
listener: TcpListener, |
|
shutdown: Arc<RwLock<bool>>, |
|
connection_tx: SyncSender<TcpStream>, |
|
handle_connection: ConnectionHandler, |
|
} |
|
|
|
impl Server { |
|
async fn new(address: &str) -> Result<Server, Box<dyn std::error::Error>> { |
|
let listener = TcpListener::bind(address)?; |
|
|
|
let (connection_tx, connection_rx) = unbounded_channel(); |
|
let shutdown = Arc::new(RwLock::new(false)); |
|
|
|
let handle_connection = Box::new(move |stream: TcpStream| { |
|
let shutdown = Arc::clone(&shutdown); |
|
let connection_tx = connection_tx.clone(); |
|
|
|
async move { |
|
let mut stream = TcpStream::from_std(stream); |
|
|
|
// Add your logic for handling incoming connections here |
|
let mut buf = vec![0; 1024]; |
|
stream.read_exact(&mut buf).await.unwrap(); |
|
let welcome_message = format!("Welcome to my TCP server!\n"); |
|
stream.write_all(welcome_message.as_bytes()).await.unwrap(); |
|
tokio::time::sleep(Duration::from_secs(5)).await; |
|
let goodbye_message = format!("Goodbye!\n"); |
|
stream.write_all(goodbye_message.as_bytes()).await.unwrap(); |
|
|
|
connection_tx.send(stream).unwrap(); |
|
} |
|
}); |
|
|
|
Ok(Server { |
|
listener, |
|
shutdown, |
|
connection_tx, |
|
handle_connection, |
|
}) |
|
} |
|
|
|
async fn start(mut self) -> Result<(), Box<dyn std::error::Error>> { |
|
let mut tasks = VecDeque::new(); |
|
|
|
// Start the listener thread |
|
let shutdown = Arc::clone(&self.shutdown); |
|
tasks.push_back(tokio::spawn(async move { |
|
Self::accept_connections(shutdown, self.listener).await; |
|
})); |
|
|
|
// Start the connection handler thread |
|
let connection_rx = self.connection_tx.clone(); |
|
tasks.push_back(tokio::spawn(async move { |
|
Self::handle_connections(connection_rx, self.handle_connection).await; |
|
})); |
|
|
|
// Wait for all tasks to complete |
|
while let Some(task) = tasks.pop_front() { |
|
task.await?; |
|
} |
|
|
|
Ok(()) |
|
} |
|
|
|
async fn accept_connections(shutdown: Arc<RwLock<bool>>, listener: TcpListener) { |
|
let mut tasks = VecDeque::new(); |
|
|
|
// Wait for the shutdown signal |
|
let mut interval = tokio::time::interval(Duration::from_millis(100)); |
|
loop { |
|
let mut shutdown_locked = shutdown.write().await; |
|
if *shutdown_locked { |
|
break; |
|
} |
|
|
|
// Accept incoming connections |
|
match listener.accept().await { |
|
Ok((stream, _)) => { |
|
let handle_connection = self.handle_connection.clone(); |
|
let connection_tx = self.connection_tx.clone(); |
|
let task = tokio::spawn(async move { |
|
handle_connection(stream); |
|
connection_tx.send(stream).unwrap(); |
|
}); |
|
tasks.push_back(task); |
|
} |
|
Err(e) => { |
|
eprintln!("Error accepting connection: {}", e); |
|
} |
|
} |
|
|
|
// Check for shutdown signal again |
|
interval.tick().await; |
|
} |
|
|
|
// Wait for all tasks to complete |
|
while let Some(task) = tasks.pop_front() { |
|
task.await?; |
|
} |
|
} |
|
|
|
async fn handle_connections(mut connection_rx: Receiver<TcpStream>, handle_connection: ConnectionHandler) { |
|
while let Ok(stream) = connection_rx.recv().await { |
|
let handle_connection = handle_connection.clone(); |
|
let task = tokio::spawn(async move { |
|
handle_connection(stream); |
|
}); |
|
tasks.push_back(task); |
|
} |
|
} |
|
} |
|
|
|
#[tokio::main] |
|
async fn main() -> Result<(), Box<dyn std::error::Error>> { |
|
let mut server = Server::new("127.0.0.1:8080").await?; |
|
|
|
server.start().await?; |
|
|
|
// Wait for a SIGINT or SIGTERM signal to gracefully shut down the server |
|
let mut sig_set = std::collections::HashSet::new(); |
|
sig_set.insert(syscall::SIGINT); |
|
sig_set.insert(syscall::SIGTERM); |
|
let mut sig_set = sig_set.into_iter().collect::<Vec<_>>(); |
|
let mut sig_set = tokio::signal::unix::signal(sig_set) |
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; |
|
|
|
sig_set.recv().await?; |
|
|
|
println!("Shutting down server..."); |
|
server.shutdown.write().await.replace(true); |
|
server.listener.shutdown().await?; |
|
|
|
Ok(()) |
|
} |