Skip to content

Instantly share code, notes, and snippets.

@003random
Created December 8, 2019 20:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 003random/9e6acb6a33785e4717e1b245e58a0726 to your computer and use it in GitHub Desktop.
Save 003random/9e6acb6a33785e4717e1b245e58a0726 to your computer and use it in GitHub Desktop.
Golang SSRF Protection & Prevention Using a Whitelist
package main
import (
"context"
"errors"
"log"
"net"
"net/http"
"strings"
"time"
)
var (
ranges []string = []string{
// CIDR of xyz,
// only allow IP addresses in this range
"84.65.0.0/24",
}
whitelist []net.IPNet = createWhitelist()
)
func main() {
tr := &http.Transport{
DialContext: func(ctx context.Context, network string, addr string) (conn net.Conn, err error) {
s := strings.LastIndex(addr, ":")
IPs, err := net.LookupHost(addr[:s])
if err != nil {
return nil, err
}
for _, IP := range IPs {
if !isAllowed(IP) {
err = errors.New("IP not allowed")
return
}
conn, err = net.Dial(network, IP+addr[s:])
if err == nil {
break
}
}
return
},
}
var client = &http.Client{
Transport: tr,
Timeout: time.Duration(5 * time.Second),
}
req, err := http.NewRequest("GET", "http://84.65.0.6", nil)
if err != nil {
log.Println(err)
return
}
_, err = client.Do(req)
if err != nil {
log.Println(err)
return
}
}
func createWhitelist() []net.IPNet {
b := []net.IPNet{}
for _, sCIDR := range ranges {
_, c, _ := net.ParseCIDR(sCIDR)
b = append(b, *c)
}
return b
}
func isAllowed(sIP string) bool {
for _, r := range whitelist {
if r.Contains(net.ParseIP(sIP)) {
return true
}
}
return false
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment