Skip to content

Instantly share code, notes, and snippets.

@ebroder
Last active February 10, 2016 23:22
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save ebroder/ae9299e0078094211bde to your computer and use it in GitHub Desktop.
Save ebroder/ae9299e0078094211bde to your computer and use it in GitHub Desktop.
package main
// The MIT License
//
// Copyright (c) 2014- Stripe, Inc. (https://stripe.com)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
import (
"flag"
"fmt"
"log"
"net"
"net/http"
"time"
"github.com/elazarl/goproxy"
"github.com/stripe/go-einhorn/einhorn"
)
var privateNetworks []net.IPNet
var connectTimeout time.Duration
func init() {
privateNetworkStrings := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"fc00::/7",
}
privateNetworks = make([]net.IPNet, len(privateNetworkStrings))
for i, netstring := range privateNetworkStrings {
_, net, err := net.ParseCIDR(netstring)
if err != nil {
log.Fatal(err)
}
privateNetworks[i] = *net
}
}
func isPrivateNetwork(ip net.IP) bool {
for _, net := range privateNetworks {
if net.Contains(ip) {
return true
}
}
return false
}
func safeResolve(network, addr string) (string, error) {
resolved, err := net.ResolveTCPAddr(network, addr)
if err != nil {
return "", err
}
if resolved.IP.IsLoopback() ||
resolved.IP.IsUnspecified() ||
resolved.IP.IsLinkLocalUnicast() ||
isPrivateNetwork(resolved.IP) {
return "", fmt.Errorf("host %s resolves to illegal IP %s", addr, resolved.IP)
}
return resolved.String(), nil
}
func dial(network, addr string) (net.Conn, error) {
resolved, err := safeResolve(network, addr)
if err != nil {
return nil, err
}
return net.DialTimeout(network, resolved, connectTimeout)
}
func errorResponse(req *http.Request, err error) *http.Response {
resp := goproxy.NewResponse(req, goproxy.ContentTypeText, http.StatusServiceUnavailable, err.Error()+"\n")
resp.ProtoMajor = req.ProtoMajor
resp.ProtoMinor = req.ProtoMinor
resp.Header.Add("X-Smokescreen-Error", err.Error())
return resp
}
func findListener(defaultPort int) (net.Listener, error) {
if einhorn.IsWorker() {
listener, err := einhorn.GetListener(0)
if err != nil {
return nil, err
}
err = einhorn.Ack()
return listener, err
} else {
return net.Listen("tcp", fmt.Sprintf(":%d", defaultPort))
}
}
func main() {
proxy := goproxy.NewProxyHttpServer()
proxy.Verbose = true
proxy.Tr.Dial = dial
proxy.OnRequest().HandleConnectFunc(func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) {
resolved, err := safeResolve("tcp", host)
if err != nil {
ctx.Resp = errorResponse(ctx.Req, err)
return goproxy.RejectConnect, ""
}
return goproxy.OkConnect, resolved
})
proxy.OnResponse().DoFunc(func(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response {
if resp == nil && ctx.Error != nil {
resp = errorResponse(ctx.Req, ctx.Error)
}
return resp
})
var port int
flag.IntVar(&port, "port", 4750, "Port to bind on")
flag.DurationVar(&connectTimeout, "timeout", time.Duration(10)*time.Second, "Time to wait while connecting")
flag.Parse()
listener, err := findListener(port)
if err != nil {
log.Fatal(err)
}
log.Fatal(http.Serve(listener, proxy))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment