Skip to content

Instantly share code, notes, and snippets.

@traetox
Created October 6, 2021 22:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save traetox/52123e7e8234e58b02c369b3d251f1c8 to your computer and use it in GitHub Desktop.
Save traetox/52123e7e8234e58b02c369b3d251f1c8 to your computer and use it in GitHub Desktop.
mass DNS resolver
package main
import (
"bufio"
"flag"
"fmt"
"net"
"os"
"strconv"
"strings"
"time"
"github.com/miekg/dns"
)
const (
maxRecursion int = 4
)
var (
port = flag.Int("port", 53, "port number to use")
laddr = flag.String("laddr", "", "local address to use")
batchSize = flag.Int("batch-size", 25, "number of inflight requests")
timeoutDial = flag.Duration("timeout-dial", 2*time.Second, "Dial timeout")
timeoutRead = flag.Duration("timeout-read", 2*time.Second, "Read timeout")
timeoutWrite = flag.Duration("timeout-write", 2*time.Second, "Write timeout")
)
func main() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [options] [@server] [qtype...] [qclass...] [name ...]\n", os.Args[0])
flag.PrintDefaults()
}
var (
qname []string
files []string
)
flag.Parse()
var nameserver string
for _, arg := range flag.Args() {
// If it starts with @ it is a nameserver
if arg[0] == '@' {
nameserver = arg
continue
}
// Anything else is a qname
files = append(files, arg)
}
if len(files) == 0 {
fmt.Fprintf(os.Stderr, "no files specified\n")
return
}
for _, f := range files {
if lines, err := readFile(f); err != nil {
fmt.Fprintf(os.Stderr, "failed to scan %s: %v\n", f, err)
} else {
qname = append(qname, lines...)
}
}
if len(nameserver) == 0 {
conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(2)
}
nameserver = "@" + conf.Servers[0]
}
nameserver = string([]byte(nameserver)[1:]) // chop off @
// if the nameserver is from /etc/resolv.conf the [ and ] are already
// added, thereby breaking net.ParseIP. Check for this and don't
// fully qualify such a name
if nameserver[0] == '[' && nameserver[len(nameserver)-1] == ']' {
nameserver = nameserver[1 : len(nameserver)-1]
}
if i := net.ParseIP(nameserver); i != nil {
nameserver = net.JoinHostPort(nameserver, strconv.Itoa(*port))
} else {
nameserver = dns.Fqdn(nameserver) + ":" + strconv.Itoa(*port)
}
c := new(dns.Client)
c.Net = "udp"
c.DialTimeout = *timeoutDial
c.ReadTimeout = *timeoutRead
c.WriteTimeout = *timeoutWrite
if *laddr != "" {
c.Dialer = &net.Dialer{Timeout: c.DialTimeout}
ip := net.ParseIP(*laddr)
c.Dialer.LocalAddr = &net.UDPAddr{IP: ip}
}
m := &dns.Msg{
MsgHdr: dns.MsgHdr{
Opcode: dns.OpcodeQuery,
RecursionDesired: true,
RecursionAvailable:true,
},
Compress: true,
Question: make([]dns.Question, 1),
}
co := new(dns.Conn)
var err error
if c.Dialer != nil {
co.Conn, err = c.Dialer.Dial(c.Net, nameserver)
} else {
co.Conn, err = net.DialTimeout(c.Net, nameserver, *timeoutDial)
}
if err != nil {
fmt.Fprintf(os.Stderr, "Dialing "+nameserver+" failed: "+err.Error()+"\n")
return
}
defer co.Close()
var set []string
mp := make(map[string]resolves, len(qname))
for len(qname) > 0 {
set, qname = peel(qname, *batchSize)
if set, err = handle(mp, set, co, m); err != nil {
if len(set) > 0 {
qname = append(qname, set...)
}
//check if this is a timeout error, if so, redial and continue
if isTimeout(err) {
if c.Dialer != nil {
co.Conn, err = c.Dialer.Dial(c.Net, nameserver)
} else {
co.Conn, err = net.DialTimeout(c.Net, nameserver, *timeoutDial)
}
if err != nil {
break
}
}
}
}
//go do cname recursions
for k, v := range mp {
for _, cn := range v.CNames {
if ips := recurseResolve(cn, co, m); len(ips) > 0 {
v.IPs = append(v.IPs, ips...)
v.CNames = nil
}
}
mp[k] = v
}
for k, v := range mp {
fmt.Println(k, v.IPs)
}
}
func isTimeout(err error) bool {
if err == nil {
return false
}
if ne, ok := err.(net.Error); ok {
return ne.Timeout()
}
return false
}
type resolves struct {
IPs []net.IP
CNames []string
}
func peel(base []string, max int) (set []string, r []string) {
if len(base) < max {
set = base
} else {
set = base[0:max]
r = base[max:]
}
return
}
func handle(mp map[string]resolves, set []string, co *dns.Conn, m *dns.Msg) (missed []string, err error) {
inflight := make(map[uint16]string, len(set))
for i := range set {
m.Question[0] = dns.Question{Name: dns.Fqdn(set[i]), Qtype: dns.TypeA, Qclass: dns.ClassINET}
m.Id = dns.Id()
inflight[m.Id] = set[i]
co.SetWriteDeadline(time.Now().Add(*timeoutWrite))
if err = co.WriteMsg(m); err != nil {
missed = set
continue
}
}
var r *dns.Msg
for i := 0; i < len(set); i++ {
co.SetReadDeadline(time.Now().Add(*timeoutRead))
if r, err = co.ReadMsg(); err != nil {
break
} else if err = processResponse(mp, r, co, m); err != nil {
break
}
delete(inflight, r.Id)
}
for _, v := range inflight {
missed = append(missed, v)
}
return
}
func processResponse(mp map[string]resolves, r *dns.Msg, co *dns.Conn, m *dns.Msg) (err error) {
for _, v := range r.Answer {
if a, ok := v.(*dns.A); ok {
nm := a.Hdr.Name
if ip := a.A; ip != nil {
if curr, ok := mp[nm]; ok {
curr.IPs = append(curr.IPs, ip)
mp[nm] = curr
} else {
mp[nm] = resolves {
IPs: []net.IP{ip},
}
}
}
} else if cn, ok := v.(*dns.CNAME); ok {
if cn.Target != `` {
nm := cn.Hdr.Name
if curr, ok := mp[nm]; ok {
curr.CNames = append(curr.CNames, cn.Target)
mp[nm] = curr
} else {
mp[nm] = resolves {
CNames: []string{cn.Target},
}
}
}
}
}
return
}
func readFile(fname string) (lines []string, err error) {
var f *os.File
if f, err = os.OpenFile(fname, os.O_RDONLY, os.ModePerm); err != nil {
return
}
defer f.Close()
sc := bufio.NewScanner(f)
for sc.Scan() {
if ln := strings.TrimSpace(sc.Text()); len(ln) > 0 {
lines = append(lines, ln)
}
}
err = sc.Err()
return
}
func makeQuestion(qname []string) (qs dns.Question, r []string) {
if len(qname) == 0 {
return
}
qs = dns.Question{
Name: dns.Fqdn(qname[0]),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
r = qname[1:]
return
}
func recurseResolve(nm string, co *dns.Conn, m *dns.Msg) []net.IP {
return recurseResolveRunner(nm, co, m, 0, nil)
}
func recurseResolveRunner(nm string, co *dns.Conn, m *dns.Msg, depth int, inIPs []net.IP) (out []net.IP) {
var r *dns.Msg
var err error
out = inIPs
if depth > maxRecursion {
return
}
//do linear query
m.Question[0] = dns.Question{Name: dns.Fqdn(nm), Qtype: dns.TypeA, Qclass: dns.ClassINET}
m.Id = dns.Id()
co.SetWriteDeadline(time.Now().Add(*timeoutWrite))
if err = co.WriteMsg(m); err != nil {
return
}
co.SetReadDeadline(time.Now().Add(*timeoutRead))
if r, err = co.ReadMsg(); err != nil {
return
}
//iterate over responses
for _, v := range r.Answer {
switch a := v.(type) {
case *dns.A:
if a.A != nil {
out = append(out, a.A)
}
case *dns.CNAME:
if a.Target != `` {
out = recurseResolveRunner(a.Target, co, m, depth+1, out)
}
}
}
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment