Last active
December 7, 2023 23:08
-
-
Save gaukas/66fbaf3faf0e4ee83cb84e302c0a547a to your computer and use it in GitHub Desktop.
Close TCP Connection RST or FIN+ACK at will
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//go:build unix | |
package main | |
import ( | |
"errors" | |
"fmt" | |
"io" | |
"net" | |
"os" | |
"syscall" | |
"time" | |
) | |
func main() { | |
// TCP Listening on :38000 | |
tcpLis, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4zero, Port: 38000}) | |
if err != nil { | |
panic(err) | |
} | |
go acceptTCP(tcpLis) | |
time.Sleep(1 * time.Second) | |
// first TCP connection | |
tcpConn1, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4zero, Port: 38000}) | |
if err != nil { | |
panic(err) | |
} | |
// second TCP connection | |
tcpConn2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4zero, Port: 38000}) | |
if err != nil { | |
panic(err) | |
} | |
// third TCP connection | |
tcpConn3, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4zero, Port: 38000}) | |
if err != nil { | |
panic(err) | |
} | |
// send something on all three connections | |
if _, err := tcpConn1.Write([]byte("hello")); err != nil { | |
panic(err) | |
} | |
if _, err := tcpConn2.Write([]byte("hello")); err != nil { | |
panic(err) | |
} | |
if _, err := tcpConn3.Write([]byte("hello")); err != nil { | |
panic(err) | |
} | |
// check the received data | |
buf := make([]byte, 1024) | |
n, err := tcpConn1.Read(buf) | |
if err != nil { | |
panic(err) | |
} | |
if string(buf[:n]) != "HELLO" { | |
panic("unexpected response") | |
} | |
n, err = tcpConn2.Read(buf) | |
if err != nil { | |
panic(err) | |
} | |
if string(buf[:n]) != "HELLO" { | |
panic("unexpected response") | |
} | |
n, err = tcpConn3.Read(buf) | |
if err != nil { | |
panic(err) | |
} | |
if string(buf[:n]) != "HELLO" { | |
panic("unexpected response") | |
} | |
// ask the first connection to RST | |
if _, err := tcpConn1.Write([]byte("rst")); err != nil { | |
panic(err) | |
} | |
// read from the first connection | |
n, err = tcpConn1.Read(buf) | |
if err == nil { | |
panic("expected connection reset, but got data") | |
} | |
if !errors.Is(err, syscall.ECONNRESET) { | |
panic("expected ECONNRESET, but got other error: " + err.Error()) | |
} | |
// ask the second connection to FIN | |
if _, err := tcpConn2.Write([]byte("fin")); err != nil { | |
panic(err) | |
} | |
time.Sleep(10 * time.Millisecond) | |
tcpConn2.Write([]byte("hello")) // send another frame so it was supposed to be RST | |
// read from the second connection | |
n, err = tcpConn2.Read(buf) | |
if err == nil { | |
panic("expected closed, but got data: " + string(buf[:n])) | |
} | |
if !errors.Is(err, io.EOF) { | |
panic("expected EOF, but got other error: " + err.Error()) | |
} | |
// the third connection should still be alive | |
if _, err := tcpConn3.Write([]byte("hello")); err != nil { | |
panic(err) | |
} | |
n, err = tcpConn3.Read(buf) | |
if err != nil { | |
panic(err) | |
} | |
if string(buf[:n]) != "HELLO" { | |
panic("unexpected response") | |
} | |
// ask the third connection to RST | |
if _, err := tcpConn3.Write([]byte("rst")); err != nil { | |
panic(err) | |
} | |
// read from the third connection | |
n, err = tcpConn3.Read(buf) | |
if err == nil { | |
panic("expected connection reset, but got data") | |
} | |
if !errors.Is(err, syscall.ECONNRESET) { | |
panic("expected ECONNRESET, but got other error: " + err.Error()) | |
} | |
// the third connection should be closed | |
if _, err := tcpConn3.Write([]byte("hello")); err == nil { | |
panic("expected closed, but got data") | |
} | |
if !errors.Is(err, syscall.ECONNRESET) { | |
panic("expected ECONNRESET, but got other error: " + err.Error()) | |
} | |
} | |
func acceptTCP(tcpLis *net.TCPListener) { | |
defer tcpLis.Close() | |
for { | |
tcpConn, err := tcpLis.AcceptTCP() | |
if err != nil { | |
tcpLis.Close() | |
return | |
} | |
go handleTCPConn(tcpConn) | |
} | |
} | |
func handleTCPConn(tcpConn *net.TCPConn) { | |
defer tcpConn.Close() | |
for { | |
buf := make([]byte, 1024) | |
n, err := tcpConn.Read(buf) | |
if err != nil { | |
return | |
} | |
// echo the capitolized string back to the client | |
cap := make([]byte, n) | |
for i := 0; i < n; i++ { | |
cap[i] = buf[i] - 32 | |
} | |
if string(cap) == "RST" { | |
if err := rstConn(tcpConn); err != nil { | |
fmt.Println(err) | |
} | |
} else if string(cap) == "FIN" { | |
time.Sleep(1 * time.Second) // wait for next frame to come so RST will be sent by default | |
if err := finConnClean(tcpConn); err != nil { | |
fmt.Println(err) | |
} | |
} else { | |
_, err = tcpConn.Write(cap) | |
if err != nil { | |
return | |
} | |
continue | |
} | |
} | |
} | |
func rstConn(tcpConn *net.TCPConn) error { | |
fmt.Printf("Closing TCP connection from %s with RST...\n", tcpConn.RemoteAddr().String()) | |
// get the raw conn | |
rawConn, err := tcpConn.SyscallConn() | |
if err != nil { | |
return err | |
} | |
// Call control function | |
err = rawConn.Control(func(fd uintptr) { | |
syscall.SetsockoptLinger(int(fd), syscall.SOL_SOCKET, syscall.SO_LINGER, &syscall.Linger{ | |
Onoff: 1, | |
Linger: 0, | |
}) | |
}) | |
if err != nil { | |
return err | |
} | |
return tcpConn.Close() | |
} | |
func finConn(tcpConn *net.TCPConn) error { | |
fmt.Printf("Closing TCP connection from %s with FIN/FIN+ACK...\n", tcpConn.RemoteAddr().String()) | |
// get the raw conn | |
rawConn, err := tcpConn.SyscallConn() | |
if err != nil { | |
return err | |
} | |
// Call control function | |
err = rawConn.Control(func(fd uintptr) { | |
syscall.Shutdown(int(fd), syscall.SHUT_WR) | |
}) | |
if err != nil { | |
return err | |
} | |
return tcpConn.Close() | |
} | |
func finConnClean(tcpConn *net.TCPConn) error { | |
fmt.Printf("Closing TCP connection from %s with FIN/FIN+ACK...\n", tcpConn.RemoteAddr().String()) | |
// // get the raw conn | |
// rawConn, err := tcpConn.SyscallConn() | |
// if err != nil { | |
// return err | |
// } | |
// Flush the read buffer | |
var buf = make([]byte, 65535) | |
for { | |
tcpConn.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) | |
n, err := tcpConn.Read(buf) | |
if err != nil { | |
if errors.Is(err, syscall.ETIMEDOUT) || errors.Is(err, os.ErrDeadlineExceeded) { | |
break | |
} | |
return err | |
} | |
fmt.Printf("Flushed %d bytes\n", n) | |
} | |
fmt.Println("Flushed all bytes") | |
// // Call control function | |
// err = rawConn.Control(func(fd uintptr) { | |
// syscall.Shutdown(int(fd), syscall.SHUT_RDWR) | |
// }) | |
// if err != nil { | |
// return err | |
// } | |
return tcpConn.Close() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment