Skip to content

Instantly share code, notes, and snippets.

@khafatech
Created September 24, 2017 19:31
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save khafatech/169e665826c050e1252c2295cbef15a1 to your computer and use it in GitHub Desktop.
Save khafatech/169e665826c050e1252c2295cbef15a1 to your computer and use it in GitHub Desktop.
Simple dns server in go
package main
// go build dns.go && ./dns`
// dig +qr +tries=1 +time=1 @localhost -p 9000 foo.com
import (
"bytes"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"log"
"net"
)
const (
// in bytes
HeaderSize = 12
)
// These represent the
type DNSHeader struct {
Id uint16
Flags uint16
Queries uint16
Answers uint16
Auths uint16
Additional uint16
}
type DNSQuery struct {
Name []byte
NameRaw []byte // with length info
Type uint16
Class uint16
}
type DNSResourceRecord struct {
Name []byte
Type uint16
Class uint16
TTL uint32
Rdatalen uint16
Rdata []byte
}
// A, AAAA, MX, etc.
type RType uint16
type DNSRequest struct {
Id uint16
Type uint16
Name []byte
NameRaw []byte
}
func newARecord(name []byte, ip net.IP) DNSResourceRecord {
return DNSResourceRecord{
Name: name,
Type: 1, // A
Class: 1,
TTL: 0,
Rdatalen: 4,
Rdata: []byte(ip),
}
}
func decodeRequest(request_bytes []byte) (DNSRequest, error) {
var request DNSRequest
header, err := decodeHeader(request_bytes)
if err != nil {
fmt.Println("failed decoding header:", err)
return request, err
}
// FIXME - check counts from header
if len(request_bytes) <= HeaderSize {
fmt.Println("only header")
return request, err
}
if header.Queries > 0 {
query, err := decodeQuery(request_bytes[HeaderSize:])
if err != nil {
fmt.Println("error decoding query:", err)
return request, err
}
return DNSRequest{
Id: header.Id,
Name: query.Name,
NameRaw: query.NameRaw,
Type: query.Type}, nil
}
return request, fmt.Errorf("No queries in request")
}
func decodeHeader(b []byte) (DNSHeader, error) {
var header DNSHeader
buf := bytes.NewReader(b)
err := binary.Read(buf, binary.BigEndian, &header)
if err != nil {
return header, err
}
fmt.Printf("Header: %#v", header)
fmt.Printf("request id: %d\n", header.Id)
return header, nil
}
func decodeQuery(b []byte) (DNSQuery, error) {
var query DNSQuery
fmt.Printf("Query:\n%s", hex.Dump(b))
buf := bytes.NewReader(b)
query.Name = parseRequestName(buf)
i := bytes.IndexByte(b, 0)
query.NameRaw = b[:i+1]
fmt.Printf("name: '%s' len(query.NameRaw): %d\n", query.Name, len(query.NameRaw))
err := binary.Read(buf, binary.BigEndian, &query.Type)
if err != nil {
return query, err
}
err = binary.Read(buf, binary.BigEndian, &query.Class)
if err != nil {
return query, err
}
fmt.Printf("type: %d, class: %d\n", query.Type, query.Class)
return query, nil
}
func parseRequestNameSimple(b []byte) ([]byte, error) {
if len(b) == 0 {
return nil, errors.New("Name must not be empty")
}
count := b[0]
if count == 0 {
return []byte{}, nil
}
// + 2 to account for zero at the end
if int(count) > len(b)+2 {
return nil, errors.New(fmt.Sprintf("Wrong count: %d", count))
}
this := b[1 : count+1]
rest, err := parseRequestNameSimple(b[count+1:])
var result []byte
if len(rest) > 0 {
result = append(append(this, "."...), rest...)
} else {
result = this
}
return result, err
}
func parseRequestName(reader *bytes.Reader) []byte {
var name []byte = make([]byte, 512)
debReader := &blockReader{r: reader}
// FIXME - handle errors
io.ReadFull(debReader, name)
fmt.Println("len name:", len(name))
fmt.Println("len string(name):", len(string(name)))
nameStr := fmt.Sprintf("%s", string(name))
fmt.Println("len nameStr:", len(nameStr))
// remove findal dot
// see TruncateAtFinalSlash() in https://blog.golang.org/slices
i := bytes.LastIndex(name, []byte("."))
if i >= 0 {
name = name[0:i]
}
return name
}
type blockReader struct {
r *bytes.Reader
slice []byte
tmp [256]byte
}
// from https://blog.golang.org/gif-decoder-exercise-in-go-interfaces
func (b *blockReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
if len(b.slice) == 0 {
blockLen, err := b.r.ReadByte()
if err != nil {
return 0, err
}
if blockLen == 0 {
return 0, io.EOF
}
b.slice = b.tmp[0:blockLen]
if _, err = io.ReadFull(b.r, b.slice); err != nil {
return 0, err
}
b.slice = append(b.slice, '.')
}
n := copy(p, b.slice)
b.slice = b.slice[n:]
return n, nil
}
func serializeResponse(header DNSHeader, rr DNSResourceRecord) []byte {
buf := new(bytes.Buffer)
err := binary.Write(buf, binary.BigEndian, header)
if err != nil {
fmt.Println("error serializing header:", err)
}
buf.Write(rr.Bytes())
return buf.Bytes()
}
func (rr *DNSResourceRecord) Bytes() []byte {
buf := new(bytes.Buffer)
buf.Write(rr.Name)
twobytes := make([]byte, 2)
fourbytes := make([]byte, 4)
binary.BigEndian.PutUint16(twobytes, rr.Type)
buf.Write(twobytes)
binary.BigEndian.PutUint16(twobytes, rr.Class)
buf.Write(twobytes)
binary.BigEndian.PutUint32(fourbytes, rr.TTL)
buf.Write(fourbytes)
binary.BigEndian.PutUint16(twobytes, rr.Rdatalen)
buf.Write(twobytes)
buf.Write(rr.Rdata)
return buf.Bytes()
}
func handleRequest(pc net.PacketConn, request DNSRequest, clientAddr net.Addr) {
hosts := map[string]string{
"foo.com": "10.0.0.1",
"reddit.com": "10.0.0.2",
"any.thing.io": "192.178.0.3",
}
ipStr := hosts[string(request.Name)]
if ipStr != "" {
fmt.Println("found ip:", ipStr)
responseRR := newARecord(request.NameRaw, net.ParseIP(ipStr).To4())
fmt.Printf("response: %#v\n", responseRR)
responseHeader := DNSHeader{
Id: request.Id,
Flags: 0x8000,
Queries: 0,
Answers: 1,
Auths: 0,
Additional: 0,
}
responseBytes := serializeResponse(responseHeader, responseRR)
fmt.Printf("Response:\n%s", hex.Dump(responseBytes))
pc.WriteTo(responseBytes, clientAddr)
}
}
func main() {
port := 9000
pc, err := net.ListenPacket("udp", fmt.Sprintf("localhost:%d", port))
if err != nil {
log.Fatal(err)
} else {
log.Println("listening on port", port)
}
// what does this do?
defer pc.Close()
buffer := make([]byte, 1024)
for {
n, addr, err := pc.ReadFrom(buffer)
if err != nil {
log.Fatal(err)
}
request_bytes := buffer[:n]
fmt.Printf("request length %d\n", len(request_bytes))
fmt.Printf("Request:\n%s", hex.Dump(request_bytes))
request, err := decodeRequest(request_bytes)
handleRequest(pc, request, addr)
// input := string(request_bytes)
// pc.WriteTo([]byte("Hello there "+strings.ToUpper(input)), addr)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment