Created
January 30, 2017 23:03
-
-
Save groob/ea563ea1f3092449cd75eeb78213cd83 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"crypto/tls" | |
"flag" | |
"io" | |
"log" | |
"net" | |
"net/http" | |
"net/http/httputil" | |
"net/url" | |
"os" | |
"path/filepath" | |
"strconv" | |
"strings" | |
"sync" | |
"time" | |
"github.com/gorilla/websocket" | |
"github.com/acme/cloudops/env" | |
"github.com/acme/cloudops/version" | |
) | |
func authMW(authURL, upstreamURL string) http.HandlerFunc { | |
return func(w http.ResponseWriter, r *http.Request) { | |
// if osqueryd don't check auth. nodeKey is checked by the server | |
if !strings.HasPrefix(r.URL.Path, "/api/v1/osquery") { | |
resp, err := authRequest(authURL+"/oauth2/auth", r) | |
if err != nil { | |
http.Error(w, err.Error(), http.StatusInternalServerError) | |
return | |
} | |
if resp.StatusCode != http.StatusAccepted { | |
http.RedirectHandler("https://"+r.Host+"/oauth2/sign_in", | |
http.StatusTemporaryRedirect).ServeHTTP(w, r) | |
return | |
} | |
} | |
// proxy websocket requests | |
if isWebsocket(r) { | |
handleWebsocket(w, r) | |
return | |
} | |
// proxy regular http requests | |
var uri string | |
if upstreamURL != "" { | |
uri = upstreamURL | |
} else { | |
uri = upstreamFromHost(r.Host) | |
} | |
u, err := url.Parse(uri) | |
if err != nil { | |
log.Printf("upstream: %s\n", uri) | |
http.Error(w, err.Error(), http.StatusInternalServerError) | |
return | |
} | |
r.Header.Set("X-Forwarded-Proto", "https") | |
p := httputil.NewSingleHostReverseProxy(u) | |
p.ServeHTTP(w, r) | |
} | |
} | |
func isWebsocket(r *http.Request) bool { | |
upgrade := r.Header.Get("upgrade") | |
conn := r.Header.Get("connection") | |
if conn == "Upgrade" && upgrade == "websocket" { | |
return true | |
} | |
return false | |
} | |
func handleWebsocket(w http.ResponseWriter, r *http.Request) { | |
var ( | |
ConnectionHeaderKey = http.CanonicalHeaderKey("connection") | |
SetCookieHeaderKey = http.CanonicalHeaderKey("set-cookie") | |
UpgradeHeaderKey = http.CanonicalHeaderKey("upgrade") | |
WSKeyHeaderKey = http.CanonicalHeaderKey("sec-websocket-key") | |
WSProtocolHeaderKey = http.CanonicalHeaderKey("sec-websocket-protocol") | |
WSVersionHeaderKey = http.CanonicalHeaderKey("sec-websocket-version") | |
HandshakeHeaders = []string{ConnectionHeaderKey, UpgradeHeaderKey, WSVersionHeaderKey, WSKeyHeaderKey} | |
UpgradeHeaders = []string{SetCookieHeaderKey, WSProtocolHeaderKey} | |
) | |
// Copy request headers and remove websocket handshaking headers | |
// before submitting to the upstream server | |
upstreamHeader := http.Header{} | |
for key, _ := range r.Header { | |
copyHeader(&upstreamHeader, r.Header, key) | |
} | |
for _, header := range HandshakeHeaders { | |
delete(upstreamHeader, header) | |
} | |
upstreamHeader.Set("Host", r.Host) | |
// Connect upstream | |
upstreamHost := strings.TrimPrefix(upstreamFromHost(r.Host)+":80", "http://") | |
upstreamAddr := upstreamWSURL(*r.URL, upstreamHost, "http").String() | |
upstream, upstreamResp, err := websocket.DefaultDialer.Dial(upstreamAddr, upstreamHeader) | |
if err != nil { | |
if upstreamResp != nil { | |
log.Printf("dialing upstream websocket failed with code %d: %v", upstreamResp.StatusCode, err) | |
} else { | |
log.Printf("dialing upstream websocket failed: %v", err) | |
} | |
http.Error(w, "websocket unavailable", http.StatusServiceUnavailable) | |
return | |
} | |
defer upstream.Close() | |
// Pass websocket handshake response headers to the upgrader | |
upgradeHeader := http.Header{} | |
copyHeaders(&upgradeHeader, upstreamResp.Header, UpgradeHeaders) | |
// Upgrade the client connection without validating the origin | |
upgrader := websocket.Upgrader{ | |
CheckOrigin: func(r *http.Request) bool { return true }, | |
} | |
client, err := upgrader.Upgrade(w, r, upgradeHeader) | |
if err != nil { | |
log.Printf("couldn't upgrade websocket request: %v", err) | |
http.Error(w, "websocket upgrade failed", http.StatusServiceUnavailable) | |
return | |
} | |
// Wire both sides together and close when finished | |
var wg sync.WaitGroup | |
cp := func(dst, src *websocket.Conn) { | |
defer wg.Done() | |
_, err := io.Copy(dst.UnderlyingConn(), src.UnderlyingConn()) | |
var closeMessage []byte | |
if err != nil { | |
closeMessage = websocket.FormatCloseMessage(websocket.CloseProtocolError, err.Error()) | |
} else { | |
closeMessage = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bye") | |
} | |
// Attempt to close the connection properly | |
dst.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(2*time.Second)) | |
src.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(2*time.Second)) | |
} | |
wg.Add(2) | |
go cp(upstream, client) | |
go cp(client, upstream) | |
wg.Wait() | |
} | |
// Create a websocket URL from the request URL | |
func upstreamWSURL(r url.URL, upstreamHost, upstreamScheme string) *url.URL { | |
ws := r | |
ws.User = r.User | |
ws.Host = upstreamHost | |
ws.Fragment = "" | |
switch upstreamScheme { | |
case "http": | |
ws.Scheme = "ws" | |
case "https": | |
ws.Scheme = "wss" | |
} | |
return &ws | |
} | |
func copyHeaders(dst *http.Header, src http.Header, headers []string) { | |
for _, header := range headers { | |
copyHeader(dst, src, header) | |
} | |
} | |
// Copy any non-empty and non-blank header values | |
func copyHeader(dst *http.Header, src http.Header, header string) { | |
for _, value := range src[header] { | |
if value != "" { | |
dst.Add(header, value) | |
} | |
} | |
} | |
// map upstream to k8s svc | |
func upstreamFromHost(host string) string { | |
host = prependZ(host) | |
dashed := strings.Replace(host, ".", "-", -1) | |
upstream := "http://" + dashed + ".default.svc.cluster.local" | |
return upstream | |
} | |
// k8s does not allow having service names that start with a number | |
// we prepend `zzzzz`(5zs) to service names that start with a number | |
// for example 100.pr.acme.net should be zzzzz1000-pr-acme-net | |
func prependZ(host string) string { | |
split := strings.Split(host, ".") | |
_, err := strconv.Atoi(split[0]) | |
if err == nil { | |
return "zzzzz" + host | |
} | |
return host | |
} | |
func main() { | |
var ( | |
flTLS = flag.Bool("tls", env.Bool("USE_TLS", true), "use tls") | |
flHTTPAddress = flag.String("http.address", env.String("HTTP_ADDRESS", "0.0.0.0:443"), "http listen address") | |
flCert = flag.String("tls.certificate", env.String("TLS_CERTIFICATE", "secrets/tls.crt"), "tls server certificate") | |
flKey = flag.String("tls.key", env.String("TLS_KEY", "secrets/tls.key"), "tls server certificate private key") | |
flCertDir = flag.String("tls.certificates", env.String("TLS_CERTIFICATES", ""), "path to tls server certificate directory") | |
flAuthProxyURL = flag.String("oauthproxy.url", env.String("OAUTH_PROXY_URL", ""), "url of the oauth2proxy backend") | |
flUpstreamURL = flag.String("upstream.url", env.String("UPSTREAM_URL", ""), "url of the backend service you want to proxy") | |
) | |
flag.Parse() | |
http.Handle("/", authMW(*flAuthProxyURL, *flUpstreamURL)) | |
http.Handle("/oauth2/", authBackend(*flAuthProxyURL)) | |
http.Handle("/-/version", version.Handler()) | |
if *flTLS && *flCertDir == "" { | |
log.Fatal(http.ListenAndServeTLS(*flHTTPAddress, *flCert, *flKey, nil)) | |
} else if *flCertDir != "" { | |
certs, err := loadCertDir(*flCertDir) | |
if err != nil { | |
log.Fatal(err) | |
} | |
srv := http.Server{ | |
Addr: *flHTTPAddress, | |
} | |
log.Fatal(ListenAndServeTLSSNI(&srv, certs)) | |
} else { | |
log.Fatal(http.ListenAndServe(*flHTTPAddress, nil)) | |
} | |
} | |
func authBackend(authURL string) http.HandlerFunc { | |
return func(w http.ResponseWriter, r *http.Request) { | |
r.Header.Set("X-Auth-Request-Redirect", "https://oauth.acme.net/oauth2/callback/"+r.Host) | |
u, err := url.Parse(authURL) | |
if err != nil { | |
http.Error(w, err.Error(), http.StatusInternalServerError) | |
return | |
} | |
httputil.NewSingleHostReverseProxy(u).ServeHTTP(w, r) | |
} | |
} | |
func authRequest(url string, upstream *http.Request) (*http.Response, error) { | |
req, err := http.NewRequest("GET", url, nil) | |
if err != nil { | |
return nil, err | |
} | |
req.Header = upstream.Header | |
resp, err := http.DefaultClient.Do(req) | |
return resp, err | |
} | |
// walk a folder like /secrets and load each tls.cert, tls.key into a | |
// tls certificate. For use with ListenANdServeTLSSNI. | |
func loadCertDir(path string) ([]tls.Certificate, error) { | |
var certs []tls.Certificate | |
walkFn := func(path string, info os.FileInfo, err error) error { | |
if err != nil { | |
return err | |
} | |
if strings.HasSuffix(info.Name(), ".crt") { | |
keyPath := filepath.Join(filepath.Dir(path), "tls.key") | |
cert, err := tls.LoadX509KeyPair(path, keyPath) | |
if err != nil { | |
return err | |
} | |
certs = append(certs, cert) | |
} | |
return nil | |
} | |
err := filepath.Walk(path, walkFn) | |
if err != nil { | |
return nil, err | |
} | |
return certs, nil | |
} | |
func ListenAndServeTLSSNI(srv *http.Server, certs []tls.Certificate) error { | |
config := tls.Config{ | |
Certificates: certs, | |
} | |
config.BuildNameToCertificate() | |
conn, err := net.Listen("tcp", srv.Addr) | |
if err != nil { | |
return err | |
} | |
return srv.Serve(tls.NewListener(conn, &config)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment