Created
June 6, 2018 19:01
-
-
Save knobunc/5282dfd23e77f6216232fa6d1c6ef7b8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"log" | |
"net" | |
"sync" | |
"time" | |
"github.com/miekg/dns" | |
) | |
func main() { | |
dnsInfo, err := NewDNS("/etc/resolv.conf") | |
if err != nil { | |
log.Fatal("xxx") | |
} | |
dnsInfo.Add("www.amazon.com") | |
} | |
const ( | |
defaultTTL = 30 * time.Minute | |
) | |
type dnsValue struct { | |
// All IPv4 addresses for a given domain name | |
ips []net.IP | |
// Time-to-live value from non-authoritative/cached name server for the domain | |
ttl time.Duration | |
// Holds (last dns lookup time + ttl), tells when to refresh IPs next time | |
nextQueryTime time.Time | |
} | |
type DNS struct { | |
// Protects dnsMap operations | |
lock sync.Mutex | |
// Holds dns name and its corresponding information | |
dnsMap map[string]dnsValue | |
// DNS resolvers | |
nameservers []string | |
// DNS port | |
port string | |
} | |
func NewDNS(resolverConfigFile string) (*DNS, error) { | |
config, err := dns.ClientConfigFromFile(resolverConfigFile) | |
if err != nil || config == nil { | |
return nil, fmt.Errorf("ASDASD") | |
} | |
return &DNS{ | |
dnsMap: map[string]dnsValue{}, | |
nameservers: filterIPv4Servers(config.Servers), | |
port: config.Port, | |
}, nil | |
} | |
func (d *DNS) Size() int { | |
d.lock.Lock() | |
defer d.lock.Unlock() | |
return len(d.dnsMap) | |
} | |
func (d *DNS) Get(dns string) dnsValue { | |
d.lock.Lock() | |
defer d.lock.Unlock() | |
data := dnsValue{} | |
if res, ok := d.dnsMap[dns]; ok { | |
data.ips = make([]net.IP, len(res.ips)) | |
copy(data.ips, res.ips) | |
data.ttl = res.ttl | |
data.nextQueryTime = res.nextQueryTime | |
} | |
return data | |
} | |
func (d *DNS) Add(dns string) error { | |
d.lock.Lock() | |
defer d.lock.Unlock() | |
d.dnsMap[dns] = dnsValue{} | |
err, _ := d.updateOne(dns) | |
if err != nil { | |
delete(d.dnsMap, dns) | |
} | |
return err | |
} | |
func (d *DNS) Update() (error, bool) { | |
d.lock.Lock() | |
defer d.lock.Unlock() | |
errList := []error{} | |
changed := false | |
for dns := range d.dnsMap { | |
err, updated := d.updateOne(dns) | |
if err != nil { | |
errList = append(errList, err) | |
continue | |
} | |
if updated { | |
changed = true | |
} | |
} | |
return nil, changed | |
} | |
func (d *DNS) updateOne(dns string) (error, bool) { | |
res, ok := d.dnsMap[dns] | |
if !ok { | |
// Should not happen, all operations on dnsMap are synchronized by d.lock | |
return fmt.Errorf("DNS value not found in dnsMap for domain: %q", dns), false | |
} | |
ips, minTTL, err := d.getIPsAndMinTTL(dns) | |
if err != nil { | |
return err, false | |
} | |
log.Printf("dns %s ttl %d\n", dns, minTTL) | |
changed := false | |
if !ipsEqual(res.ips, ips) { | |
changed = true | |
} | |
res.ips = ips | |
res.ttl = minTTL | |
res.nextQueryTime = time.Now().Add(res.ttl) | |
d.dnsMap[dns] = res | |
return nil, changed | |
} | |
func (d *DNS) getIPsAndMinTTL(domain string) ([]net.IP, time.Duration, error) { | |
ips := []net.IP{} | |
var minTTL uint32 | |
for _, server := range d.nameservers { | |
msg := new(dns.Msg) | |
msg.SetQuestion(dns.Fqdn(domain), dns.TypeA) | |
dialServer := server | |
if _, _, err := net.SplitHostPort(server); err != nil { | |
dialServer = net.JoinHostPort(server, d.port) | |
} | |
c := new(dns.Client) | |
c.Timeout = 2 * time.Second | |
in, _, err := c.Exchange(msg, dialServer) | |
if err != nil { | |
return nil, defaultTTL, err | |
} | |
if in != nil && in.Rcode != dns.RcodeSuccess { | |
return nil, defaultTTL, fmt.Errorf("failed to get a valid answer: %v", in) | |
} | |
if in != nil && len(in.Answer) > 0 { | |
for i, a := range in.Answer { | |
fmt.Printf("GOT ANSWER %d %#v\n", i, a) | |
switch t := a.(type) { | |
case *dns.A: | |
ips = append(ips, t.A) | |
fmt.Printf(" USING ANSWER %d %#v %d %d\n", i, a, minTTL, t.Hdr.Ttl) | |
if minTTL == 0 || t.Hdr.Ttl < minTTL { | |
minTTL = t.Hdr.Ttl | |
} | |
} | |
} | |
} | |
} | |
ttl, err := time.ParseDuration(fmt.Sprintf("%ds", minTTL)) | |
if err != nil || minTTL == 0 { | |
log.Fatal(fmt.Errorf("Invalid TTL value for domain: %q, err: %v, defaulting ttl=%s", domain, err, defaultTTL.String())) | |
ttl = defaultTTL | |
} | |
return removeDuplicateIPs(ips), ttl, nil | |
} | |
func (d *DNS) GetMinQueryTime() (time.Time, bool) { | |
d.lock.Lock() | |
defer d.lock.Unlock() | |
timeSet := false | |
var minTime time.Time | |
for _, res := range d.dnsMap { | |
if (timeSet == false) || res.nextQueryTime.Before(minTime) { | |
timeSet = true | |
minTime = res.nextQueryTime | |
} | |
} | |
return minTime, timeSet | |
} | |
func ipsEqual(oldips, newips []net.IP) bool { | |
if len(oldips) != len(newips) { | |
return false | |
} | |
for _, oldip := range oldips { | |
found := false | |
for _, newip := range newips { | |
if oldip.Equal(newip) { | |
found = true | |
break | |
} | |
} | |
if !found { | |
return false | |
} | |
} | |
return true | |
} | |
func filterIPv4Servers(servers []string) []string { | |
ipv4Servers := []string{} | |
for _, server := range servers { | |
ipString := server | |
if host, _, err := net.SplitHostPort(server); err == nil { | |
ipString = host | |
} | |
if ip := net.ParseIP(ipString); ip != nil { | |
if ip.To4() != nil { | |
ipv4Servers = append(ipv4Servers, server) | |
} | |
} | |
} | |
return ipv4Servers | |
} | |
func removeDuplicateIPs(ips []net.IP) []net.IP { | |
return ips | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment