Skip to content

Instantly share code, notes, and snippets.

@gaukas
Last active December 7, 2023 23:08
Show Gist options
  • Save gaukas/66fbaf3faf0e4ee83cb84e302c0a547a to your computer and use it in GitHub Desktop.
Save gaukas/66fbaf3faf0e4ee83cb84e302c0a547a to your computer and use it in GitHub Desktop.
Close TCP Connection RST or FIN+ACK at will
//go:build unix
package main
import (
"errors"
"fmt"
"io"
"net"
"os"
"syscall"
"time"
)
func main() {
// TCP Listening on :38000
tcpLis, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4zero, Port: 38000})
if err != nil {
panic(err)
}
go acceptTCP(tcpLis)
time.Sleep(1 * time.Second)
// first TCP connection
tcpConn1, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4zero, Port: 38000})
if err != nil {
panic(err)
}
// second TCP connection
tcpConn2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4zero, Port: 38000})
if err != nil {
panic(err)
}
// third TCP connection
tcpConn3, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4zero, Port: 38000})
if err != nil {
panic(err)
}
// send something on all three connections
if _, err := tcpConn1.Write([]byte("hello")); err != nil {
panic(err)
}
if _, err := tcpConn2.Write([]byte("hello")); err != nil {
panic(err)
}
if _, err := tcpConn3.Write([]byte("hello")); err != nil {
panic(err)
}
// check the received data
buf := make([]byte, 1024)
n, err := tcpConn1.Read(buf)
if err != nil {
panic(err)
}
if string(buf[:n]) != "HELLO" {
panic("unexpected response")
}
n, err = tcpConn2.Read(buf)
if err != nil {
panic(err)
}
if string(buf[:n]) != "HELLO" {
panic("unexpected response")
}
n, err = tcpConn3.Read(buf)
if err != nil {
panic(err)
}
if string(buf[:n]) != "HELLO" {
panic("unexpected response")
}
// ask the first connection to RST
if _, err := tcpConn1.Write([]byte("rst")); err != nil {
panic(err)
}
// read from the first connection
n, err = tcpConn1.Read(buf)
if err == nil {
panic("expected connection reset, but got data")
}
if !errors.Is(err, syscall.ECONNRESET) {
panic("expected ECONNRESET, but got other error: " + err.Error())
}
// ask the second connection to FIN
if _, err := tcpConn2.Write([]byte("fin")); err != nil {
panic(err)
}
time.Sleep(10 * time.Millisecond)
tcpConn2.Write([]byte("hello")) // send another frame so it was supposed to be RST
// read from the second connection
n, err = tcpConn2.Read(buf)
if err == nil {
panic("expected closed, but got data: " + string(buf[:n]))
}
if !errors.Is(err, io.EOF) {
panic("expected EOF, but got other error: " + err.Error())
}
// the third connection should still be alive
if _, err := tcpConn3.Write([]byte("hello")); err != nil {
panic(err)
}
n, err = tcpConn3.Read(buf)
if err != nil {
panic(err)
}
if string(buf[:n]) != "HELLO" {
panic("unexpected response")
}
// ask the third connection to RST
if _, err := tcpConn3.Write([]byte("rst")); err != nil {
panic(err)
}
// read from the third connection
n, err = tcpConn3.Read(buf)
if err == nil {
panic("expected connection reset, but got data")
}
if !errors.Is(err, syscall.ECONNRESET) {
panic("expected ECONNRESET, but got other error: " + err.Error())
}
// the third connection should be closed
if _, err := tcpConn3.Write([]byte("hello")); err == nil {
panic("expected closed, but got data")
}
if !errors.Is(err, syscall.ECONNRESET) {
panic("expected ECONNRESET, but got other error: " + err.Error())
}
}
func acceptTCP(tcpLis *net.TCPListener) {
defer tcpLis.Close()
for {
tcpConn, err := tcpLis.AcceptTCP()
if err != nil {
tcpLis.Close()
return
}
go handleTCPConn(tcpConn)
}
}
func handleTCPConn(tcpConn *net.TCPConn) {
defer tcpConn.Close()
for {
buf := make([]byte, 1024)
n, err := tcpConn.Read(buf)
if err != nil {
return
}
// echo the capitolized string back to the client
cap := make([]byte, n)
for i := 0; i < n; i++ {
cap[i] = buf[i] - 32
}
if string(cap) == "RST" {
if err := rstConn(tcpConn); err != nil {
fmt.Println(err)
}
} else if string(cap) == "FIN" {
time.Sleep(1 * time.Second) // wait for next frame to come so RST will be sent by default
if err := finConnClean(tcpConn); err != nil {
fmt.Println(err)
}
} else {
_, err = tcpConn.Write(cap)
if err != nil {
return
}
continue
}
}
}
func rstConn(tcpConn *net.TCPConn) error {
fmt.Printf("Closing TCP connection from %s with RST...\n", tcpConn.RemoteAddr().String())
// get the raw conn
rawConn, err := tcpConn.SyscallConn()
if err != nil {
return err
}
// Call control function
err = rawConn.Control(func(fd uintptr) {
syscall.SetsockoptLinger(int(fd), syscall.SOL_SOCKET, syscall.SO_LINGER, &syscall.Linger{
Onoff: 1,
Linger: 0,
})
})
if err != nil {
return err
}
return tcpConn.Close()
}
func finConn(tcpConn *net.TCPConn) error {
fmt.Printf("Closing TCP connection from %s with FIN/FIN+ACK...\n", tcpConn.RemoteAddr().String())
// get the raw conn
rawConn, err := tcpConn.SyscallConn()
if err != nil {
return err
}
// Call control function
err = rawConn.Control(func(fd uintptr) {
syscall.Shutdown(int(fd), syscall.SHUT_WR)
})
if err != nil {
return err
}
return tcpConn.Close()
}
func finConnClean(tcpConn *net.TCPConn) error {
fmt.Printf("Closing TCP connection from %s with FIN/FIN+ACK...\n", tcpConn.RemoteAddr().String())
// // get the raw conn
// rawConn, err := tcpConn.SyscallConn()
// if err != nil {
// return err
// }
// Flush the read buffer
var buf = make([]byte, 65535)
for {
tcpConn.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
n, err := tcpConn.Read(buf)
if err != nil {
if errors.Is(err, syscall.ETIMEDOUT) || errors.Is(err, os.ErrDeadlineExceeded) {
break
}
return err
}
fmt.Printf("Flushed %d bytes\n", n)
}
fmt.Println("Flushed all bytes")
// // Call control function
// err = rawConn.Control(func(fd uintptr) {
// syscall.Shutdown(int(fd), syscall.SHUT_RDWR)
// })
// if err != nil {
// return err
// }
return tcpConn.Close()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment