Skip to content

Instantly share code, notes, and snippets.

@bacher09
Last active March 17, 2019 16:15
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bacher09/53049ecd6aff01fde4e4 to your computer and use it in GitHub Desktop.
Save bacher09/53049ecd6aff01fde4e4 to your computer and use it in GitHub Desktop.
package main
import (
"encoding/binary"
"flag"
"fmt"
"log"
"net"
"sync"
"time"
)
const (
XonMSS = 1500
ExpireTime = 60 * time.Second
)
// crutches that I stole from docker :)
type connTrackKey struct {
IPHigh uint64
IPLow uint64
Port int
}
type UDPProxy struct {
listenConn *net.UDPConn
listenAddr *net.UDPAddr
remoteAddr *net.UDPAddr
connections connTrackMap
connectionsMutex sync.RWMutex
connectionsTimeouts sync.Map
stopChannel chan struct{}
}
func newConnTrackKey(addr *net.UDPAddr) *connTrackKey {
if len(addr.IP) == net.IPv4len {
return &connTrackKey{
IPHigh: 0,
IPLow: uint64(binary.BigEndian.Uint32(addr.IP)),
Port: addr.Port,
}
} else {
return &connTrackKey{
IPHigh: binary.BigEndian.Uint64(addr.IP[:8]),
IPLow: binary.BigEndian.Uint64(addr.IP[8:]),
Port: addr.Port,
}
}
}
type connTrackMap map[connTrackKey]*net.UDPConn
func CompareAddr(first *net.UDPAddr, second *net.UDPAddr) bool {
return first.IP.Equal(second.IP) && first.Port == second.Port
}
func CreateUDPProxy(listenAddr *net.UDPAddr, remoteAddr *net.UDPAddr) (*UDPProxy, error) {
listenConn, err := net.ListenUDP("udp", listenAddr)
if err != nil {
return nil, err
}
connections := make(connTrackMap)
proxy := &UDPProxy{
listenConn: listenConn,
listenAddr: listenAddr,
remoteAddr: remoteAddr,
connections: connections,
stopChannel: make(chan struct{}),
}
return proxy, nil
}
func (proxy *UDPProxy) Run() {
go proxy.cleanupConnections()
readBuffer := make([]byte, XonMSS)
for {
n, addr, err := proxy.listenConn.ReadFromUDP(readBuffer)
if err != nil {
break
}
addrKey := newConnTrackKey(addr)
proxy.connectionsTimeouts.Store(*addrKey, time.Now())
proxy.connectionsMutex.RLock()
clientConn, hit := proxy.connections[*addrKey]
proxy.connectionsMutex.RUnlock()
if !hit {
clientConn, err = net.DialUDP("udp", nil, proxy.remoteAddr)
if err != nil {
continue
}
proxy.connectionsMutex.Lock()
proxy.connections[*addrKey] = clientConn
proxy.connectionsMutex.Unlock()
go proxy.replyLoop(clientConn, addr, addrKey)
}
clientConn.Write(readBuffer[:n])
}
}
func (proxy *UDPProxy) cleanupConnections() {
for {
select {
default:
time.Sleep(20 * time.Second)
now := time.Now()
proxy.connectionsTimeouts.Range(func(iaddr interface{}, its interface{}) bool {
addr, ts := iaddr.(connTrackKey), its.(time.Time)
if now.Sub(time.Time(ts)) > ExpireTime {
proxy.connectionsMutex.Lock()
clientConn, hit := proxy.connections[addr]
if hit {
clientConn.Close()
delete(proxy.connections, addr)
}
proxy.connectionsMutex.Unlock()
proxy.connectionsTimeouts.Delete(addr)
}
return true
})
case <-proxy.stopChannel:
fmt.Print("stoped")
return
}
}
}
func (proxy *UDPProxy) replyLoop(clientConn *net.UDPConn, clientAddr *net.UDPAddr, addrKey *connTrackKey) {
defer func() {
proxy.connectionsMutex.Lock()
delete(proxy.connections, *addrKey)
proxy.connectionsMutex.Unlock()
clientConn.Close()
}()
buffer := make([]byte, XonMSS)
for {
n, addr, err := clientConn.ReadFromUDP(buffer)
if err != nil {
fmt.Print("stoping goroutine")
return
}
if !CompareAddr(addr, proxy.remoteAddr) {
// received response from wrong ip, ignore it
continue
}
proxy.listenConn.WriteToUDP(buffer[:n], clientAddr)
}
}
func (proxy *UDPProxy) Close() {
close(proxy.stopChannel)
proxy.listenConn.Close()
proxy.connectionsMutex.Lock()
defer proxy.connectionsMutex.Unlock()
for _, conn := range proxy.connections {
conn.Close()
}
}
func main() {
port := flag.Uint("port", 26000, "listen udp port")
server := flag.String("server", "", "server where traffic should go")
flag.Parse()
// TODO: validate port
if *port <= 0 || *port > 65535 {
log.Fatal("Port value should be in range: [1, 65535]")
}
if len(*server) < 2 {
log.Fatal("Bad server address")
}
ListenAddr := &net.UDPAddr{Port: int(*port)}
RemoteAddr, err := net.ResolveUDPAddr("udp", *server)
if err != nil {
log.Fatal(err)
}
proxy, _ := CreateUDPProxy(ListenAddr, RemoteAddr)
defer proxy.Close()
proxy.Run()
}
import asyncio
import socket
class BaseProto:
def __init__(self):
self.transport = None
def connection_made(self, transport):
self.transport = transport
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, 0x40 << 2)
def connection_lost(self, exc):
pass
class UdpProxyConnection(BaseProto):
def __init__(self, client, server):
super().__init__()
self.client = client
self.server = server
def datagram_received(self, data, addr):
self.server.update_timeout(self.client)
self.server.transport.sendto(data, self.client)
class UdpProxyServerProtocol(BaseProto):
MAX_TIMEOUT = 60
remote_addr = ("109.169.18.153", 26000)
def __init__(self, loop):
super().__init__()
self.clients = {}
self.timeouts = {}
self.loop = loop
def connection_made(self, transport):
super().connection_made(transport)
#asyncio.async(self.check_timeouts())
self.loop.call_later(20, self.check_timeouts)
def datagram_received(self, data, addr):
if addr not in self.clients:
self.clients[addr] = asyncio.async(self.new_connection(addr))
asyncio.async(self.deliver_msg(self.clients[addr], data))
@asyncio.coroutine
def new_connection(self, addr):
t, p = yield from self.loop.create_datagram_endpoint(
lambda: UdpProxyConnection(addr, self),
remote_addr=self.remote_addr
)
self.update_timeout(addr)
return p
@asyncio.coroutine
def deliver_msg(self, proto, msg):
p = yield from proto
self.update_timeout(p.client)
p.transport.sendto(msg, self.remote_addr)
def check_timeouts(self):
removing = []
current_time = self.loop.time()
for addr, timeout in self.timeouts.items():
if timeout + self.MAX_TIMEOUT < current_time:
removing.append(addr)
for addr in removing:
asyncio.async(self.remove_connection(addr))
self.loop.call_later(20, self.check_timeouts)
def update_timeout(self, addr):
self.timeouts[addr] = self.loop.time()
@asyncio.coroutine
def remove_connection(self, addr):
p = yield from self.clients.get(addr)
if p is not None:
p.transport.close()
del self.timeouts[addr]
del self.clients[addr]
loop = asyncio.get_event_loop()
listen = loop.create_datagram_endpoint(
lambda : UdpProxyServerProtocol(loop), local_addr=('0.0.0.0', 27000))
transport, protocol = loop.run_until_complete(listen)
try:
loop.run_forever()
except KeyboardInterrupt:
pass
transport.close()
loop.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment