Skip to content

Instantly share code, notes, and snippets.

@knobunc
Created June 6, 2018 19:01
Show Gist options
  • Save knobunc/5282dfd23e77f6216232fa6d1c6ef7b8 to your computer and use it in GitHub Desktop.
Save knobunc/5282dfd23e77f6216232fa6d1c6ef7b8 to your computer and use it in GitHub Desktop.
package main
import (
"fmt"
"log"
"net"
"sync"
"time"
"github.com/miekg/dns"
)
func main() {
dnsInfo, err := NewDNS("/etc/resolv.conf")
if err != nil {
log.Fatal("xxx")
}
dnsInfo.Add("www.amazon.com")
}
const (
defaultTTL = 30 * time.Minute
)
type dnsValue struct {
// All IPv4 addresses for a given domain name
ips []net.IP
// Time-to-live value from non-authoritative/cached name server for the domain
ttl time.Duration
// Holds (last dns lookup time + ttl), tells when to refresh IPs next time
nextQueryTime time.Time
}
type DNS struct {
// Protects dnsMap operations
lock sync.Mutex
// Holds dns name and its corresponding information
dnsMap map[string]dnsValue
// DNS resolvers
nameservers []string
// DNS port
port string
}
func NewDNS(resolverConfigFile string) (*DNS, error) {
config, err := dns.ClientConfigFromFile(resolverConfigFile)
if err != nil || config == nil {
return nil, fmt.Errorf("ASDASD")
}
return &DNS{
dnsMap: map[string]dnsValue{},
nameservers: filterIPv4Servers(config.Servers),
port: config.Port,
}, nil
}
func (d *DNS) Size() int {
d.lock.Lock()
defer d.lock.Unlock()
return len(d.dnsMap)
}
func (d *DNS) Get(dns string) dnsValue {
d.lock.Lock()
defer d.lock.Unlock()
data := dnsValue{}
if res, ok := d.dnsMap[dns]; ok {
data.ips = make([]net.IP, len(res.ips))
copy(data.ips, res.ips)
data.ttl = res.ttl
data.nextQueryTime = res.nextQueryTime
}
return data
}
func (d *DNS) Add(dns string) error {
d.lock.Lock()
defer d.lock.Unlock()
d.dnsMap[dns] = dnsValue{}
err, _ := d.updateOne(dns)
if err != nil {
delete(d.dnsMap, dns)
}
return err
}
func (d *DNS) Update() (error, bool) {
d.lock.Lock()
defer d.lock.Unlock()
errList := []error{}
changed := false
for dns := range d.dnsMap {
err, updated := d.updateOne(dns)
if err != nil {
errList = append(errList, err)
continue
}
if updated {
changed = true
}
}
return nil, changed
}
func (d *DNS) updateOne(dns string) (error, bool) {
res, ok := d.dnsMap[dns]
if !ok {
// Should not happen, all operations on dnsMap are synchronized by d.lock
return fmt.Errorf("DNS value not found in dnsMap for domain: %q", dns), false
}
ips, minTTL, err := d.getIPsAndMinTTL(dns)
if err != nil {
return err, false
}
log.Printf("dns %s ttl %d\n", dns, minTTL)
changed := false
if !ipsEqual(res.ips, ips) {
changed = true
}
res.ips = ips
res.ttl = minTTL
res.nextQueryTime = time.Now().Add(res.ttl)
d.dnsMap[dns] = res
return nil, changed
}
func (d *DNS) getIPsAndMinTTL(domain string) ([]net.IP, time.Duration, error) {
ips := []net.IP{}
var minTTL uint32
for _, server := range d.nameservers {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(domain), dns.TypeA)
dialServer := server
if _, _, err := net.SplitHostPort(server); err != nil {
dialServer = net.JoinHostPort(server, d.port)
}
c := new(dns.Client)
c.Timeout = 2 * time.Second
in, _, err := c.Exchange(msg, dialServer)
if err != nil {
return nil, defaultTTL, err
}
if in != nil && in.Rcode != dns.RcodeSuccess {
return nil, defaultTTL, fmt.Errorf("failed to get a valid answer: %v", in)
}
if in != nil && len(in.Answer) > 0 {
for i, a := range in.Answer {
fmt.Printf("GOT ANSWER %d %#v\n", i, a)
switch t := a.(type) {
case *dns.A:
ips = append(ips, t.A)
fmt.Printf(" USING ANSWER %d %#v %d %d\n", i, a, minTTL, t.Hdr.Ttl)
if minTTL == 0 || t.Hdr.Ttl < minTTL {
minTTL = t.Hdr.Ttl
}
}
}
}
}
ttl, err := time.ParseDuration(fmt.Sprintf("%ds", minTTL))
if err != nil || minTTL == 0 {
log.Fatal(fmt.Errorf("Invalid TTL value for domain: %q, err: %v, defaulting ttl=%s", domain, err, defaultTTL.String()))
ttl = defaultTTL
}
return removeDuplicateIPs(ips), ttl, nil
}
func (d *DNS) GetMinQueryTime() (time.Time, bool) {
d.lock.Lock()
defer d.lock.Unlock()
timeSet := false
var minTime time.Time
for _, res := range d.dnsMap {
if (timeSet == false) || res.nextQueryTime.Before(minTime) {
timeSet = true
minTime = res.nextQueryTime
}
}
return minTime, timeSet
}
func ipsEqual(oldips, newips []net.IP) bool {
if len(oldips) != len(newips) {
return false
}
for _, oldip := range oldips {
found := false
for _, newip := range newips {
if oldip.Equal(newip) {
found = true
break
}
}
if !found {
return false
}
}
return true
}
func filterIPv4Servers(servers []string) []string {
ipv4Servers := []string{}
for _, server := range servers {
ipString := server
if host, _, err := net.SplitHostPort(server); err == nil {
ipString = host
}
if ip := net.ParseIP(ipString); ip != nil {
if ip.To4() != nil {
ipv4Servers = append(ipv4Servers, server)
}
}
}
return ipv4Servers
}
func removeDuplicateIPs(ips []net.IP) []net.IP {
return ips
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment