Skip to content

Instantly share code, notes, and snippets.

@groob
Created January 30, 2017 23:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save groob/ea563ea1f3092449cd75eeb78213cd83 to your computer and use it in GitHub Desktop.
Save groob/ea563ea1f3092449cd75eeb78213cd83 to your computer and use it in GitHub Desktop.
package main
import (
"crypto/tls"
"flag"
"io"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/acme/cloudops/env"
"github.com/acme/cloudops/version"
)
func authMW(authURL, upstreamURL string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// if osqueryd don't check auth. nodeKey is checked by the server
if !strings.HasPrefix(r.URL.Path, "/api/v1/osquery") {
resp, err := authRequest(authURL+"/oauth2/auth", r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if resp.StatusCode != http.StatusAccepted {
http.RedirectHandler("https://"+r.Host+"/oauth2/sign_in",
http.StatusTemporaryRedirect).ServeHTTP(w, r)
return
}
}
// proxy websocket requests
if isWebsocket(r) {
handleWebsocket(w, r)
return
}
// proxy regular http requests
var uri string
if upstreamURL != "" {
uri = upstreamURL
} else {
uri = upstreamFromHost(r.Host)
}
u, err := url.Parse(uri)
if err != nil {
log.Printf("upstream: %s\n", uri)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
r.Header.Set("X-Forwarded-Proto", "https")
p := httputil.NewSingleHostReverseProxy(u)
p.ServeHTTP(w, r)
}
}
func isWebsocket(r *http.Request) bool {
upgrade := r.Header.Get("upgrade")
conn := r.Header.Get("connection")
if conn == "Upgrade" && upgrade == "websocket" {
return true
}
return false
}
func handleWebsocket(w http.ResponseWriter, r *http.Request) {
var (
ConnectionHeaderKey = http.CanonicalHeaderKey("connection")
SetCookieHeaderKey = http.CanonicalHeaderKey("set-cookie")
UpgradeHeaderKey = http.CanonicalHeaderKey("upgrade")
WSKeyHeaderKey = http.CanonicalHeaderKey("sec-websocket-key")
WSProtocolHeaderKey = http.CanonicalHeaderKey("sec-websocket-protocol")
WSVersionHeaderKey = http.CanonicalHeaderKey("sec-websocket-version")
HandshakeHeaders = []string{ConnectionHeaderKey, UpgradeHeaderKey, WSVersionHeaderKey, WSKeyHeaderKey}
UpgradeHeaders = []string{SetCookieHeaderKey, WSProtocolHeaderKey}
)
// Copy request headers and remove websocket handshaking headers
// before submitting to the upstream server
upstreamHeader := http.Header{}
for key, _ := range r.Header {
copyHeader(&upstreamHeader, r.Header, key)
}
for _, header := range HandshakeHeaders {
delete(upstreamHeader, header)
}
upstreamHeader.Set("Host", r.Host)
// Connect upstream
upstreamHost := strings.TrimPrefix(upstreamFromHost(r.Host)+":80", "http://")
upstreamAddr := upstreamWSURL(*r.URL, upstreamHost, "http").String()
upstream, upstreamResp, err := websocket.DefaultDialer.Dial(upstreamAddr, upstreamHeader)
if err != nil {
if upstreamResp != nil {
log.Printf("dialing upstream websocket failed with code %d: %v", upstreamResp.StatusCode, err)
} else {
log.Printf("dialing upstream websocket failed: %v", err)
}
http.Error(w, "websocket unavailable", http.StatusServiceUnavailable)
return
}
defer upstream.Close()
// Pass websocket handshake response headers to the upgrader
upgradeHeader := http.Header{}
copyHeaders(&upgradeHeader, upstreamResp.Header, UpgradeHeaders)
// Upgrade the client connection without validating the origin
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
client, err := upgrader.Upgrade(w, r, upgradeHeader)
if err != nil {
log.Printf("couldn't upgrade websocket request: %v", err)
http.Error(w, "websocket upgrade failed", http.StatusServiceUnavailable)
return
}
// Wire both sides together and close when finished
var wg sync.WaitGroup
cp := func(dst, src *websocket.Conn) {
defer wg.Done()
_, err := io.Copy(dst.UnderlyingConn(), src.UnderlyingConn())
var closeMessage []byte
if err != nil {
closeMessage = websocket.FormatCloseMessage(websocket.CloseProtocolError, err.Error())
} else {
closeMessage = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bye")
}
// Attempt to close the connection properly
dst.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(2*time.Second))
src.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(2*time.Second))
}
wg.Add(2)
go cp(upstream, client)
go cp(client, upstream)
wg.Wait()
}
// Create a websocket URL from the request URL
func upstreamWSURL(r url.URL, upstreamHost, upstreamScheme string) *url.URL {
ws := r
ws.User = r.User
ws.Host = upstreamHost
ws.Fragment = ""
switch upstreamScheme {
case "http":
ws.Scheme = "ws"
case "https":
ws.Scheme = "wss"
}
return &ws
}
func copyHeaders(dst *http.Header, src http.Header, headers []string) {
for _, header := range headers {
copyHeader(dst, src, header)
}
}
// Copy any non-empty and non-blank header values
func copyHeader(dst *http.Header, src http.Header, header string) {
for _, value := range src[header] {
if value != "" {
dst.Add(header, value)
}
}
}
// map upstream to k8s svc
func upstreamFromHost(host string) string {
host = prependZ(host)
dashed := strings.Replace(host, ".", "-", -1)
upstream := "http://" + dashed + ".default.svc.cluster.local"
return upstream
}
// k8s does not allow having service names that start with a number
// we prepend `zzzzz`(5zs) to service names that start with a number
// for example 100.pr.acme.net should be zzzzz1000-pr-acme-net
func prependZ(host string) string {
split := strings.Split(host, ".")
_, err := strconv.Atoi(split[0])
if err == nil {
return "zzzzz" + host
}
return host
}
func main() {
var (
flTLS = flag.Bool("tls", env.Bool("USE_TLS", true), "use tls")
flHTTPAddress = flag.String("http.address", env.String("HTTP_ADDRESS", "0.0.0.0:443"), "http listen address")
flCert = flag.String("tls.certificate", env.String("TLS_CERTIFICATE", "secrets/tls.crt"), "tls server certificate")
flKey = flag.String("tls.key", env.String("TLS_KEY", "secrets/tls.key"), "tls server certificate private key")
flCertDir = flag.String("tls.certificates", env.String("TLS_CERTIFICATES", ""), "path to tls server certificate directory")
flAuthProxyURL = flag.String("oauthproxy.url", env.String("OAUTH_PROXY_URL", ""), "url of the oauth2proxy backend")
flUpstreamURL = flag.String("upstream.url", env.String("UPSTREAM_URL", ""), "url of the backend service you want to proxy")
)
flag.Parse()
http.Handle("/", authMW(*flAuthProxyURL, *flUpstreamURL))
http.Handle("/oauth2/", authBackend(*flAuthProxyURL))
http.Handle("/-/version", version.Handler())
if *flTLS && *flCertDir == "" {
log.Fatal(http.ListenAndServeTLS(*flHTTPAddress, *flCert, *flKey, nil))
} else if *flCertDir != "" {
certs, err := loadCertDir(*flCertDir)
if err != nil {
log.Fatal(err)
}
srv := http.Server{
Addr: *flHTTPAddress,
}
log.Fatal(ListenAndServeTLSSNI(&srv, certs))
} else {
log.Fatal(http.ListenAndServe(*flHTTPAddress, nil))
}
}
func authBackend(authURL string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
r.Header.Set("X-Auth-Request-Redirect", "https://oauth.acme.net/oauth2/callback/"+r.Host)
u, err := url.Parse(authURL)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
httputil.NewSingleHostReverseProxy(u).ServeHTTP(w, r)
}
}
func authRequest(url string, upstream *http.Request) (*http.Response, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
req.Header = upstream.Header
resp, err := http.DefaultClient.Do(req)
return resp, err
}
// walk a folder like /secrets and load each tls.cert, tls.key into a
// tls certificate. For use with ListenANdServeTLSSNI.
func loadCertDir(path string) ([]tls.Certificate, error) {
var certs []tls.Certificate
walkFn := func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if strings.HasSuffix(info.Name(), ".crt") {
keyPath := filepath.Join(filepath.Dir(path), "tls.key")
cert, err := tls.LoadX509KeyPair(path, keyPath)
if err != nil {
return err
}
certs = append(certs, cert)
}
return nil
}
err := filepath.Walk(path, walkFn)
if err != nil {
return nil, err
}
return certs, nil
}
func ListenAndServeTLSSNI(srv *http.Server, certs []tls.Certificate) error {
config := tls.Config{
Certificates: certs,
}
config.BuildNameToCertificate()
conn, err := net.Listen("tcp", srv.Addr)
if err != nil {
return err
}
return srv.Serve(tls.NewListener(conn, &config))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment