Skip to content

Instantly share code, notes, and snippets.

@nsmaciej
Last active March 6, 2017 21:19
Show Gist options
  • Save nsmaciej/bffc4335875a675f8b2be0bb1eb1fcfa to your computer and use it in GitHub Desktop.
Save nsmaciej/bffc4335875a675f8b2be0bb1eb1fcfa to your computer and use it in GitHub Desktop.
Simple DNS Server
extern crate futures;
extern crate tokio_core;
extern crate byteorder;
extern crate itertools;
use std::io::prelude::*;
use std::net::{SocketAddr, IpAddr};
use std::io::{self, Cursor, Error, ErrorKind};
use byteorder::{ReadBytesExt, WriteBytesExt};
use futures::Stream;
use tokio_core::reactor::Core;
use tokio_core::net::{UdpSocket, UdpCodec};
const ACT_AUTHORITIVE: bool = true;
const DEFAULT_TTL: u32 = 0; // Do not cache
const RESPONSE_IP: &'static str = "127.0.0.1";
type NetOrd = byteorder::BigEndian;
#[derive(Debug)]
struct Dns;
#[derive(Debug)]
struct Header {
addr: SocketAddr,
bits: [u8; 4],
counts: [u16; 4],
}
#[derive(Debug)]
struct Record {
domain: String,
info: [u16; 2],
}
#[derive(Debug)]
struct Request {
header: Header,
record: Record,
error: bool,
}
#[derive(Debug)]
struct Response {
header: Header,
record: Record,
ttl: u32,
ip: IpAddr,
error: bool,
}
impl Header {
fn answer(mut self, error: bool) -> Header {
// Leave the message id as it is
// Answer, Standard Quest, Authoritive, Not trunc, Rec desired (Q only)
self.bits[2] = 0b1_0000_0_0_0 | ((ACT_AUTHORITIVE as u8) << 2);
// Based on RFC4035 Specify the -2 and -3 bits
// Rec not avaliable, 1x Reserved, Not "Authentic Data", Not "Checking Disabled", Error
self.bits[3] = 0b0_0_00_0000 | (error as u8);
// No questions in answer, no anything else
self.counts[0] = 0;
self.counts[2] = 0;
self.counts[3] = 0;
// One answer in response
self.counts[1] = !error as u16;
self
}
fn id(&self) -> u16 {
((self.bits[0] as u16) << 2) + (self.bits[1] as u16)
}
}
fn read_domain<R: ReadBytesExt>(mut buf: R) -> io::Result<String> {
let mut labels = Vec::new();
loop {
let len = buf.read_u8()?;
if len == 0 {
break
}
if len & 0xc0 != 0 {
// Some new fancy label type or EDNS
return Err(Error::new(ErrorKind::InvalidData, "unsupported label type"));
}
let mut label = vec![0; len as usize];
buf.read_exact(&mut label)?;
labels.push(String::from_utf8(label).map_err(|_| Error::new(ErrorKind::InvalidData, "utf-8 error"))?);
}
Ok(labels.join("."))
}
fn write_domain<R: WriteBytesExt>(mut buf: R, domain: String) {
for part in domain.split(".") {
buf.write_u8(part.len() as u8).unwrap();
buf.write_all(part.as_ref()).unwrap();
}
buf.write_u8(0).unwrap(); // End
}
fn read_request<R: ReadBytesExt>(src: &SocketAddr, mut buf: R) -> io::Result<Request> {
macro_rules! r_u16 {
() => { buf.read_u16::<NetOrd>()? }
}
let mut bits = [0; 4];
buf.read_exact(&mut bits)?;
let mut req = Request {
header: Header {
bits: bits,
addr: src.clone(),
counts: [r_u16!(), r_u16!(), r_u16!(), r_u16!()],
},
record: Record {
domain: read_domain(&mut buf)?,
info: [r_u16!(), r_u16!()],
},
error: false,
};
macro_rules! ensure {
($x:expr, $msg:expr) => {
if $x {
println!("error: {}", $msg);
req.error = true;
return Ok(req)
}
}
}
ensure!(req.record.info[0] != 1, "only A record supported");
ensure!(req.record.info[1] != 1, "error: only internet class supported");
ensure!(req.header.counts[0] > 1, "error: too many questions");
let mut rest = Vec::new();
buf.read_to_end(&mut rest).unwrap();
ensure!(rest.len() > 0, "excessive bytes");
Ok(req)
}
impl UdpCodec for Dns {
type In = Option<Request>;
type Out = Response;
fn decode(&mut self, src: &SocketAddr, buf: &[u8]) -> io::Result<Option<Request>> {
match read_request(src, Cursor::new(buf)) {
Ok(val) => Ok(Some(val)),
Err(err) => {
println!("error: {}", err);
Ok(None)
}
}
}
fn encode(&mut self, res: Response, mut buf: &mut Vec<u8>) -> SocketAddr {
buf.write_all(&res.header.bits).unwrap();
for cnt in &res.header.counts {
buf.write_u16::<NetOrd>(*cnt).unwrap();
}
if !res.error {
write_domain(&mut buf, res.record.domain);
buf.write_u16::<NetOrd>(res.record.info[0]).unwrap();
buf.write_u16::<NetOrd>(res.record.info[1]).unwrap();
buf.write_u32::<NetOrd>(res.ttl).unwrap();
match res.ip {
IpAddr::V4(v4) => {
buf.write_u16::<NetOrd>(4).unwrap();
buf.write_all(&v4.octets()).unwrap();
},
IpAddr::V6(v6) => {
buf.write_u16::<NetOrd>(16).unwrap();
buf.write_all(&v6.octets()).unwrap();
},
}
}
res.header.addr
}
}
fn main() {
let mut reactor = Core::new().expect("could not create core");
let socket = UdpSocket::bind(&"127.0.0.1:8000".parse().unwrap(), &reactor.handle()).expect("could not bind socket");
let (sink, stream) = socket.framed(Dns).split();
let response_addr: IpAddr = RESPONSE_IP.parse().unwrap();
let server = stream.filter_map(|req| {
if let Some(req) = req {
println!("<- {} A {} {}", req.header.id(), req.header.addr, req.record.domain);
let res = Response {
header: req.header.answer(req.error),
record: req.record,
ttl: DEFAULT_TTL,
ip: response_addr.clone(),
error: req.error,
};
println!("-> {} A {} {} {}\n", res.header.id(), res.record.domain, res.ip, res.ttl);
Some(res)
} else {
None
}
}).forward(sink);
reactor.run(server).expect("server error");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment