Skip to content

Instantly share code, notes, and snippets.

@killwing
Created December 13, 2015 11:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save killwing/5a503f6df3301d8f054e to your computer and use it in GitHub Desktop.
Save killwing/5a503f6df3301d8f054e to your computer and use it in GitHub Desktop.
simple proxy
package main
import (
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"sync"
"time"
)
const TIMEOUT = 10 * time.Second
func handleReq(w http.ResponseWriter, r *http.Request) {
if r.Method != "CONNECT" {
proxyReq(w, r)
} else {
tunnelReq(w, r)
}
}
func tunnelReq(w http.ResponseWriter, r *http.Request) {
fmt.Printf("recv tunnel req: %+v\n", r.URL.String())
hj, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError)
return
}
conn, bufrw, err := hj.Hijack()
if err != nil {
fmt.Println("hijack err:", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer conn.Close()
connDest, err := net.DialTimeout("tcp", r.URL.Host, TIMEOUT)
if err != nil {
fmt.Println("dial err:", err.Error())
bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n\r\n")
return
}
defer connDest.Close()
bufrw.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n")
bufrw.Flush()
var wg sync.WaitGroup
wg.Add(2)
go func() {
_, err := io.Copy(connDest, conn)
if err != nil {
fmt.Println("copy err:", err)
}
conn.Close()
connDest.Close()
fmt.Println("src -> dest close")
wg.Done()
}()
go func() {
_, err := io.Copy(conn, connDest)
if err != nil {
fmt.Println("copy err:", err)
}
conn.Close()
connDest.Close()
fmt.Println("dest -> src close")
wg.Done()
}()
wg.Wait()
fmt.Println("disconnect tunnel")
}
func proxyReq(w http.ResponseWriter, r *http.Request) {
fmt.Printf("recv proxy req: %+v\n", r.URL.String())
client := &http.Client{Timeout: TIMEOUT}
req, e := http.NewRequest(r.Method, r.URL.String(), r.Body)
if e != nil {
fmt.Println("create request err: ", e)
return
}
req.Header = r.Header
resp, e := client.Do(req)
if e != nil {
fmt.Println("do client err: ", e)
return
}
for k, v := range resp.Header {
w.Header()[k] = v
}
w.WriteHeader(resp.StatusCode)
defer resp.Body.Close()
body, e := ioutil.ReadAll(resp.Body)
if e != nil {
fmt.Println("read body err:", e)
return
}
_, e = w.Write(body)
if e != nil {
fmt.Println("write body err:", e)
return
}
fmt.Println("end proxy req")
}
func main() {
s := &http.Server{Addr: ":12345", Handler: http.HandlerFunc(handleReq), ReadTimeout: TIMEOUT, WriteTimeout: TIMEOUT}
err := s.ListenAndServe()
if err != nil {
fmt.Println("ListenAndServe: ", err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment