Skip to content

Instantly share code, notes, and snippets.

@paulo-raca
Last active June 12, 2023 16:16
Show Gist options
  • Save paulo-raca/f9e7c629326becbfe5c7ad7848db19e9 to your computer and use it in GitHub Desktop.
Save paulo-raca/f9e7c629326becbfe5c7ad7848db19e9 to your computer and use it in GitHub Desktop.
Support for `Forwarded`, `X-Forwarded` and `X-Real-IP` in Go ReverseProxy

This code ensures X-Forward (and friends) headers are updated correctly when using stacked proxies: the host, protocol and "for" should be inherited from the previous proxies.

Of course, there are 3 competing/similar ways to retrieve this information from the previous Proxy, so this tries to be compatible with them all:

  • Forwarded (new, nicer and standardized in rfc7239)
  • X-Forwarded-[For/Host/Proto], the legacy but de-facto standard
  • X-Real-IP, a less-common alternative to X-Forwarded-For[0].

Also, retrieving them from the incoming request requires caution: Malicious clients may try inject these headers, so these headers should only be used when we expect them.

Likewise, this code is capable of producing all of the formats, but you should enable only those expected by the next Proxy. (Otherwise the next proxy may update one of the headers and not the other, resulting in inconsistencies)

It tries to be lenient with inputs (e.g., IPv6 may be formatted either with our without [], and strict in the output (anything weird becomes unknown or 0.0.0.0)

package main
import (
"net"
"net/http/httputil"
"regexp"
"strings"
"github.com/samber/lo"
"github.com/theckman/httpforwarded"
)
// Configuration
type ConfigForwarded struct {
SetOutbound ConfigEnabledForwardedHeaders
PreserveInbound ConfigEnabledForwardedHeaders
}
type ConfigEnabledForwardedHeaders struct {
Forwarded bool
XForwarded bool
XRealIP bool
}
func (h ConfigEnabledForwardedHeaders) Any() bool {
return h.Forwarded || h.XForwarded || h.XRealIP
}
// Replacement for ProxyRequest.SetXForwarded()
func SetForwardedHeaders(r *httputil.ProxyRequest, cfg ConfigForwarded) {
// Cleanup proxy forward headers.
// Some of them are automatically copied from the request,
// and we don't want it
r.Out.Header.Del("Forwarded")
r.Out.Header.Del("X-Forwarded-Host")
r.Out.Header.Del("X-Forwarded-Proto")
r.Out.Header.Del("X-Forwarded-For")
r.Out.Header.Del("X-Real-IP")
if !cfg.SetOutbound.Any() {
return
}
var forwardedFor []string
var forwardedBy []string
var forwardedHost string
forwardedHost = r.In.Host
var forwardedProto string
if r.In.TLS == nil {
forwardedProto = "http"
} else {
forwardedProto = "https"
}
// Preserve data from inbound headers
if cfg.PreserveInbound.Any() {
forwardedHeader, err := httpforwarded.Parse(r.In.Header.Values("Forwarded"))
if err != nil {
// log.Warning("malformed Forwarded header: %w", err)
forwardedHeader = map[string][]string{}
}
xForwardHost := r.In.Header.Values("X-Forwarded-Host")
xForwardProto := r.In.Header.Values("X-Forwarded-Proto")
xForwardFor := r.In.Header.Values("X-Forwarded-For")
xRealIp := r.In.Header.Values("X-Real-IP")
// "for": retrieve from either "Forwarded", "X-Forwarded-For" or "X-Real-IP"
// TODO: Fix conflicts using topologic sort
switch {
case cfg.PreserveInbound.Forwarded && len(forwardedHeader["for"]) > 0:
forwardedFor = forwardedHeader["for"]
case cfg.PreserveInbound.XForwarded && len(xForwardFor) > 0:
forwardedFor = xForwardFor
case cfg.PreserveInbound.XRealIP && len(xRealIp) > 0:
forwardedFor = []string{xRealIp[0]}
}
// "by": There is no X- equivalent
switch {
case cfg.PreserveInbound.Forwarded && len(forwardedHeader["by"]) > 0:
forwardedBy = forwardedHeader["by"]
}
// host: retrieve from either "Forwarded" or "X-Forwarded-Host"
switch {
case cfg.PreserveInbound.Forwarded && len(forwardedHeader["host"]) > 0:
forwardedHost = forwardedHeader["host"][0]
case cfg.PreserveInbound.XForwarded && len(xForwardHost) > 0:
forwardedHost = xForwardHost[0]
}
// host: retrieve from either "Forwarded" or "X-Forwarded-Proto"
switch {
case cfg.PreserveInbound.Forwarded && len(forwardedHeader["proto"]) > 0:
forwardedProto = forwardedHeader["proto"][0]
case cfg.PreserveInbound.XForwarded && len(xForwardProto) > 0:
forwardedProto = xForwardProto[0]
}
}
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr) // discard port on client-side
if err != nil {
clientIP = r.In.RemoteAddr
}
forwardedFor = append(forwardedFor, clientIP)
serverIP := r.Out.RemoteAddr // Preserve port in server-side
forwardedBy = append(forwardedBy, serverIP)
if cfg.SetOutbound.Forwarded {
r.Out.Header.Set("Forwarded", httpforwarded.Format(map[string][]string{
"for": lo.Map(forwardedFor, func(value string, _ int) string {
return formatNodeIdentifier(value, true)
}),
"by": lo.Map(forwardedBy, func(value string, _ int) string {
return formatNodeIdentifier(value, true)
}),
"host": {forwardedHost},
"proto": {forwardedProto},
}))
}
if cfg.SetOutbound.XForwarded {
r.Out.Header.Set("X-Forwarded-Host", forwardedHost)
r.Out.Header.Set("X-Forwarded-Proto", forwardedProto)
r.Out.Header.Del("X-Forwarded-For")
for _, v := range forwardedFor {
r.Out.Header.Add("X-Forwarded-For", formatNodeIdentifier(v, false))
}
}
if cfg.SetOutbound.XRealIP {
r.Out.Header.Set("X-Real-IP", formatNodeIdentifier(forwardedFor[0], false))
}
}
const LegacyUnknownHopIP = "0.0.0.0"
var numericPortRegex = regexp.MustCompile(`^\d{1,5}$`)
var obfuscatedIdentifierRegex = regexp.MustCompile(`^_[a-zA-Z0-9._-]+$`)
func formatNodeIdentifier(value string, newSyntax bool) string {
// See https: //www.rfc-editor.org/rfc/rfc7239.html#section-6
// IPv4, IPv6
prettyIP := formatIP(value, newSyntax)
if prettyIP != "" {
return prettyIP
}
// IPv4:port, [IPv6]:port
collonPos := strings.LastIndex(value, ":")
if collonPos >= 0 {
ip := value[0:collonPos]
port := value[collonPos+1:]
// port can be either a numeric value or an _obfuscate_identifier
if !numericPortRegex.MatchString(port) && !obfuscatedIdentifierRegex.MatchString(port) {
port = ""
}
prettyIP := formatIP(ip, newSyntax)
if prettyIP != "" {
if newSyntax && port != "" {
return prettyIP + ":" + port
} else {
return prettyIP
}
}
}
if !newSyntax {
return LegacyUnknownHopIP
}
if obfuscatedIdentifierRegex.MatchString(value) {
// _obfuscated_value
return value
} else {
// Unknown / unparseable
return "unknown"
}
}
func formatIP(value string, newSyntax bool) string {
rawIp := net.ParseIP(value)
if rawIp == nil {
// Stupid way to transform [IPv6] -> IPv6
host, _, err := net.SplitHostPort(value + ":0")
if err != nil {
return ""
}
rawIp = net.ParseIP(host)
if rawIp == nil {
return ""
}
}
ret := rawIp.String()
if newSyntax {
if ret == LegacyUnknownHopIP {
return "unknown"
}
if strings.Contains(ret, ":") {
// This is an IPv6 and must be enclosed in braces
ret = "[" + ret + "]"
}
}
return ret
}
func main() {}
package gate
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFormatIP(t *testing.T) {
t.Parallel()
assertformatIP(t, "1.2.3.4", "1.2.3.4", "1.2.3.4")
assertformatIP(t, "::", "[::]", "::")
assertformatIP(t, "[::]", "[::]", "::")
assertformatIP(t, "0001:0002:0:0:0:0:0003:0004", "[1:2::3:4]", "1:2::3:4")
assertformatIP(t, "[0001:0002:0:0:0:0:0003:0004]", "[1:2::3:4]", "1:2::3:4")
assertformatIP(t, "1.2.3", "", "")
assertformatIP(t, "foo::bar", "", "")
assertformatIP(t, "[foo::bar]", "", "")
assertformatIP(t, "", "", "")
}
func TestFormatHop(t *testing.T) {
t.Parallel()
assertFormatHop(t, "0.0.0.0", "unknown", "0.0.0.0")
assertFormatHop(t, "1.2.3.4", "1.2.3.4", "1.2.3.4")
assertFormatHop(t, "1.2.3.4:", "1.2.3.4", "1.2.3.4")
assertFormatHop(t, "1.2.3.4:_", "1.2.3.4", "1.2.3.4")
assertFormatHop(t, "1.2.3.4:port", "1.2.3.4", "1.2.3.4")
assertFormatHop(t, "1.2.3.4:5617", "1.2.3.4:5617", "1.2.3.4")
assertFormatHop(t, "1.2.3.4:123456", "1.2.3.4", "1.2.3.4")
assertFormatHop(t, "1.2.3.4:_port", "1.2.3.4:_port", "1.2.3.4")
assertFormatHop(t, "1.2.3.4:_😊", "1.2.3.4", "1.2.3.4")
assertFormatHop(t, "::1", "[::1]", "::1")
assertFormatHop(t, "::1:", "[::1]", "::1")
assertFormatHop(t, "::1:5617", "[::1:5617]", "::1:5617")
assertFormatHop(t, "[::1]:5617", "[::1]:5617", "::1")
assertFormatHop(t, "unknown", "unknown", "0.0.0.0")
assertFormatHop(t, "foobar", "unknown", "0.0.0.0")
assertFormatHop(t, "_foobar", "_foobar", "0.0.0.0")
assertFormatHop(t, "_😊", "unknown", "0.0.0.0")
}
func assertFormatHop(t *testing.T, value, expectedNew, expectedOld string) {
t.Helper()
assert.Equal(t, expectedNew, formatNodeIdentifier(value, true))
assert.Equal(t, expectedOld, formatNodeIdentifier(value, false))
}
func assertformatIP(t *testing.T, value, expectedNew, expectedOld string) {
t.Helper()
assert.Equal(t, expectedNew, formatIP(value, true))
assert.Equal(t, expectedOld, formatIP(value, false))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment