Skip to content

Instantly share code, notes, and snippets.

@warriorpaw
Last active November 12, 2023 09:56
Show Gist options
  • Save warriorpaw/d2db2cd128e4c33fec10c42cfd80ef6c to your computer and use it in GitHub Desktop.
Save warriorpaw/d2db2cd128e4c33fec10c42cfd80ef6c to your computer and use it in GitHub Desktop.
package main
import (
"crypto/md5"
"encoding/binary"
"encoding/json"
"flag"
"fmt"
"math/rand"
"net"
"os"
"os/signal"
"sync"
"time"
)
func errorPanic(err error, format string, a ...interface{}) {
if err != nil {
fmt.Printf("Panic : "+format, a...)
panic(err)
}
}
func errorPrint(format string, a ...interface{}) {
if len(a) > 0 {
fmt.Printf("Error : "+format, a...)
} else {
fmt.Printf("Error : " + format)
}
}
type udpConnAlloc struct {
cache map[IpPort]*net.UDPConn
}
var gUdpConnAlloc udpConnAlloc
func (a *udpConnAlloc) getUdpConn(addr IpPort) *net.UDPConn {
if c, ok := a.cache[addr]; ok {
return c
} else {
ret := getUdpConn(addr.Ip, addr.Port)
a.cache[addr] = ret
return ret
}
}
func (a *udpConnAlloc) stop() {
for k := range a.cache {
a.cache[k].Close()
}
}
func getUdpConn(ip string, port int) *net.UDPConn {
ret, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP(ip), Port: port})
errorPanic(err, "creat udp conn %s : %d error \n", ip, port)
return ret
}
type IpPort struct {
Ip string `json:"ip"`
Port int `json:"port"`
}
type TunnelConfig struct {
LocalPort IpPort `json:"local_port"`
RemotePort IpPort `json:"remote_port"`
}
type Config struct {
ListenPort IpPort `json:"listen_port"`
RemotePort IpPort `json:"remote_port"`
Tunnels []TunnelConfig `json:"tunnels"`
Timeout int64 `json:"timeout"`
TimeoutCheck int64 `json:"timeout_check"`
ServerSide bool `json:"server_side"`
}
/*
{
"listen_port":{
"ip":"0.0.0.0",
"port":10102
},
"remote_port":{
"ip":"",
"port":0
},
"tunnels":[
{
"local_port":{
"ip":"0.0.0.0",
"port":11111
},
"remote_port":{
"ip":"8.8.8.8",
"port":53
}
}
],
"timeout":30,
"timeout_check":5,
"server_side":false
}
*/
// hash ipPort into 64bit id
func portToId(port *net.UDPAddr) uint64 {
h := md5.New()
h.Write(port.IP.To4())
var data [2]byte
binary.BigEndian.PutUint16(data[:2], uint16(port.Port))
h.Write(data[:2])
ret := h.Sum(nil)
for i := 0; i < 8; i++ {
ret[i] ^= ret[i+8]
}
return binary.BigEndian.Uint64(ret[:8])
}
type clientInfo struct {
clientAddr *net.UDPAddr
lastActive int64
}
type dataInfo struct {
addr *net.UDPAddr
data []byte
id uint64
}
type tunnel struct {
local *net.UDPConn
remote *net.UDPAddr
buffer chan []byte
}
type windowCounter struct {
currentWindow int64
currentWindowSize int
passWindowLimit int64
pastWindowsSize map[int64]int
sizeLimit int
needLimit bool
}
func newWindowCounter() *windowCounter {
ret := &windowCounter{}
ret.passWindowLimit = 4
ret.pastWindowsSize = make(map[int64]int)
ret.sizeLimit = int(ret.passWindowLimit) * 524288
ret.needLimit = false
return ret
}
func (c *windowCounter) add(packageSize int) bool {
now := time.Now().Unix()
if now == c.currentWindow {
c.currentWindowSize += packageSize
} else {
c.pastWindowsSize[c.currentWindow] = c.currentWindowSize
c.currentWindowSize = packageSize
c.currentWindow = now
pastWindownSum := 0
ddl := now - c.passWindowLimit
for window := range c.pastWindowsSize {
if window < ddl {
delete(c.pastWindowsSize, window)
} else {
pastWindownSum += c.pastWindowsSize[window]
}
}
c.needLimit = (pastWindownSum >= c.sizeLimit)
}
return c.needLimit
}
type tunnelHandler struct {
tunnels []*tunnel
tunnelRecv chan []byte
memPool *sync.Pool
windowCounter *windowCounter
blanceCount uint64
}
func (h *tunnelHandler) send(data []byte) {
dataSize := len(data)
needLimit := h.windowCounter.add(dataSize)
if needLimit && dataSize > 1340 {
h.tunnels[h.blanceCount%uint64(len(h.tunnels))].buffer <- data
h.blanceCount += 1
} else {
for i, t := range h.tunnels {
if i < len(h.tunnels)-1 {
cp_data := h.memPool.Get().([]byte)
cp_data = cp_data[:len(data)]
copy(cp_data, data)
t.buffer <- cp_data
} else {
t.buffer <- data
}
}
}
}
func (h *tunnelHandler) hb() {
for _, t := range h.tunnels {
data := h.memPool.Get().([]byte)
data = data[:1500]
binary.BigEndian.PutUint64(data[:8], rand.Uint64())
t.buffer <- data[:8]
}
}
func (h *tunnelHandler) handleSend(pos int) {
sendConn := h.tunnels[pos].local
memPool := h.memPool
sendTo := h.tunnels[pos].remote
buffer := h.tunnels[pos].buffer
for data := range buffer {
if _, err := sendConn.WriteToUDP(data, sendTo); err != nil {
errorPrint("tunnelHandler handleSend .WriteToUDP error %+v \n", err)
}
memPool.Put(data)
}
}
func (h *tunnelHandler) handleRecv(pos int) {
recvConn := h.tunnels[pos].local
memPool := h.memPool
recvData := h.tunnelRecv
for {
data := memPool.Get().([]byte)
data = data[:1500]
n, _, err := recvConn.ReadFromUDP(data)
if err != nil {
errorPrint("error in tunnelHandler handleRecv %+v \n", err)
return
}
recvData <- data[:n]
}
}
func newTunnelHandler(c Config, m *sync.Pool) *tunnelHandler {
ret := &tunnelHandler{}
ret.memPool = m
ret.tunnelRecv = make(chan []byte, 1000*len(c.Tunnels))
ret.tunnels = make([]*tunnel, len(c.Tunnels))
for i, t := range c.Tunnels {
ret.tunnels[i] = &tunnel{}
ret.tunnels[i].buffer = make(chan []byte, 1000)
l := t.LocalPort
conn := gUdpConnAlloc.getUdpConn(l)
ret.tunnels[i].local = conn
r := t.RemotePort
ret.tunnels[i].remote = &net.UDPAddr{IP: net.ParseIP(r.Ip), Port: r.Port}
}
ret.windowCounter = newWindowCounter()
ret.blanceCount = 0
for i := range c.Tunnels {
go ret.handleRecv(i)
go ret.handleSend(i)
}
return ret
}
type clientSide struct {
memPool *sync.Pool
recvConn *net.UDPConn
clientMap map[uint64]*clientInfo
recvData chan dataInfo
sendData chan dataInfo
tunnels *tunnelHandler
stopChan chan interface{}
timeoutCheck int64
timeout int64
}
func newclientSide(c Config, m *sync.Pool, t *tunnelHandler, s chan interface{}) *clientSide {
ret := &clientSide{}
ret.memPool = m
ret.recvConn = gUdpConnAlloc.getUdpConn(c.ListenPort)
ret.clientMap = make(map[uint64]*clientInfo)
ret.recvData = make(chan dataInfo, 1000)
ret.sendData = make(chan dataInfo, 1000)
ret.tunnels = t
ret.stopChan = s
ret.timeoutCheck = c.TimeoutCheck
ret.timeout = c.Timeout
go ret.handleSend()
go ret.handleRecv()
go ret.handleClients()
return ret
}
func (c *clientSide) getClientId(addr *net.UDPAddr) uint64 {
id := portToId(addr)
if info, ok := c.clientMap[id]; ok {
info.lastActive = time.Now().Unix()
} else {
c.clientMap[id] = &clientInfo{addr, time.Now().Unix()}
}
return id
}
func (c *clientSide) handleRecv() {
recvConn := c.recvConn
memPool := c.memPool
recvData := c.recvData
for {
data := memPool.Get().([]byte)
data = data[:1500]
n, addr, err := recvConn.ReadFromUDP(data[8:])
if err != nil {
errorPrint("error in downLink listenConn.ReadFromUDP %+v \n", err)
return
}
recvData <- dataInfo{addr, data[:n+8], 0}
}
}
func (c *clientSide) handleSend() {
sendConn := c.recvConn
memPool := c.memPool
sendData := c.sendData
stopChan := c.stopChan
for {
select {
case di, ok := <-sendData:
if !ok {
errorPrint("handleClients handleSend broken, exit\n")
return
}
if _, err := sendConn.WriteToUDP(di.data[8:], di.addr); err != nil {
errorPrint("handleClients handleSend .WriteToUDP error %+v \n", err)
}
memPool.Put(di.data)
case <-stopChan:
return
}
}
}
func (c *clientSide) getClientAddr(id uint64) *net.UDPAddr {
if info, ok := c.clientMap[id]; ok {
info.lastActive = time.Now().Unix()
return info.clientAddr
}
return nil
}
func (c *clientSide) cleanClientMap() {
ddl := time.Now().Unix() - c.timeout
for id := range c.clientMap {
if c.clientMap[id].lastActive < ddl {
delete(c.clientMap, id)
}
}
}
func xorData8(data []byte, mark []byte) {
for i := 0; i < 8; i++ {
data[i] ^= mark[i]
}
}
func (c *clientSide) handleClients() {
delay := time.Duration(c.timeoutCheck)
timer := time.NewTimer(delay * time.Second)
recvData := c.recvData
stopChan := c.stopChan
tunnels := c.tunnels
tunnelRecv := c.tunnels.tunnelRecv
sendData := c.sendData
memPool := c.memPool
for {
select {
case di, ok := <-recvData:
if !ok {
errorPrint("handleClients recvData broken, exit\n")
return
}
n := len(di.data)
if n < 40 {
memPool.Put(di.data)
} else {
binary.BigEndian.PutUint64(di.data[:8], c.getClientId(di.addr))
xorData8(di.data[:8], di.data[24:32])
tunnels.send(di.data)
}
case data, ok := <-tunnelRecv:
if !ok {
errorPrint("handleClients tunnelRecv broken, exit\n")
return
}
n := len(data)
if n < 40 {
memPool.Put(data)
} else {
xorData8(data[:8], data[24:32])
id := binary.BigEndian.Uint64(data[:8])
if addr := c.getClientAddr(id); addr != nil {
sendData <- dataInfo{addr, data, 0}
} else {
memPool.Put(data)
}
}
case <-timer.C:
c.cleanClientMap()
timer.Reset(delay * time.Second)
case <-stopChan:
return
}
}
}
type sendSocket struct {
memPool *sync.Pool
sendConn *net.UDPConn
remote *net.UDPAddr
sendData chan []byte
recvData *chan dataInfo
lastActive int64
cid uint64
}
func newsendSocket(s *serverSide, id uint64) *sendSocket {
ret := &sendSocket{}
ret.memPool = s.memPool
ret.sendConn = getUdpConn("0.0.0.0", 0)
ret.remote = s.remote
ret.sendData = make(chan []byte, 1000)
ret.recvData = &s.recvData
ret.lastActive = time.Now().Unix()
ret.cid = id
go ret.handleRecv()
go ret.handleSend()
return ret
}
func (s *sendSocket) stop() {
s.sendConn.Close()
close(s.sendData)
}
func (s *sendSocket) handleSend() {
sendConn := s.sendConn
memPool := s.memPool
sendData := s.sendData
remote := s.remote
for data := range sendData {
if _, err := sendConn.WriteToUDP(data[8:], remote); err != nil {
errorPrint("handleClients handleSend .WriteToUDP error %+v \n", err)
}
memPool.Put(data)
}
}
func (s *sendSocket) handleRecv() {
recvConn := s.sendConn
memPool := s.memPool
recvData := s.recvData
for {
data := memPool.Get().([]byte)
data = data[:1500]
n, _, err := recvConn.ReadFromUDP(data[8:])
if err != nil {
return
}
*recvData <- dataInfo{nil, data[:n+8], s.cid}
}
}
type serverSide struct {
memPool *sync.Pool
remote *net.UDPAddr
tunnels *tunnelHandler
socketMap map[uint64]*sendSocket
recvData chan dataInfo
stopChan chan interface{}
timeoutCheck int64
timeout int64
}
func newServerSide(c Config, m *sync.Pool, t *tunnelHandler, s chan interface{}) *serverSide {
ret := &serverSide{}
ret.memPool = m
r := c.RemotePort
ret.remote = &net.UDPAddr{IP: net.ParseIP(r.Ip), Port: r.Port}
ret.tunnels = t
ret.socketMap = make(map[uint64]*sendSocket)
ret.recvData = make(chan dataInfo, 1000*len(c.Tunnels))
ret.stopChan = s
ret.timeout = c.Timeout
ret.timeoutCheck = c.TimeoutCheck
go ret.handleclients()
return ret
}
func (s *serverSide) getSendSocket(id uint64) *sendSocket {
if send, ok := s.socketMap[id]; ok {
send.lastActive = time.Now().Unix()
return send
} else {
send = newsendSocket(s, id)
s.socketMap[id] = send
return send
}
}
func (s *serverSide) cleanSendtMap() {
ddl := time.Now().Unix() - s.timeout
for id := range s.socketMap {
if s.socketMap[id].lastActive < ddl {
s.socketMap[id].stop()
delete(s.socketMap, id)
}
}
}
func (s *serverSide) handleclients() {
memPool := s.memPool
tunnels := s.tunnels
recvData := s.recvData
stopChan := s.stopChan
delay := time.Duration(s.timeoutCheck)
timer := time.NewTimer(delay * time.Second)
tunnels.hb()
for {
select {
case data, ok := <-tunnels.tunnelRecv:
if !ok {
errorPrint("serverSide tunnelRecv broken, exit\n")
return
}
n := len(data)
if n < 40 {
memPool.Put(data)
} else {
xorData8(data[:8], data[24:32])
id := binary.BigEndian.Uint64(data[:8])
send := s.getSendSocket(id)
send.sendData <- data
}
case di, ok := <-recvData:
if !ok {
errorPrint("serverSide recvData broken, exit\n")
return
}
if send, ok := s.socketMap[di.id]; ok {
send.lastActive = time.Now().Unix()
binary.BigEndian.PutUint64(di.data[:8], di.id)
xorData8(di.data[:8], di.data[24:32])
tunnels.send(di.data)
} else {
memPool.Put(di.data)
}
case <-timer.C:
tunnels.hb()
s.cleanSendtMap()
timer.Reset(delay * time.Second)
case <-stopChan:
return
}
}
}
func (c *Config) load(f string) {
jsonBytes, err := os.ReadFile(f)
errorPanic(err, "open file %s error %+v \n", f, err)
err = json.Unmarshal(jsonBytes, c)
errorPanic(err, "Unmarshal file %s error %+v \n", f, err)
// no input legality check, just panic in New if config is miss some value.
}
func main() {
configPath := ""
flag.StringVar(&configPath, "c", "./conf.json", "config file path")
flag.Parse()
gUdpConnAlloc.cache = make(map[IpPort]*net.UDPConn)
c := Config{}
c.load(configPath)
memPool := &sync.Pool{New: func() interface{} { return make([]byte, 1500) }}
tunnels := newTunnelHandler(c, memPool)
stopChan := make(chan interface{}, 2)
if c.ServerSide {
newServerSide(c, memPool, tunnels, stopChan)
} else {
newclientSide(c, memPool, tunnels, stopChan)
}
signalChan := make(chan os.Signal, 2)
signal.Notify(signalChan, os.Interrupt)
<-signalChan
close(stopChan)
gUdpConnAlloc.stop()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment