Skip to content

Instantly share code, notes, and snippets.

@Quentin-M
Created January 4, 2018 00:45
Show Gist options
  • Save Quentin-M/b8a6aa1742afab260fec733dc141d788 to your computer and use it in GitHub Desktop.
Save Quentin-M/b8a6aa1742afab260fec733dc141d788 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
crand "crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"github.com/spf13/viper"
"log"
"math/big"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
"time"
)
func init() {
rand.Seed(time.Now().UTC().UnixNano())
viper.SetConfigName("bearer-rproxy")
viper.AddConfigPath(".")
viper.AddConfigPath("/etc/")
viper.SetEnvPrefix("BRP")
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_"))
viper.AutomaticEnv()
viper.SetDefault("listen.address", "0.0.0.0:9010")
viper.SetDefault("listen.auto-tls", "true")
viper.SetDefault("listen.graceful-timeout", "5s")
viper.SetDefault("backend.request-timeout", "5s")
}
func main() {
// Build server.
s := &http.Server{
Addr: viper.GetString("listen.address"),
Handler: handlersChain(),
TLSConfig: tlsConfig(),
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0),
}
// Listen and Serve.
go func() {
log.Printf("listening on %s (TLS: %v)\n", s.Addr, s.TLSConfig != nil)
if s.TLSConfig != nil {
if err := s.ListenAndServeTLS("", ""); err != http.ErrServerClosed {
log.Fatalf("failed to listen/serve: %v", err)
}
} else {
if err := s.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalf("failed to listen/serve: %v", err)
}
}
}()
// Wait for SIGTERM signal, to attempt a graceful shutdown.
waitSigterm()
ctx, cancel := context.WithTimeout(context.Background(), viper.GetDuration("listen.graceful-timeout"))
defer cancel()
log.Printf("stopping gracefully (%v)\n", viper.GetDuration("listen.graceful-timeout"))
if err := s.Shutdown(ctx); err != nil {
log.Printf("failed to stop gracefully: %v\n", err)
}
}
func selfSignedCertificate() (tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to create pub/priv key pair: %v", err)
}
privDERBytes, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to marshal pub/priv key pair in DER: %v", err)
}
snLimit := new(big.Int).Lsh(big.NewInt(1), 128)
sn, err := crand.Int(crand.Reader, snLimit)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to generate serial number: %v", err)
}
certTmpl := x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{
Organization: []string{"BitMEX"},
OrganizationalUnit: []string{"Bearer RProxy"},
},
Issuer: pkix.Name{
Organization: []string{"BitMEX"},
OrganizationalUnit: []string{"Bearer RProxy"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
}
// TODO: IP Addresses / DNS Names
certDERBytes, err := x509.CreateCertificate(crand.Reader, &certTmpl, &certTmpl, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to generate certificate: %v", err)
}
var pubPEM, privPEM bytes.Buffer
pem.Encode(&pubPEM, &pem.Block{Type: "CERTIFICATE", Bytes: certDERBytes})
pem.Encode(&privPEM, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privDERBytes})
cert, err := tls.X509KeyPair(pubPEM.Bytes(), privPEM.Bytes())
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to parse generated certificate: %v", err)
}
return cert, nil
}
func tlsConfig() *tls.Config {
if !viper.GetBool("listen.auto-tls") {
log.Print("auto-tls not specified, tls is disabled (plain-text bearer!)")
return nil
}
// Generate a self-signed certificate.
certificate, err := selfSignedCertificate()
if err != nil {
log.Fatalf("failed to create self-signed certificate: %v", err)
}
// Create and return an associated TLS configuration.
return &tls.Config{
Certificates: []tls.Certificate{certificate},
MinVersion: tls.VersionTLS12,
CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256},
PreferServerCipherSuites: true,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, // cURL 7.54, LibreSSL/2.0.20
},
}
}
func handlersChain() (h http.Handler) {
h = rproxyHandler()
h = http.TimeoutHandler(h, viper.GetDuration("backend.request-timeout"), "backend timed-out / failed")
h = bearerHandler(h)
return
}
func parseRoutes() (map[string]*url.URL, error) {
r := viper.GetStringSlice("backend.routes")
if len(r) == 0 {
return nil, errors.New("no routes specified, exiting")
}
routes := make(map[string]*url.URL)
for _, rd := range r {
rds := strings.Split(rd, "=")
if len(rds) != 2 {
return routes, fmt.Errorf("could not parse route %q: format is <route name>=<route URL>", rd)
}
rTargetURL, err := url.Parse(rds[1])
if err != nil {
return routes, fmt.Errorf("could not parse route URL %q: %v", rd, err)
}
if strings.Contains(rds[0], "/") {
return routes, fmt.Errorf("could not parse route %q: route name cannot contain /", rd)
}
routes[rds[0]] = rTargetURL
log.Printf("loaded route: %v -> %v\n", rds[0], rds[1])
}
return routes, nil
}
func rproxyHandler() http.Handler {
routes, err := parseRoutes()
if err != nil {
log.Fatal(err)
}
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
// Find target based on first section of request's path.
sPath := strings.Split(req.URL.Path, "/")
if len(sPath) <= 1 {
http.Error(rw, "no route target specified", http.StatusBadGateway)
return
}
target, ok := routes[sPath[1]]
if !ok {
http.Error(rw, "unknown route target", http.StatusBadGateway)
return
}
// Reverse proxy the request.
rp := &httputil.ReverseProxy{
Director: func(req *http.Request) {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.Host = req.URL.Host
req.URL.Path = singleJoiningSlash(target.Path, strings.TrimPrefix(req.URL.Path, "/"+sPath[1]))
targetQuery := target.RawQuery
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
},
}
rp.ServeHTTP(rw, req)
})
}
func bearerHandler(h http.Handler) http.Handler {
bearer := viper.GetString("listen.token")
if len(bearer) == 0 {
log.Print("no bearer token specified, bearer auth is disabled")
}
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") != bearer {
http.Error(rw, "bearer token is invalid", http.StatusUnauthorized)
return
}
req.Header.Del("Authorization")
h.ServeHTTP(rw, req)
})
}
func singleJoiningSlash(a, b string) string {
if len(b) == 0 {
return a
}
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func waitSigterm() {
stop := make(chan os.Signal, 1)
signal.Notify(stop, syscall.SIGTERM)
<-stop
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment