Created
October 6, 2021 22:34
-
-
Save traetox/52123e7e8234e58b02c369b3d251f1c8 to your computer and use it in GitHub Desktop.
mass DNS resolver
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bufio" | |
"flag" | |
"fmt" | |
"net" | |
"os" | |
"strconv" | |
"strings" | |
"time" | |
"github.com/miekg/dns" | |
) | |
const ( | |
maxRecursion int = 4 | |
) | |
var ( | |
port = flag.Int("port", 53, "port number to use") | |
laddr = flag.String("laddr", "", "local address to use") | |
batchSize = flag.Int("batch-size", 25, "number of inflight requests") | |
timeoutDial = flag.Duration("timeout-dial", 2*time.Second, "Dial timeout") | |
timeoutRead = flag.Duration("timeout-read", 2*time.Second, "Read timeout") | |
timeoutWrite = flag.Duration("timeout-write", 2*time.Second, "Write timeout") | |
) | |
func main() { | |
flag.Usage = func() { | |
fmt.Fprintf(os.Stderr, "Usage: %s [options] [@server] [qtype...] [qclass...] [name ...]\n", os.Args[0]) | |
flag.PrintDefaults() | |
} | |
var ( | |
qname []string | |
files []string | |
) | |
flag.Parse() | |
var nameserver string | |
for _, arg := range flag.Args() { | |
// If it starts with @ it is a nameserver | |
if arg[0] == '@' { | |
nameserver = arg | |
continue | |
} | |
// Anything else is a qname | |
files = append(files, arg) | |
} | |
if len(files) == 0 { | |
fmt.Fprintf(os.Stderr, "no files specified\n") | |
return | |
} | |
for _, f := range files { | |
if lines, err := readFile(f); err != nil { | |
fmt.Fprintf(os.Stderr, "failed to scan %s: %v\n", f, err) | |
} else { | |
qname = append(qname, lines...) | |
} | |
} | |
if len(nameserver) == 0 { | |
conf, err := dns.ClientConfigFromFile("/etc/resolv.conf") | |
if err != nil { | |
fmt.Fprintln(os.Stderr, err) | |
os.Exit(2) | |
} | |
nameserver = "@" + conf.Servers[0] | |
} | |
nameserver = string([]byte(nameserver)[1:]) // chop off @ | |
// if the nameserver is from /etc/resolv.conf the [ and ] are already | |
// added, thereby breaking net.ParseIP. Check for this and don't | |
// fully qualify such a name | |
if nameserver[0] == '[' && nameserver[len(nameserver)-1] == ']' { | |
nameserver = nameserver[1 : len(nameserver)-1] | |
} | |
if i := net.ParseIP(nameserver); i != nil { | |
nameserver = net.JoinHostPort(nameserver, strconv.Itoa(*port)) | |
} else { | |
nameserver = dns.Fqdn(nameserver) + ":" + strconv.Itoa(*port) | |
} | |
c := new(dns.Client) | |
c.Net = "udp" | |
c.DialTimeout = *timeoutDial | |
c.ReadTimeout = *timeoutRead | |
c.WriteTimeout = *timeoutWrite | |
if *laddr != "" { | |
c.Dialer = &net.Dialer{Timeout: c.DialTimeout} | |
ip := net.ParseIP(*laddr) | |
c.Dialer.LocalAddr = &net.UDPAddr{IP: ip} | |
} | |
m := &dns.Msg{ | |
MsgHdr: dns.MsgHdr{ | |
Opcode: dns.OpcodeQuery, | |
RecursionDesired: true, | |
RecursionAvailable:true, | |
}, | |
Compress: true, | |
Question: make([]dns.Question, 1), | |
} | |
co := new(dns.Conn) | |
var err error | |
if c.Dialer != nil { | |
co.Conn, err = c.Dialer.Dial(c.Net, nameserver) | |
} else { | |
co.Conn, err = net.DialTimeout(c.Net, nameserver, *timeoutDial) | |
} | |
if err != nil { | |
fmt.Fprintf(os.Stderr, "Dialing "+nameserver+" failed: "+err.Error()+"\n") | |
return | |
} | |
defer co.Close() | |
var set []string | |
mp := make(map[string]resolves, len(qname)) | |
for len(qname) > 0 { | |
set, qname = peel(qname, *batchSize) | |
if set, err = handle(mp, set, co, m); err != nil { | |
if len(set) > 0 { | |
qname = append(qname, set...) | |
} | |
//check if this is a timeout error, if so, redial and continue | |
if isTimeout(err) { | |
if c.Dialer != nil { | |
co.Conn, err = c.Dialer.Dial(c.Net, nameserver) | |
} else { | |
co.Conn, err = net.DialTimeout(c.Net, nameserver, *timeoutDial) | |
} | |
if err != nil { | |
break | |
} | |
} | |
} | |
} | |
//go do cname recursions | |
for k, v := range mp { | |
for _, cn := range v.CNames { | |
if ips := recurseResolve(cn, co, m); len(ips) > 0 { | |
v.IPs = append(v.IPs, ips...) | |
v.CNames = nil | |
} | |
} | |
mp[k] = v | |
} | |
for k, v := range mp { | |
fmt.Println(k, v.IPs) | |
} | |
} | |
func isTimeout(err error) bool { | |
if err == nil { | |
return false | |
} | |
if ne, ok := err.(net.Error); ok { | |
return ne.Timeout() | |
} | |
return false | |
} | |
type resolves struct { | |
IPs []net.IP | |
CNames []string | |
} | |
func peel(base []string, max int) (set []string, r []string) { | |
if len(base) < max { | |
set = base | |
} else { | |
set = base[0:max] | |
r = base[max:] | |
} | |
return | |
} | |
func handle(mp map[string]resolves, set []string, co *dns.Conn, m *dns.Msg) (missed []string, err error) { | |
inflight := make(map[uint16]string, len(set)) | |
for i := range set { | |
m.Question[0] = dns.Question{Name: dns.Fqdn(set[i]), Qtype: dns.TypeA, Qclass: dns.ClassINET} | |
m.Id = dns.Id() | |
inflight[m.Id] = set[i] | |
co.SetWriteDeadline(time.Now().Add(*timeoutWrite)) | |
if err = co.WriteMsg(m); err != nil { | |
missed = set | |
continue | |
} | |
} | |
var r *dns.Msg | |
for i := 0; i < len(set); i++ { | |
co.SetReadDeadline(time.Now().Add(*timeoutRead)) | |
if r, err = co.ReadMsg(); err != nil { | |
break | |
} else if err = processResponse(mp, r, co, m); err != nil { | |
break | |
} | |
delete(inflight, r.Id) | |
} | |
for _, v := range inflight { | |
missed = append(missed, v) | |
} | |
return | |
} | |
func processResponse(mp map[string]resolves, r *dns.Msg, co *dns.Conn, m *dns.Msg) (err error) { | |
for _, v := range r.Answer { | |
if a, ok := v.(*dns.A); ok { | |
nm := a.Hdr.Name | |
if ip := a.A; ip != nil { | |
if curr, ok := mp[nm]; ok { | |
curr.IPs = append(curr.IPs, ip) | |
mp[nm] = curr | |
} else { | |
mp[nm] = resolves { | |
IPs: []net.IP{ip}, | |
} | |
} | |
} | |
} else if cn, ok := v.(*dns.CNAME); ok { | |
if cn.Target != `` { | |
nm := cn.Hdr.Name | |
if curr, ok := mp[nm]; ok { | |
curr.CNames = append(curr.CNames, cn.Target) | |
mp[nm] = curr | |
} else { | |
mp[nm] = resolves { | |
CNames: []string{cn.Target}, | |
} | |
} | |
} | |
} | |
} | |
return | |
} | |
func readFile(fname string) (lines []string, err error) { | |
var f *os.File | |
if f, err = os.OpenFile(fname, os.O_RDONLY, os.ModePerm); err != nil { | |
return | |
} | |
defer f.Close() | |
sc := bufio.NewScanner(f) | |
for sc.Scan() { | |
if ln := strings.TrimSpace(sc.Text()); len(ln) > 0 { | |
lines = append(lines, ln) | |
} | |
} | |
err = sc.Err() | |
return | |
} | |
func makeQuestion(qname []string) (qs dns.Question, r []string) { | |
if len(qname) == 0 { | |
return | |
} | |
qs = dns.Question{ | |
Name: dns.Fqdn(qname[0]), | |
Qtype: dns.TypeA, | |
Qclass: dns.ClassINET, | |
} | |
r = qname[1:] | |
return | |
} | |
func recurseResolve(nm string, co *dns.Conn, m *dns.Msg) []net.IP { | |
return recurseResolveRunner(nm, co, m, 0, nil) | |
} | |
func recurseResolveRunner(nm string, co *dns.Conn, m *dns.Msg, depth int, inIPs []net.IP) (out []net.IP) { | |
var r *dns.Msg | |
var err error | |
out = inIPs | |
if depth > maxRecursion { | |
return | |
} | |
//do linear query | |
m.Question[0] = dns.Question{Name: dns.Fqdn(nm), Qtype: dns.TypeA, Qclass: dns.ClassINET} | |
m.Id = dns.Id() | |
co.SetWriteDeadline(time.Now().Add(*timeoutWrite)) | |
if err = co.WriteMsg(m); err != nil { | |
return | |
} | |
co.SetReadDeadline(time.Now().Add(*timeoutRead)) | |
if r, err = co.ReadMsg(); err != nil { | |
return | |
} | |
//iterate over responses | |
for _, v := range r.Answer { | |
switch a := v.(type) { | |
case *dns.A: | |
if a.A != nil { | |
out = append(out, a.A) | |
} | |
case *dns.CNAME: | |
if a.Target != `` { | |
out = recurseResolveRunner(a.Target, co, m, depth+1, out) | |
} | |
} | |
} | |
return | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment