Skip to content

Instantly share code, notes, and snippets.

@felix021
Created November 21, 2020 08:12
Show Gist options
  • Save felix021/7f9d05fa1fd9f8f62cbce9edbdb19253 to your computer and use it in GitHub Desktop.
Save felix021/7f9d05fa1fd9f8f62cbce9edbdb19253 to your computer and use it in GitHub Desktop.
Minimal socks5 proxy implementation in Golang
package main
import (
"encoding/binary"
"errors"
"fmt"
"io"
"net"
)
func main() {
server, err := net.Listen("tcp", ":1080")
if err != nil {
fmt.Printf("Listen failed: %v\n", err)
return
}
for {
client, err := server.Accept()
if err != nil {
fmt.Printf("Accept failed: %v", err)
continue
}
go process(client)
}
}
func process(client net.Conn) {
if err := Socks5Auth(client); err != nil {
fmt.Println("auth error:", err)
client.Close()
return
}
target, err := Socks5Connect(client)
if err != nil {
fmt.Println("connect error:", err)
client.Close()
return
}
Socks5Forward(client, target)
}
func Socks5Auth(client net.Conn) (err error) {
buf := make([]byte, 256)
// 读取 VER 和 NMETHODS
n, err := io.ReadFull(client, buf[:2])
if n != 2 {
return errors.New("reading header: " + err.Error())
}
ver, nMethods := int(buf[0]), int(buf[1])
if ver != 5 {
return errors.New("invalid version")
}
// 读取 METHODS 列表
n, err = io.ReadFull(client, buf[:nMethods])
if n != nMethods {
return errors.New("reading methods: " + err.Error())
}
//无需认证
n, err = client.Write([]byte{0x05, 0x00})
if n != 2 || err != nil {
return errors.New("write rsp: " + err.Error())
}
return nil
}
func Socks5Connect(client net.Conn) (net.Conn, error) {
buf := make([]byte, 256)
n, err := io.ReadFull(client, buf[:4])
if n != 4 {
return nil, errors.New("read header: " + err.Error())
}
ver, cmd, _, atyp := buf[0], buf[1], buf[2], buf[3]
if ver != 5 || cmd != 1 {
return nil, errors.New("invalid ver/cmd")
}
addr := ""
switch atyp {
case 1:
n, err = io.ReadFull(client, buf[:4])
if n != 4 {
return nil, errors.New("invalid IPv4: " + err.Error())
}
addr = fmt.Sprintf("%d.%d.%d.%d", buf[0], buf[1], buf[2], buf[3])
case 3:
n, err = io.ReadFull(client, buf[:1])
if n != 1 {
return nil, errors.New("invalid hostname: " + err.Error())
}
addrLen := int(buf[0])
n, err = io.ReadFull(client, buf[:addrLen])
if n != addrLen {
return nil, errors.New("invalid hostname: " + err.Error())
}
addr = string(buf[:addrLen])
case 4:
return nil, errors.New("IPv6: no supported yet")
default:
return nil, errors.New("invalid atyp")
}
n, err = io.ReadFull(client, buf[:2])
if n != 2 {
return nil, errors.New("read port: " + err.Error())
}
port := binary.BigEndian.Uint16(buf[:2])
destAddrPort := fmt.Sprintf("%s:%d", addr, port)
dest, err := net.Dial("tcp", destAddrPort)
if err != nil {
return nil, errors.New("dial dst: " + err.Error())
}
n, err = client.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0})
if err != nil {
dest.Close()
return nil, errors.New("write rsp: " + err.Error())
}
return dest, nil
}
func Socks5Forward(client, target net.Conn) {
forward := func(src, dest net.Conn) {
defer src.Close()
defer dest.Close()
io.Copy(src, dest)
}
go forward(client, target)
go forward(target, client)
}
@icetech233
Copy link

Socks5Forward 这里可以稍微修改一下


func relay(left, right net.Conn) error {
	var err, err1 error
	var wg sync.WaitGroup
	var wait = 5 * time.Second
	wg.Add(1)
	go func() {
		defer wg.Done()
		_, err1 = io.Copy(right, left)
		right.SetReadDeadline(time.Now().Add(wait))
	}()
	_, err = io.Copy(left, right)
	left.SetReadDeadline(time.Now().Add(wait))
	wg.Wait()

	if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
		return err
	}

	if err1 != nil && !errors.Is(err1, os.ErrDeadlineExceeded) {
		return err1
	}

	return nil
}

没看懂

@rayepeng
Copy link

rayepeng commented Apr 9, 2023

Socks5Forward 这里可以稍微修改一下

func relay(left, right net.Conn) error {
	var err, err1 error
	var wg sync.WaitGroup
	var wait = 5 * time.Second
	wg.Add(1)
	go func() {
		defer wg.Done()
		_, err1 = io.Copy(right, left)
		right.SetReadDeadline(time.Now().Add(wait))
	}()
	_, err = io.Copy(left, right)
	left.SetReadDeadline(time.Now().Add(wait))
	wg.Wait()

	if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
		return err
	}

	if err1 != nil && !errors.Is(err1, os.ErrDeadlineExceeded) {
		return err1
	}

	return nil
}

没看懂

我理解这里就是增加了个错误处理吧,不过为啥 _, err = io.Copy(left, right) 这里不开一个协程

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment