Skip to content

Instantly share code, notes, and snippets.

@jniltinho
Last active March 30, 2024 00:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jniltinho/8ff31a444c9aa44d048b747dfc4ecf91 to your computer and use it in GitHub Desktop.
Save jniltinho/8ff31a444c9aa44d048b747dfc4ecf91 to your computer and use it in GitHub Desktop.
TCP Server Golang

TCP Server Golang

Build a concurrent TCP server in Go with graceful shutdown include unit tests

TCP (Transmission Control Protocol) is a protocol that provides reliable, ordered, and error-checked delivery of data between applications running on hosts communicating over an IP network. A TCP server is a program that listens for incoming TCP connections and handles them by providing services or sending data to the client. In this article, we will discuss how to implement a TCP server in Golang that is concurrent and supports graceful shutdown. We will also provide an example of a unit test for this implementation.

If you’re new to socket programming in Go, you may want to check out my previous article on the basics of socket programming in Go, which covers the fundamental concepts and includes some sample code. TCP Server Implementation

The TCP server implementation consists of the following steps:

Create a new net.Listener object that listens for incoming connections on a specified address using the net.Listen function.
Start two goroutines to handle incoming connections concurrently: one to accept new connections and another to handle them.
In the acceptConnections goroutine, use a for loop and a select statement to listen for incoming connections on the listener and send them over a channel.
In the handleConnections goroutine, use a for loop and a select statement to receive connections from the channel and handle them in separate goroutines.
In the handleConnection function, handle the incoming connection by performing any necessary processing or sending data to the client.
Implement graceful shutdown by creating a shutdown channel that signals to the goroutines that they should stop processing connections. Close the listener and wait for the goroutines to finish using a sync.WaitGroup.

Here is the complete code for the TCP server implementation:

In this example, the server creates two goroutines to handle incoming connections and manage them concurrently. The acceptConnections method listens for new connections and sends them over a channel, while the handleConnections method receives these connections and handles them in separate goroutines.

The server also implements graceful shutdown by using a shutdown channel to signal to the goroutines that they should stop processing connections, and a WaitGroup to wait for them to finish before closing the listener and terminating the program.

To start the server, we create a new instance of the server struct and call its Start method. To stop the server, we wait for a SIGINT or SIGTERM signal, and then call the Stop method.

Note that this code is a simple example, and you may need to modify it to suit your specific use case. Unit Test Implementation

To test the TCP server implementation, we will use the standard Golang testing framework. We will create a test case that starts the server, connects to it using a TCP client, sends a message, and verifies that the server received the message.

Here is the code for the unit test implementation:

In the TestTCP function, we start the TCP server using the newServer function and connect to it using a TCP client using the net.Dial function. We then send a message to the server using fmt.Fprintf and verify that the server received the message by reading from the connection and checking the response using an if statement.

To run the test, we can use the go test command in the terminal:

$ go test
PASS
ok   command-line-arguments 0.003s

Conclusion

In this article, we discussed how to implement a TCP server in Golang that is concurrent and supports graceful shutdown. We also provided an example of a unit test for this implementation. The implementation uses goroutines to handle incoming connections and gracefully shutdown, and the unit test verifies that the server can handle incoming connections and messages correctly. This implementation can be used as a starting point for building more complex TCP servers in Golang.

Link

## Note that this code uses the Tokio runtime for asynchronous I/O.
## You'll need to add the following dependencies to your Cargo.toml file:
tokio = { version = "1", features = ["full"] }
tokio-util = { version = "0.7", features = ["compat"] }
package main
import (
"fmt"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
)
type server struct {
wg sync.WaitGroup
listener net.Listener
shutdown chan struct{}
connection chan net.Conn
}
func newServer(address string) (*server, error) {
listener, err := net.Listen("tcp", address)
if err != nil {
return nil, fmt.Errorf("failed to listen on address %s: %w", address, err)
}
return &server{
listener: listener,
shutdown: make(chan struct{}),
connection: make(chan net.Conn),
}, nil
}
func (s *server) acceptConnections() {
defer s.wg.Done()
for {
select {
case <-s.shutdown:
return
default:
conn, err := s.listener.Accept()
if err != nil {
continue
}
s.connection <- conn
}
}
}
func (s *server) handleConnections() {
defer s.wg.Done()
for {
select {
case <-s.shutdown:
return
case conn := <-s.connection:
go s.handleConnection(conn)
}
}
}
func (s *server) handleConnection(conn net.Conn) {
defer conn.Close()
// Add your logic for handling incoming connections here
fmt.Fprintf(conn, "Welcome to my TCP server!\n")
time.Sleep(5 * time.Second)
fmt.Fprintf(conn, "Goodbye!\n")
}
func (s *server) Start() {
s.wg.Add(2)
go s.acceptConnections()
go s.handleConnections()
}
func (s *server) Stop() {
close(s.shutdown)
s.listener.Close()
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
return
case <-time.After(time.Second):
fmt.Println("Timed out waiting for connections to finish.")
return
}
}
func main() {
s, err := newServer(":8080")
if err != nil {
fmt.Println(err)
os.Exit(1)
}
s.Start()
// Wait for a SIGINT or SIGTERM signal to gracefully shut down the server
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
fmt.Println("Shutting down server...")
s.Stop()
fmt.Println("Server stopped.")
}
package main
import (
"net"
"testing"
"time"
)
func TestServer(t *testing.T) {
// Start the server
s, err := newServer(":8080")
if err != nil {
t.Fatal(err)
}
s.Start()
// Connect to the server and send a message
conn, err := net.Dial("tcp", "localhost:8080")
if err != nil {
t.Fatal(err)
}
defer conn.Close()
expected := "Welcome to my TCP server!\n"
actual := make([]byte, len(expected))
if _, err := conn.Read(actual); err != nil {
t.Fatal(err)
}
if string(actual) != expected {
t.Errorf("expected %q, but got %q", expected, actual)
}
// Wait for the server to handle the connection
time.Sleep(6 * time.Second)
expected = "Goodbye!\n"
actual = make([]byte, len(expected))
if _, err := conn.Read(actual); err != nil {
t.Fatal(err)
}
if string(actual) != expected {
t.Errorf("expected %q, but got %q", expected, actual)
}
// Stop the server
s.Stop()
}
use std::collections::VecDeque;
use std::net::TcpListener;
use std::os::unix::io::AsRawFd;
use std::sync::Arc;
use std::sync::mpsc::{channel, Receiver, Sender, SyncSender};
use std::thread;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::prelude::*;
use tokio::sync::mpsc::unbounded_channel;
use tokio::sync::RwLock;
type ConnectionHandler = Box<dyn FnMut(TcpStream) + Send + 'static>;
struct Server {
listener: TcpListener,
shutdown: Arc<RwLock<bool>>,
connection_tx: SyncSender<TcpStream>,
handle_connection: ConnectionHandler,
}
impl Server {
async fn new(address: &str) -> Result<Server, Box<dyn std::error::Error>> {
let listener = TcpListener::bind(address)?;
let (connection_tx, connection_rx) = unbounded_channel();
let shutdown = Arc::new(RwLock::new(false));
let handle_connection = Box::new(move |stream: TcpStream| {
let shutdown = Arc::clone(&shutdown);
let connection_tx = connection_tx.clone();
async move {
let mut stream = TcpStream::from_std(stream);
// Add your logic for handling incoming connections here
let mut buf = vec![0; 1024];
stream.read_exact(&mut buf).await.unwrap();
let welcome_message = format!("Welcome to my TCP server!\n");
stream.write_all(welcome_message.as_bytes()).await.unwrap();
tokio::time::sleep(Duration::from_secs(5)).await;
let goodbye_message = format!("Goodbye!\n");
stream.write_all(goodbye_message.as_bytes()).await.unwrap();
connection_tx.send(stream).unwrap();
}
});
Ok(Server {
listener,
shutdown,
connection_tx,
handle_connection,
})
}
async fn start(mut self) -> Result<(), Box<dyn std::error::Error>> {
let mut tasks = VecDeque::new();
// Start the listener thread
let shutdown = Arc::clone(&self.shutdown);
tasks.push_back(tokio::spawn(async move {
Self::accept_connections(shutdown, self.listener).await;
}));
// Start the connection handler thread
let connection_rx = self.connection_tx.clone();
tasks.push_back(tokio::spawn(async move {
Self::handle_connections(connection_rx, self.handle_connection).await;
}));
// Wait for all tasks to complete
while let Some(task) = tasks.pop_front() {
task.await?;
}
Ok(())
}
async fn accept_connections(shutdown: Arc<RwLock<bool>>, listener: TcpListener) {
let mut tasks = VecDeque::new();
// Wait for the shutdown signal
let mut interval = tokio::time::interval(Duration::from_millis(100));
loop {
let mut shutdown_locked = shutdown.write().await;
if *shutdown_locked {
break;
}
// Accept incoming connections
match listener.accept().await {
Ok((stream, _)) => {
let handle_connection = self.handle_connection.clone();
let connection_tx = self.connection_tx.clone();
let task = tokio::spawn(async move {
handle_connection(stream);
connection_tx.send(stream).unwrap();
});
tasks.push_back(task);
}
Err(e) => {
eprintln!("Error accepting connection: {}", e);
}
}
// Check for shutdown signal again
interval.tick().await;
}
// Wait for all tasks to complete
while let Some(task) = tasks.pop_front() {
task.await?;
}
}
async fn handle_connections(mut connection_rx: Receiver<TcpStream>, handle_connection: ConnectionHandler) {
while let Ok(stream) = connection_rx.recv().await {
let handle_connection = handle_connection.clone();
let task = tokio::spawn(async move {
handle_connection(stream);
});
tasks.push_back(task);
}
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut server = Server::new("127.0.0.1:8080").await?;
server.start().await?;
// Wait for a SIGINT or SIGTERM signal to gracefully shut down the server
let mut sig_set = std::collections::HashSet::new();
sig_set.insert(syscall::SIGINT);
sig_set.insert(syscall::SIGTERM);
let mut sig_set = sig_set.into_iter().collect::<Vec<_>>();
let mut sig_set = tokio::signal::unix::signal(sig_set)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
sig_set.recv().await?;
println!("Shutting down server...");
server.shutdown.write().await.replace(true);
server.listener.shutdown().await?;
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment